transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
- # coding=utf-8
2
1
  # Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved.
3
2
  #
4
3
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,33 +14,35 @@
15
14
  """PyTorch DETR model."""
16
15
 
17
16
  import math
17
+ from collections.abc import Callable
18
18
  from dataclasses import dataclass
19
- from typing import Optional, Union
20
19
 
21
20
  import torch
22
- from torch import Tensor, nn
21
+ import torch.nn as nn
23
22
 
24
23
  from ... import initialization as init
25
24
  from ...activations import ACT2FN
26
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
25
+ from ...backbone_utils import load_backbone
26
+ from ...masking_utils import create_bidirectional_mask
27
27
  from ...modeling_layers import GradientCheckpointingLayer
28
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
29
- from ...modeling_utils import PreTrainedModel
28
+ from ...modeling_outputs import (
29
+ BaseModelOutput,
30
+ BaseModelOutputWithCrossAttentions,
31
+ Seq2SeqModelOutput,
32
+ )
33
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
34
+ from ...processing_utils import Unpack
35
+ from ...pytorch_utils import compile_compatible_method_lru_cache
30
36
  from ...utils import (
31
37
  ModelOutput,
38
+ TransformersKwargs,
32
39
  auto_docstring,
33
- is_timm_available,
34
40
  logging,
35
- requires_backends,
36
41
  )
37
- from ...utils.backbone_utils import load_backbone
42
+ from ...utils.generic import can_return_tuple, check_model_inputs
38
43
  from .configuration_detr import DetrConfig
39
44
 
40
45
 
41
- if is_timm_available():
42
- from timm import create_model
43
-
44
-
45
46
  logger = logging.get_logger(__name__)
46
47
 
47
48
 
@@ -64,7 +65,7 @@ class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
64
65
  layernorm.
65
66
  """
66
67
 
67
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
68
+ intermediate_hidden_states: torch.FloatTensor | None = None
68
69
 
69
70
 
70
71
  @dataclass
@@ -84,7 +85,7 @@ class DetrModelOutput(Seq2SeqModelOutput):
84
85
  layernorm.
85
86
  """
86
87
 
87
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
88
+ intermediate_hidden_states: torch.FloatTensor | None = None
88
89
 
89
90
 
90
91
  @dataclass
@@ -116,18 +117,18 @@ class DetrObjectDetectionOutput(ModelOutput):
116
117
  Sequence of hidden-states at the output of the last layer of the decoder of the model.
117
118
  """
118
119
 
119
- loss: Optional[torch.FloatTensor] = None
120
- loss_dict: Optional[dict] = None
121
- logits: Optional[torch.FloatTensor] = None
122
- pred_boxes: Optional[torch.FloatTensor] = None
123
- auxiliary_outputs: Optional[list[dict]] = None
124
- last_hidden_state: Optional[torch.FloatTensor] = None
125
- decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
126
- decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
127
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
128
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
129
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
130
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
120
+ loss: torch.FloatTensor | None = None
121
+ loss_dict: dict | None = None
122
+ logits: torch.FloatTensor | None = None
123
+ pred_boxes: torch.FloatTensor | None = None
124
+ auxiliary_outputs: list[dict] | None = None
125
+ last_hidden_state: torch.FloatTensor | None = None
126
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
127
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
128
+ cross_attentions: tuple[torch.FloatTensor] | None = None
129
+ encoder_last_hidden_state: torch.FloatTensor | None = None
130
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
131
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
131
132
 
132
133
 
133
134
  @dataclass
@@ -165,23 +166,21 @@ class DetrSegmentationOutput(ModelOutput):
165
166
  Sequence of hidden-states at the output of the last layer of the decoder of the model.
166
167
  """
167
168
 
168
- loss: Optional[torch.FloatTensor] = None
169
- loss_dict: Optional[dict] = None
170
- logits: Optional[torch.FloatTensor] = None
171
- pred_boxes: Optional[torch.FloatTensor] = None
172
- pred_masks: Optional[torch.FloatTensor] = None
173
- auxiliary_outputs: Optional[list[dict]] = None
174
- last_hidden_state: Optional[torch.FloatTensor] = None
175
- decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
176
- decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
177
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
178
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
179
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
180
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
181
-
182
-
183
- # BELOW: utilities copied from
184
- # https://github.com/facebookresearch/detr/blob/master/backbone.py
169
+ loss: torch.FloatTensor | None = None
170
+ loss_dict: dict | None = None
171
+ logits: torch.FloatTensor | None = None
172
+ pred_boxes: torch.FloatTensor | None = None
173
+ pred_masks: torch.FloatTensor | None = None
174
+ auxiliary_outputs: list[dict] | None = None
175
+ last_hidden_state: torch.FloatTensor | None = None
176
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
177
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
178
+ cross_attentions: tuple[torch.FloatTensor] | None = None
179
+ encoder_last_hidden_state: torch.FloatTensor | None = None
180
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
181
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
182
+
183
+
185
184
  class DetrFrozenBatchNorm2d(nn.Module):
186
185
  """
187
186
  BatchNorm2d where the batch statistics and the affine parameters are fixed.
@@ -258,47 +257,25 @@ class DetrConvEncoder(nn.Module):
258
257
 
259
258
  self.config = config
260
259
 
261
- # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
262
- if config.use_timm_backbone:
263
- # We default to values which were previously hard-coded. This enables configurability from the config
264
- # using backbone arguments, while keeping the default behavior the same.
265
- requires_backends(self, ["timm"])
266
- kwargs = getattr(config, "backbone_kwargs", {})
267
- kwargs = {} if kwargs is None else kwargs.copy()
268
- out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
269
- num_channels = kwargs.pop("in_chans", config.num_channels)
270
- if config.dilation:
271
- kwargs["output_stride"] = kwargs.get("output_stride", 16)
272
- backbone = create_model(
273
- config.backbone,
274
- pretrained=config.use_pretrained_backbone,
275
- features_only=True,
276
- out_indices=out_indices,
277
- in_chans=num_channels,
278
- **kwargs,
279
- )
280
- else:
281
- backbone = load_backbone(config)
260
+ backbone = load_backbone(config)
261
+ self.intermediate_channel_sizes = backbone.channels
282
262
 
283
263
  # replace batch norm by frozen batch norm
284
264
  with torch.no_grad():
285
265
  replace_batch_norm(backbone)
286
- self.model = backbone
287
- self.intermediate_channel_sizes = (
288
- self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
289
- )
290
266
 
291
- backbone_model_type = None
292
- if config.backbone is not None:
293
- backbone_model_type = config.backbone
294
- elif config.backbone_config is not None:
295
- backbone_model_type = config.backbone_config.model_type
296
- else:
297
- raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
267
+ # We used to load with timm library directly instead of the AutoBackbone API
268
+ # so we need to unwrap the `backbone._backbone` module to load weights without mismatch
269
+ is_timm_model = False
270
+ if hasattr(backbone, "_backbone"):
271
+ backbone = backbone._backbone
272
+ is_timm_model = True
273
+ self.model = backbone
298
274
 
275
+ backbone_model_type = config.backbone_config.model_type
299
276
  if "resnet" in backbone_model_type:
300
277
  for name, parameter in self.model.named_parameters():
301
- if config.use_timm_backbone:
278
+ if is_timm_model:
302
279
  if "layer2" not in name and "layer3" not in name and "layer4" not in name:
303
280
  parameter.requires_grad_(False)
304
281
  else:
@@ -307,7 +284,9 @@ class DetrConvEncoder(nn.Module):
307
284
 
308
285
  def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
309
286
  # send pixel_values through the model to get list of feature maps
310
- features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
287
+ features = self.model(pixel_values)
288
+ if isinstance(features, dict):
289
+ features = features.feature_maps
311
290
 
312
291
  out = []
313
292
  for feature_map in features:
@@ -317,61 +296,55 @@ class DetrConvEncoder(nn.Module):
317
296
  return out
318
297
 
319
298
 
320
- class DetrConvModel(nn.Module):
321
- """
322
- This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
323
- """
324
-
325
- def __init__(self, conv_encoder, position_embedding):
326
- super().__init__()
327
- self.conv_encoder = conv_encoder
328
- self.position_embedding = position_embedding
329
-
330
- def forward(self, pixel_values, pixel_mask):
331
- # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
332
- out = self.conv_encoder(pixel_values, pixel_mask)
333
- pos = []
334
- for feature_map, mask in out:
335
- # position encoding
336
- pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
337
-
338
- return out, pos
339
-
340
-
341
299
  class DetrSinePositionEmbedding(nn.Module):
342
300
  """
343
301
  This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
344
302
  need paper, generalized to work on images.
345
303
  """
346
304
 
347
- def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
305
+ def __init__(
306
+ self,
307
+ num_position_features: int = 64,
308
+ temperature: int = 10000,
309
+ normalize: bool = False,
310
+ scale: float | None = None,
311
+ ):
348
312
  super().__init__()
349
- self.embedding_dim = embedding_dim
350
- self.temperature = temperature
351
- self.normalize = normalize
352
313
  if scale is not None and normalize is False:
353
314
  raise ValueError("normalize should be True if scale is passed")
354
- if scale is None:
355
- scale = 2 * math.pi
356
- self.scale = scale
315
+ self.num_position_features = num_position_features
316
+ self.temperature = temperature
317
+ self.normalize = normalize
318
+ self.scale = 2 * math.pi if scale is None else scale
357
319
 
358
- def forward(self, pixel_values, pixel_mask):
359
- if pixel_mask is None:
360
- raise ValueError("No pixel mask provided")
361
- y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
362
- x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
320
+ @compile_compatible_method_lru_cache(maxsize=1)
321
+ def forward(
322
+ self,
323
+ shape: torch.Size,
324
+ device: torch.device | str,
325
+ dtype: torch.dtype,
326
+ mask: torch.Tensor | None = None,
327
+ ) -> torch.Tensor:
328
+ if mask is None:
329
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
330
+ y_embed = mask.cumsum(1, dtype=dtype)
331
+ x_embed = mask.cumsum(2, dtype=dtype)
363
332
  if self.normalize:
364
- y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
365
- x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
333
+ eps = 1e-6
334
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
335
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
366
336
 
367
- dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
368
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
337
+ dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
338
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
369
339
 
370
340
  pos_x = x_embed[:, :, :, None] / dim_t
371
341
  pos_y = y_embed[:, :, :, None] / dim_t
372
342
  pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
373
343
  pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
374
344
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
345
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
346
+ # expected by the encoder
347
+ pos = pos.flatten(2).permute(0, 2, 1)
375
348
  return pos
376
349
 
377
350
 
@@ -385,207 +358,260 @@ class DetrLearnedPositionEmbedding(nn.Module):
385
358
  self.row_embeddings = nn.Embedding(50, embedding_dim)
386
359
  self.column_embeddings = nn.Embedding(50, embedding_dim)
387
360
 
388
- def forward(self, pixel_values, pixel_mask=None):
389
- height, width = pixel_values.shape[-2:]
390
- width_values = torch.arange(width, device=pixel_values.device)
391
- height_values = torch.arange(height, device=pixel_values.device)
361
+ @compile_compatible_method_lru_cache(maxsize=1)
362
+ def forward(
363
+ self,
364
+ shape: torch.Size,
365
+ device: torch.device | str,
366
+ dtype: torch.dtype,
367
+ mask: torch.Tensor | None = None,
368
+ ):
369
+ height, width = shape[-2:]
370
+ width_values = torch.arange(width, device=device)
371
+ height_values = torch.arange(height, device=device)
392
372
  x_emb = self.column_embeddings(width_values)
393
373
  y_emb = self.row_embeddings(height_values)
394
374
  pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
395
375
  pos = pos.permute(2, 0, 1)
396
376
  pos = pos.unsqueeze(0)
397
- pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
377
+ pos = pos.repeat(shape[0], 1, 1, 1)
378
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
379
+ # expected by the encoder
380
+ pos = pos.flatten(2).permute(0, 2, 1)
398
381
  return pos
399
382
 
400
383
 
401
- def build_position_encoding(config):
402
- n_steps = config.d_model // 2
403
- if config.position_embedding_type == "sine":
404
- # TODO find a better way of exposing other arguments
405
- position_embedding = DetrSinePositionEmbedding(n_steps, normalize=True)
406
- elif config.position_embedding_type == "learned":
407
- position_embedding = DetrLearnedPositionEmbedding(n_steps)
408
- else:
409
- raise ValueError(f"Not supported {config.position_embedding_type}")
384
+ # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
385
+ def eager_attention_forward(
386
+ module: nn.Module,
387
+ query: torch.Tensor,
388
+ key: torch.Tensor,
389
+ value: torch.Tensor,
390
+ attention_mask: torch.Tensor | None,
391
+ scaling: float | None = None,
392
+ dropout: float = 0.0,
393
+ **kwargs: Unpack[TransformersKwargs],
394
+ ):
395
+ if scaling is None:
396
+ scaling = query.size(-1) ** -0.5
397
+
398
+ # Take the dot product between "query" and "key" to get the raw attention scores.
399
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
410
400
 
411
- return position_embedding
401
+ if attention_mask is not None:
402
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
403
+ attn_weights = attn_weights + attention_mask
412
404
 
405
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
406
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
413
407
 
414
- class DetrAttention(nn.Module):
408
+ attn_output = torch.matmul(attn_weights, value)
409
+ attn_output = attn_output.transpose(1, 2).contiguous()
410
+
411
+ return attn_output, attn_weights
412
+
413
+
414
+ class DetrSelfAttention(nn.Module):
415
415
  """
416
- Multi-headed attention from 'Attention Is All You Need' paper.
416
+ Multi-headed self-attention from 'Attention Is All You Need' paper.
417
417
 
418
- Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
418
+ In DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
419
419
  """
420
420
 
421
421
  def __init__(
422
422
  self,
423
- embed_dim: int,
424
- num_heads: int,
423
+ config: DetrConfig,
424
+ hidden_size: int,
425
+ num_attention_heads: int,
425
426
  dropout: float = 0.0,
426
427
  bias: bool = True,
427
428
  ):
428
429
  super().__init__()
429
- self.embed_dim = embed_dim
430
- self.num_heads = num_heads
431
- self.dropout = dropout
432
- self.head_dim = embed_dim // num_heads
433
- if self.head_dim * num_heads != self.embed_dim:
434
- raise ValueError(
435
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
436
- f" {num_heads})."
437
- )
430
+ self.config = config
431
+ self.head_dim = hidden_size // num_attention_heads
438
432
  self.scaling = self.head_dim**-0.5
433
+ self.attention_dropout = dropout
434
+ self.is_causal = False
439
435
 
440
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
441
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
442
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
443
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
444
-
445
- def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
446
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
447
-
448
- def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
449
- return tensor if object_queries is None else tensor + object_queries
436
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
437
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
438
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
439
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
450
440
 
451
441
  def forward(
452
442
  self,
453
443
  hidden_states: torch.Tensor,
454
- attention_mask: Optional[torch.Tensor] = None,
455
- object_queries: Optional[torch.Tensor] = None,
456
- key_value_states: Optional[torch.Tensor] = None,
457
- spatial_position_embeddings: Optional[torch.Tensor] = None,
458
- output_attentions: bool = False,
459
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
460
- """Input shape: Batch x Time x Channel"""
461
- # if key_value_states are provided this layer is used as a cross-attention layer
462
- # for the decoder
463
- is_cross_attention = key_value_states is not None
464
- batch_size, target_len, embed_dim = hidden_states.size()
465
-
466
- # add position embeddings to the hidden states before projecting to queries and keys
467
- if object_queries is not None:
468
- hidden_states_original = hidden_states
469
- hidden_states = self.with_pos_embed(hidden_states, object_queries)
470
-
471
- # add key-value position embeddings to the key value states
472
- if spatial_position_embeddings is not None:
473
- key_value_states_original = key_value_states
474
- key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
475
-
476
- # get query proj
477
- query_states = self.q_proj(hidden_states) * self.scaling
478
- # get key, value proj
479
- if is_cross_attention:
480
- # cross_attentions
481
- key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
482
- value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
483
- else:
484
- # self_attention
485
- key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
486
- value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
444
+ attention_mask: torch.Tensor | None = None,
445
+ position_embeddings: torch.Tensor | None = None,
446
+ **kwargs: Unpack[TransformersKwargs],
447
+ ) -> tuple[torch.Tensor, torch.Tensor]:
448
+ """
449
+ Position embeddings are added to both queries and keys (but not values).
450
+ """
451
+ input_shape = hidden_states.shape[:-1]
452
+ hidden_shape = (*input_shape, -1, self.head_dim)
487
453
 
488
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
489
- query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
490
- key_states = key_states.view(*proj_shape)
491
- value_states = value_states.view(*proj_shape)
454
+ query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
492
455
 
493
- source_len = key_states.size(1)
456
+ query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
457
+ key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
458
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
494
459
 
495
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
460
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
461
+ self.config._attn_implementation, eager_attention_forward
462
+ )
496
463
 
497
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
498
- raise ValueError(
499
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
500
- f" {attn_weights.size()}"
501
- )
464
+ attn_output, attn_weights = attention_interface(
465
+ self,
466
+ query_states,
467
+ key_states,
468
+ value_states,
469
+ attention_mask,
470
+ dropout=0.0 if not self.training else self.attention_dropout,
471
+ scaling=self.scaling,
472
+ **kwargs,
473
+ )
502
474
 
503
- if attention_mask is not None:
504
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
505
- raise ValueError(
506
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
507
- f" {attention_mask.size()}"
508
- )
509
- if attention_mask.dtype == torch.bool:
510
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
511
- attention_mask, -torch.inf
512
- )
513
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
514
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
515
-
516
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
517
-
518
- if output_attentions:
519
- # this operation is a bit awkward, but it's required to
520
- # make sure that attn_weights keeps its gradient.
521
- # In order to do so, attn_weights have to reshaped
522
- # twice and have to be reused in the following
523
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
524
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
525
- else:
526
- attn_weights_reshaped = None
475
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
476
+ attn_output = self.o_proj(attn_output)
477
+ return attn_output, attn_weights
527
478
 
528
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
529
479
 
530
- attn_output = torch.bmm(attn_probs, value_states)
480
+ class DetrCrossAttention(nn.Module):
481
+ """
482
+ Multi-headed cross-attention from 'Attention Is All You Need' paper.
531
483
 
532
- if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
533
- raise ValueError(
534
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
535
- f" {attn_output.size()}"
536
- )
484
+ In DETR, queries get their own position embeddings, while keys get encoder position embeddings.
485
+ Values don't get any position embeddings.
486
+ """
537
487
 
538
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
539
- attn_output = attn_output.transpose(1, 2)
540
- attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
488
+ def __init__(
489
+ self,
490
+ config: DetrConfig,
491
+ hidden_size: int,
492
+ num_attention_heads: int,
493
+ dropout: float = 0.0,
494
+ bias: bool = True,
495
+ ):
496
+ super().__init__()
497
+ self.config = config
498
+ self.head_dim = hidden_size // num_attention_heads
499
+ self.scaling = self.head_dim**-0.5
500
+ self.attention_dropout = dropout
501
+ self.is_causal = False
541
502
 
542
- attn_output = self.out_proj(attn_output)
503
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
504
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
505
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
506
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
543
507
 
544
- return attn_output, attn_weights_reshaped
508
+ def forward(
509
+ self,
510
+ hidden_states: torch.Tensor,
511
+ key_value_states: torch.Tensor,
512
+ attention_mask: torch.Tensor | None = None,
513
+ position_embeddings: torch.Tensor | None = None,
514
+ encoder_position_embeddings: torch.Tensor | None = None,
515
+ **kwargs: Unpack[TransformersKwargs],
516
+ ) -> tuple[torch.Tensor, torch.Tensor]:
517
+ """
518
+ Position embeddings logic:
519
+ - Queries get position_embeddings
520
+ - Keys get encoder_position_embeddings
521
+ - Values don't get any position embeddings
522
+ """
523
+ query_input_shape = hidden_states.shape[:-1]
524
+ query_hidden_shape = (*query_input_shape, -1, self.head_dim)
525
+
526
+ kv_input_shape = key_value_states.shape[:-1]
527
+ kv_hidden_shape = (*kv_input_shape, -1, self.head_dim)
528
+
529
+ query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
530
+ key_input = (
531
+ key_value_states + encoder_position_embeddings
532
+ if encoder_position_embeddings is not None
533
+ else key_value_states
534
+ )
545
535
 
536
+ query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2)
537
+ key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2)
538
+ value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2)
546
539
 
547
- class DetrEncoderLayer(nn.Module):
540
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
541
+ self.config._attn_implementation, eager_attention_forward
542
+ )
543
+
544
+ attn_output, attn_weights = attention_interface(
545
+ self,
546
+ query_states,
547
+ key_states,
548
+ value_states,
549
+ attention_mask,
550
+ dropout=0.0 if not self.training else self.attention_dropout,
551
+ scaling=self.scaling,
552
+ **kwargs,
553
+ )
554
+
555
+ attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
556
+ attn_output = self.o_proj(attn_output)
557
+ return attn_output, attn_weights
558
+
559
+
560
+ class DetrMLP(nn.Module):
561
+ def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int):
562
+ super().__init__()
563
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
564
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
565
+ self.activation_fn = ACT2FN[config.activation_function]
566
+ self.activation_dropout = config.activation_dropout
567
+ self.dropout = config.dropout
568
+
569
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
570
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
571
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
572
+ hidden_states = self.fc2(hidden_states)
573
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
574
+ return hidden_states
575
+
576
+
577
+ class DetrEncoderLayer(GradientCheckpointingLayer):
548
578
  def __init__(self, config: DetrConfig):
549
579
  super().__init__()
550
- self.embed_dim = config.d_model
551
- self.self_attn = DetrAttention(
552
- embed_dim=self.embed_dim,
553
- num_heads=config.encoder_attention_heads,
580
+ self.hidden_size = config.d_model
581
+ self.self_attn = DetrSelfAttention(
582
+ config=config,
583
+ hidden_size=self.hidden_size,
584
+ num_attention_heads=config.encoder_attention_heads,
554
585
  dropout=config.attention_dropout,
555
586
  )
556
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
587
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
557
588
  self.dropout = config.dropout
558
- self.activation_fn = ACT2FN[config.activation_function]
559
- self.activation_dropout = config.activation_dropout
560
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
561
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
562
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
589
+ self.mlp = DetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
590
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
563
591
 
564
592
  def forward(
565
593
  self,
566
594
  hidden_states: torch.Tensor,
567
595
  attention_mask: torch.Tensor,
568
- object_queries: Optional[torch.Tensor] = None,
569
- output_attentions: bool = False,
570
- ):
596
+ spatial_position_embeddings: torch.Tensor | None = None,
597
+ **kwargs: Unpack[TransformersKwargs],
598
+ ) -> torch.Tensor:
571
599
  """
572
600
  Args:
573
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
601
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
574
602
  attention_mask (`torch.FloatTensor`): attention mask of size
575
603
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
576
604
  values.
577
- object_queries (`torch.FloatTensor`, *optional*):
578
- Object queries (also called content embeddings), to be added to the hidden states.
579
- output_attentions (`bool`, *optional*):
580
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
581
- returned tensors for more detail.
605
+ spatial_position_embeddings (`torch.FloatTensor`, *optional*):
606
+ Spatial position embeddings (2D positional encodings of image locations), to be added to both
607
+ the queries and keys in self-attention (but not to values).
582
608
  """
583
609
  residual = hidden_states
584
- hidden_states, attn_weights = self.self_attn(
610
+ hidden_states, _ = self.self_attn(
585
611
  hidden_states=hidden_states,
586
612
  attention_mask=attention_mask,
587
- object_queries=object_queries,
588
- output_attentions=output_attentions,
613
+ position_embeddings=spatial_position_embeddings,
614
+ **kwargs,
589
615
  )
590
616
 
591
617
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -593,12 +619,7 @@ class DetrEncoderLayer(nn.Module):
593
619
  hidden_states = self.self_attn_layer_norm(hidden_states)
594
620
 
595
621
  residual = hidden_states
596
- hidden_states = self.activation_fn(self.fc1(hidden_states))
597
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
598
-
599
- hidden_states = self.fc2(hidden_states)
600
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
601
-
622
+ hidden_states = self.mlp(hidden_states)
602
623
  hidden_states = residual + hidden_states
603
624
  hidden_states = self.final_layer_norm(hidden_states)
604
625
 
@@ -607,78 +628,69 @@ class DetrEncoderLayer(nn.Module):
607
628
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
608
629
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
609
630
 
610
- outputs = (hidden_states,)
611
-
612
- if output_attentions:
613
- outputs += (attn_weights,)
614
-
615
- return outputs
631
+ return hidden_states
616
632
 
617
633
 
618
634
  class DetrDecoderLayer(GradientCheckpointingLayer):
619
635
  def __init__(self, config: DetrConfig):
620
636
  super().__init__()
621
- self.embed_dim = config.d_model
637
+ self.hidden_size = config.d_model
622
638
 
623
- self.self_attn = DetrAttention(
624
- embed_dim=self.embed_dim,
625
- num_heads=config.decoder_attention_heads,
639
+ self.self_attn = DetrSelfAttention(
640
+ config=config,
641
+ hidden_size=self.hidden_size,
642
+ num_attention_heads=config.decoder_attention_heads,
626
643
  dropout=config.attention_dropout,
627
644
  )
628
645
  self.dropout = config.dropout
629
- self.activation_fn = ACT2FN[config.activation_function]
630
- self.activation_dropout = config.activation_dropout
631
646
 
632
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
633
- self.encoder_attn = DetrAttention(
634
- self.embed_dim,
635
- config.decoder_attention_heads,
647
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
648
+ self.encoder_attn = DetrCrossAttention(
649
+ config=config,
650
+ hidden_size=self.hidden_size,
651
+ num_attention_heads=config.decoder_attention_heads,
636
652
  dropout=config.attention_dropout,
637
653
  )
638
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
639
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
640
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
641
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
654
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
655
+ self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
656
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
642
657
 
643
658
  def forward(
644
659
  self,
645
660
  hidden_states: torch.Tensor,
646
- attention_mask: Optional[torch.Tensor] = None,
647
- object_queries: Optional[torch.Tensor] = None,
648
- query_position_embeddings: Optional[torch.Tensor] = None,
649
- encoder_hidden_states: Optional[torch.Tensor] = None,
650
- encoder_attention_mask: Optional[torch.Tensor] = None,
651
- output_attentions: Optional[bool] = False,
652
- ):
661
+ attention_mask: torch.Tensor | None = None,
662
+ spatial_position_embeddings: torch.Tensor | None = None,
663
+ object_queries_position_embeddings: torch.Tensor | None = None,
664
+ encoder_hidden_states: torch.Tensor | None = None,
665
+ encoder_attention_mask: torch.Tensor | None = None,
666
+ **kwargs: Unpack[TransformersKwargs],
667
+ ) -> torch.Tensor:
653
668
  """
654
669
  Args:
655
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
670
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
656
671
  attention_mask (`torch.FloatTensor`): attention mask of size
657
672
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
658
673
  values.
659
- object_queries (`torch.FloatTensor`, *optional*):
660
- object_queries that are added to the hidden states
661
- in the cross-attention layer.
662
- query_position_embeddings (`torch.FloatTensor`, *optional*):
663
- position embeddings that are added to the queries and keys
664
- in the self-attention layer.
674
+ spatial_position_embeddings (`torch.FloatTensor`, *optional*):
675
+ Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only
676
+ in the cross-attention layer (not to values).
677
+ object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
678
+ Position embeddings for the object query slots. In self-attention, these are added to both queries
679
+ and keys (not values). In cross-attention, these are added to queries only (not to keys or values).
665
680
  encoder_hidden_states (`torch.FloatTensor`):
666
- cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
681
+ cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
667
682
  encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
668
683
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
669
684
  values.
670
- output_attentions (`bool`, *optional*):
671
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
672
- returned tensors for more detail.
673
685
  """
674
686
  residual = hidden_states
675
687
 
676
688
  # Self Attention
677
- hidden_states, self_attn_weights = self.self_attn(
689
+ hidden_states, _ = self.self_attn(
678
690
  hidden_states=hidden_states,
679
- object_queries=query_position_embeddings,
691
+ position_embeddings=object_queries_position_embeddings,
680
692
  attention_mask=attention_mask,
681
- output_attentions=output_attentions,
693
+ **kwargs,
682
694
  )
683
695
 
684
696
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -686,17 +698,16 @@ class DetrDecoderLayer(GradientCheckpointingLayer):
686
698
  hidden_states = self.self_attn_layer_norm(hidden_states)
687
699
 
688
700
  # Cross-Attention Block
689
- cross_attn_weights = None
690
701
  if encoder_hidden_states is not None:
691
702
  residual = hidden_states
692
703
 
693
- hidden_states, cross_attn_weights = self.encoder_attn(
704
+ hidden_states, _ = self.encoder_attn(
694
705
  hidden_states=hidden_states,
695
- object_queries=query_position_embeddings,
696
706
  key_value_states=encoder_hidden_states,
697
707
  attention_mask=encoder_attention_mask,
698
- spatial_position_embeddings=object_queries,
699
- output_attentions=output_attentions,
708
+ position_embeddings=object_queries_position_embeddings,
709
+ encoder_position_embeddings=spatial_position_embeddings,
710
+ **kwargs,
700
711
  )
701
712
 
702
713
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -705,19 +716,164 @@ class DetrDecoderLayer(GradientCheckpointingLayer):
705
716
 
706
717
  # Fully Connected
707
718
  residual = hidden_states
708
- hidden_states = self.activation_fn(self.fc1(hidden_states))
709
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
710
- hidden_states = self.fc2(hidden_states)
711
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
719
+ hidden_states = self.mlp(hidden_states)
712
720
  hidden_states = residual + hidden_states
713
721
  hidden_states = self.final_layer_norm(hidden_states)
714
722
 
715
- outputs = (hidden_states,)
723
+ return hidden_states
724
+
725
+
726
+ class DetrConvBlock(nn.Module):
727
+ """Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
728
+
729
+ def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
730
+ super().__init__()
731
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
732
+ self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
733
+ self.activation = ACT2FN[activation]
734
+
735
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
736
+ return self.activation(self.norm(self.conv(x)))
737
+
738
+
739
+ class DetrFPNFusionStage(nn.Module):
740
+ """Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
741
+
742
+ def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
743
+ super().__init__()
744
+ self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
745
+ self.refine = DetrConvBlock(current_channels, output_channels, activation)
746
+
747
+ def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
748
+ """
749
+ Args:
750
+ features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
751
+ fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
752
+
753
+ Returns:
754
+ Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
755
+ """
756
+ fpn_features = self.fpn_adapter(fpn_features)
757
+ features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
758
+ return self.refine(fpn_features + features)
759
+
716
760
 
717
- if output_attentions:
718
- outputs += (self_attn_weights, cross_attn_weights)
761
+ class DetrMaskHeadSmallConv(nn.Module):
762
+ """
763
+ Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
719
764
 
720
- return outputs
765
+ Combines attention maps (spatial localization) with encoder features (semantics) and progressively
766
+ upsamples through multiple scales, fusing with FPN features for high-resolution detail.
767
+ """
768
+
769
+ def __init__(
770
+ self,
771
+ input_channels: int,
772
+ fpn_channels: list[int],
773
+ hidden_size: int,
774
+ activation_function: str = "relu",
775
+ ):
776
+ super().__init__()
777
+ if input_channels % 8 != 0:
778
+ raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
779
+
780
+ self.conv1 = DetrConvBlock(input_channels, input_channels, activation_function)
781
+ self.conv2 = DetrConvBlock(input_channels, hidden_size // 2, activation_function)
782
+
783
+ # Progressive channel reduction: /2 -> /4 -> /8 -> /16
784
+ self.fpn_stages = nn.ModuleList(
785
+ [
786
+ DetrFPNFusionStage(fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function),
787
+ DetrFPNFusionStage(fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function),
788
+ DetrFPNFusionStage(fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function),
789
+ ]
790
+ )
791
+
792
+ self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
793
+
794
+ def forward(
795
+ self,
796
+ features: torch.Tensor,
797
+ attention_masks: torch.Tensor,
798
+ fpn_features: list[torch.Tensor],
799
+ ) -> torch.Tensor:
800
+ """
801
+ Args:
802
+ features: Encoder output features, shape (batch_size, hidden_size, H, W)
803
+ attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
804
+ fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
805
+
806
+ Returns:
807
+ Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
808
+ """
809
+ num_queries = attention_masks.shape[1]
810
+
811
+ # Expand to (batch_size * num_queries) dimension
812
+ features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
813
+ attention_masks = attention_masks.flatten(0, 1)
814
+ fpn_features = [
815
+ fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
816
+ ]
817
+
818
+ hidden_states = torch.cat([features, attention_masks], dim=1)
819
+ hidden_states = self.conv1(hidden_states)
820
+ hidden_states = self.conv2(hidden_states)
821
+
822
+ for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
823
+ hidden_states = fpn_stage(hidden_states, fpn_feat)
824
+
825
+ return self.output_conv(hidden_states)
826
+
827
+
828
+ class DetrMHAttentionMap(nn.Module):
829
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
830
+
831
+ def __init__(
832
+ self,
833
+ hidden_size: int,
834
+ num_attention_heads: int,
835
+ dropout: float = 0.0,
836
+ bias: bool = True,
837
+ ):
838
+ super().__init__()
839
+ self.head_dim = hidden_size // num_attention_heads
840
+ self.scaling = self.head_dim**-0.5
841
+ self.attention_dropout = dropout
842
+
843
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
844
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
845
+
846
+ def forward(
847
+ self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
848
+ ):
849
+ query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
850
+ key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
851
+
852
+ query_states = self.q_proj(query_states).view(query_hidden_shape)
853
+ key_states = nn.functional.conv2d(
854
+ key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
855
+ ).view(key_hidden_shape)
856
+
857
+ batch_size, num_queries, num_heads, head_dim = query_states.shape
858
+ _, _, _, height, width = key_states.shape
859
+ query_shape = (batch_size * num_heads, num_queries, head_dim)
860
+ key_shape = (batch_size * num_heads, height * width, head_dim)
861
+ attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
862
+
863
+ query = query_states.transpose(1, 2).contiguous().view(query_shape)
864
+ key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
865
+
866
+ attn_weights = (
867
+ (torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
868
+ )
869
+
870
+ if attention_mask is not None:
871
+ attn_weights = attn_weights + attention_mask
872
+
873
+ attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
874
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
875
+
876
+ return attn_weights
721
877
 
722
878
 
723
879
  @auto_docstring
@@ -727,21 +883,36 @@ class DetrPreTrainedModel(PreTrainedModel):
727
883
  main_input_name = "pixel_values"
728
884
  input_modalities = ("image",)
729
885
  _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
886
+ supports_gradient_checkpointing = True
887
+ _supports_sdpa = True
888
+ _supports_flash_attn = True
889
+ _supports_attention_backend = True
890
+ _supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
891
+ _keys_to_ignore_on_load_unexpected = [
892
+ r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
893
+ ]
730
894
 
731
895
  @torch.no_grad()
732
896
  def _init_weights(self, module):
733
897
  std = self.config.init_std
734
898
  xavier_std = self.config.init_xavier_std
735
899
 
736
- if isinstance(module, DetrMHAttentionMap):
737
- init.zeros_(module.k_linear.bias)
738
- init.zeros_(module.q_linear.bias)
739
- init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
740
- init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
900
+ if isinstance(module, DetrMaskHeadSmallConv):
901
+ # DetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
902
+ for m in module.modules():
903
+ if isinstance(m, nn.Conv2d):
904
+ init.kaiming_uniform_(m.weight, a=1)
905
+ if m.bias is not None:
906
+ init.constant_(m.bias, 0)
907
+ elif isinstance(module, DetrMHAttentionMap):
908
+ init.zeros_(module.k_proj.bias)
909
+ init.zeros_(module.q_proj.bias)
910
+ init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
911
+ init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
741
912
  elif isinstance(module, DetrLearnedPositionEmbedding):
742
913
  init.uniform_(module.row_embeddings.weight)
743
914
  init.uniform_(module.column_embeddings.weight)
744
- if isinstance(module, (nn.Linear, nn.Conv2d)):
915
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
745
916
  init.normal_(module.weight, mean=0.0, std=std)
746
917
  if module.bias is not None:
747
918
  init.zeros_(module.bias)
@@ -757,47 +928,36 @@ class DetrPreTrainedModel(PreTrainedModel):
757
928
 
758
929
  class DetrEncoder(DetrPreTrainedModel):
759
930
  """
760
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
761
- [`DetrEncoderLayer`].
762
-
763
- The encoder updates the flattened feature map through multiple self-attention layers.
764
-
765
- Small tweak for DETR:
766
-
767
- - object_queries are added to the forward pass.
931
+ Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
932
+ [`DetrEncoderLayer`] modules.
768
933
 
769
934
  Args:
770
- config: DetrConfig
935
+ config (`DetrConfig`): Model configuration object.
771
936
  """
772
937
 
938
+ _can_record_outputs = {"hidden_states": DetrEncoderLayer, "attentions": DetrSelfAttention}
939
+
773
940
  def __init__(self, config: DetrConfig):
774
941
  super().__init__(config)
775
942
 
776
943
  self.dropout = config.dropout
777
- self.layerdrop = config.encoder_layerdrop
778
-
779
944
  self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
780
945
 
781
- # in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
782
-
783
946
  # Initialize weights and apply final processing
784
947
  self.post_init()
785
948
 
949
+ @check_model_inputs()
786
950
  def forward(
787
951
  self,
788
952
  inputs_embeds=None,
789
953
  attention_mask=None,
790
- object_queries=None,
791
- output_attentions=None,
792
- output_hidden_states=None,
793
- return_dict=None,
794
- **kwargs,
795
- ):
954
+ spatial_position_embeddings=None,
955
+ **kwargs: Unpack[TransformersKwargs],
956
+ ) -> BaseModelOutput:
796
957
  r"""
797
958
  Args:
798
959
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
799
960
  Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
800
-
801
961
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
802
962
  Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
803
963
 
@@ -805,112 +965,67 @@ class DetrEncoder(DetrPreTrainedModel):
805
965
  - 0 for pixel features that are padding (i.e. **masked**).
806
966
 
807
967
  [What are attention masks?](../glossary#attention-mask)
808
-
809
- object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
810
- Object queries that are added to the queries in each self-attention layer.
811
-
812
- output_attentions (`bool`, *optional*):
813
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
814
- returned tensors for more detail.
815
- output_hidden_states (`bool`, *optional*):
816
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
817
- for more detail.
818
- return_dict (`bool`, *optional*):
819
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
968
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
969
+ Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
820
970
  """
821
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
822
- output_hidden_states = (
823
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
824
- )
825
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
826
-
827
971
  hidden_states = inputs_embeds
828
972
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
829
973
 
830
974
  # expand attention_mask
831
975
  if attention_mask is not None:
832
976
  # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
833
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
834
-
835
- encoder_states = () if output_hidden_states else None
836
- all_attentions = () if output_attentions else None
837
- for i, encoder_layer in enumerate(self.layers):
838
- if output_hidden_states:
839
- encoder_states = encoder_states + (hidden_states,)
840
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
841
- to_drop = False
842
- if self.training:
843
- dropout_probability = torch.rand([])
844
- if dropout_probability < self.layerdrop: # skip the layer
845
- to_drop = True
846
-
847
- if to_drop:
848
- layer_outputs = (None, None)
849
- else:
850
- # we add object_queries as extra input to the encoder_layer
851
- layer_outputs = encoder_layer(
852
- hidden_states,
853
- attention_mask,
854
- object_queries=object_queries,
855
- output_attentions=output_attentions,
856
- )
857
-
858
- hidden_states = layer_outputs[0]
859
-
860
- if output_attentions:
861
- all_attentions = all_attentions + (layer_outputs[1],)
862
-
863
- if output_hidden_states:
864
- encoder_states = encoder_states + (hidden_states,)
865
-
866
- if not return_dict:
867
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
868
- return BaseModelOutput(
869
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
870
- )
871
-
977
+ attention_mask = create_bidirectional_mask(
978
+ config=self.config,
979
+ input_embeds=inputs_embeds,
980
+ attention_mask=attention_mask,
981
+ )
872
982
 
873
- class DetrDecoder(DetrPreTrainedModel):
874
- """
875
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
983
+ for encoder_layer in self.layers:
984
+ # we add spatial_position_embeddings as extra input to the encoder_layer
985
+ hidden_states = encoder_layer(
986
+ hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
987
+ )
876
988
 
877
- The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
989
+ return BaseModelOutput(last_hidden_state=hidden_states)
878
990
 
879
- Some small tweaks for DETR:
880
991
 
881
- - object_queries and query_position_embeddings are added to the forward pass.
882
- - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
992
+ class DetrDecoder(DetrPreTrainedModel):
993
+ """
994
+ Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules,
995
+ which apply self-attention to the queries and cross-attention to the encoder's outputs.
883
996
 
884
997
  Args:
885
- config: DetrConfig
998
+ config (`DetrConfig`): Model configuration object.
886
999
  """
887
1000
 
1001
+ _can_record_outputs = {
1002
+ "hidden_states": DetrDecoderLayer,
1003
+ "attentions": DetrSelfAttention,
1004
+ "cross_attentions": DetrCrossAttention,
1005
+ }
1006
+
888
1007
  def __init__(self, config: DetrConfig):
889
1008
  super().__init__(config)
890
1009
  self.dropout = config.dropout
891
- self.layerdrop = config.decoder_layerdrop
892
1010
 
893
1011
  self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
894
1012
  # in DETR, the decoder uses layernorm after the last decoder layer output
895
1013
  self.layernorm = nn.LayerNorm(config.d_model)
896
1014
 
897
- self.gradient_checkpointing = False
898
1015
  # Initialize weights and apply final processing
899
1016
  self.post_init()
900
1017
 
1018
+ @check_model_inputs()
901
1019
  def forward(
902
1020
  self,
903
1021
  inputs_embeds=None,
904
1022
  attention_mask=None,
905
1023
  encoder_hidden_states=None,
906
1024
  encoder_attention_mask=None,
907
- object_queries=None,
908
- query_position_embeddings=None,
909
- output_attentions=None,
910
- output_hidden_states=None,
911
- return_dict=None,
912
- **kwargs,
913
- ):
1025
+ spatial_position_embeddings=None,
1026
+ object_queries_position_embeddings=None,
1027
+ **kwargs: Unpack[TransformersKwargs],
1028
+ ) -> DetrDecoderOutput:
914
1029
  r"""
915
1030
  Args:
916
1031
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -933,108 +1048,62 @@ class DetrDecoder(DetrPreTrainedModel):
933
1048
  - 1 for pixels that are real (i.e. **not masked**),
934
1049
  - 0 for pixels that are padding (i.e. **masked**).
935
1050
 
936
- object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
937
- Object queries that are added to the queries and keys in each cross-attention layer.
938
- query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
939
- , *optional*): Position embeddings that are added to the values and keys in each self-attention layer.
940
-
941
- output_attentions (`bool`, *optional*):
942
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
943
- returned tensors for more detail.
944
- output_hidden_states (`bool`, *optional*):
945
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
946
- for more detail.
947
- return_dict (`bool`, *optional*):
948
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1051
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1052
+ Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer.
1053
+ object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1054
+ Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
949
1055
  """
950
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
951
- output_hidden_states = (
952
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
953
- )
954
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
955
1056
 
956
1057
  if inputs_embeds is not None:
957
1058
  hidden_states = inputs_embeds
958
- input_shape = inputs_embeds.size()[:-1]
959
-
960
- combined_attention_mask = None
961
1059
 
962
- if attention_mask is not None and combined_attention_mask is not None:
963
- # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
964
- combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask(
965
- attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1060
+ # expand decoder attention mask (for self-attention on object queries)
1061
+ if attention_mask is not None:
1062
+ # [batch_size, num_queries] -> [batch_size, 1, num_queries, num_queries]
1063
+ attention_mask = create_bidirectional_mask(
1064
+ config=self.config,
1065
+ input_embeds=inputs_embeds,
1066
+ attention_mask=attention_mask,
966
1067
  )
967
1068
 
968
- # expand encoder attention mask
1069
+ # expand encoder attention mask (for cross-attention on encoder outputs)
969
1070
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
970
1071
  # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
971
- encoder_attention_mask = _prepare_4d_attention_mask(
972
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1072
+ encoder_attention_mask = create_bidirectional_mask(
1073
+ config=self.config,
1074
+ input_embeds=inputs_embeds,
1075
+ attention_mask=encoder_attention_mask,
1076
+ encoder_hidden_states=encoder_hidden_states,
973
1077
  )
974
1078
 
975
1079
  # optional intermediate hidden states
976
1080
  intermediate = () if self.config.auxiliary_loss else None
977
1081
 
978
1082
  # decoder layers
979
- all_hidden_states = () if output_hidden_states else None
980
- all_self_attns = () if output_attentions else None
981
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
982
1083
 
983
1084
  for idx, decoder_layer in enumerate(self.layers):
984
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
985
- if output_hidden_states:
986
- all_hidden_states += (hidden_states,)
987
- if self.training:
988
- dropout_probability = torch.rand([])
989
- if dropout_probability < self.layerdrop:
990
- continue
991
-
992
- layer_outputs = decoder_layer(
1085
+ hidden_states = decoder_layer(
993
1086
  hidden_states,
994
- combined_attention_mask,
995
- object_queries,
996
- query_position_embeddings,
1087
+ attention_mask,
1088
+ spatial_position_embeddings,
1089
+ object_queries_position_embeddings,
997
1090
  encoder_hidden_states, # as a positional argument for gradient checkpointing
998
1091
  encoder_attention_mask=encoder_attention_mask,
999
- output_attentions=output_attentions,
1092
+ **kwargs,
1000
1093
  )
1001
1094
 
1002
- hidden_states = layer_outputs[0]
1003
-
1004
1095
  if self.config.auxiliary_loss:
1005
1096
  hidden_states = self.layernorm(hidden_states)
1006
1097
  intermediate += (hidden_states,)
1007
1098
 
1008
- if output_attentions:
1009
- all_self_attns += (layer_outputs[1],)
1010
-
1011
- if encoder_hidden_states is not None:
1012
- all_cross_attentions += (layer_outputs[2],)
1013
-
1014
1099
  # finally, apply layernorm
1015
1100
  hidden_states = self.layernorm(hidden_states)
1016
1101
 
1017
- # add hidden states from the last decoder layer
1018
- if output_hidden_states:
1019
- all_hidden_states += (hidden_states,)
1020
-
1021
1102
  # stack intermediate decoder activations
1022
1103
  if self.config.auxiliary_loss:
1023
1104
  intermediate = torch.stack(intermediate)
1024
1105
 
1025
- if not return_dict:
1026
- return tuple(
1027
- v
1028
- for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
1029
- if v is not None
1030
- )
1031
- return DetrDecoderOutput(
1032
- last_hidden_state=hidden_states,
1033
- hidden_states=all_hidden_states,
1034
- attentions=all_self_attns,
1035
- cross_attentions=all_cross_attentions,
1036
- intermediate_hidden_states=intermediate,
1037
- )
1106
+ return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate)
1038
1107
 
1039
1108
 
1040
1109
  @auto_docstring(
@@ -1047,15 +1116,16 @@ class DetrModel(DetrPreTrainedModel):
1047
1116
  def __init__(self, config: DetrConfig):
1048
1117
  super().__init__(config)
1049
1118
 
1050
- # Create backbone + positional encoding
1051
- backbone = DetrConvEncoder(config)
1052
- object_queries = build_position_encoding(config)
1053
- self.backbone = DetrConvModel(backbone, object_queries)
1054
-
1055
- # Create projection layer
1056
- self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
1119
+ self.backbone = DetrConvEncoder(config)
1057
1120
 
1121
+ if config.position_embedding_type == "sine":
1122
+ self.position_embedding = DetrSinePositionEmbedding(config.d_model // 2, normalize=True)
1123
+ elif config.position_embedding_type == "learned":
1124
+ self.position_embedding = DetrLearnedPositionEmbedding(config.d_model // 2)
1125
+ else:
1126
+ raise ValueError(f"Not supported {config.position_embedding_type}")
1058
1127
  self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
1128
+ self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
1059
1129
 
1060
1130
  self.encoder = DetrEncoder(config)
1061
1131
  self.decoder = DetrDecoder(config)
@@ -1064,46 +1134,49 @@ class DetrModel(DetrPreTrainedModel):
1064
1134
  self.post_init()
1065
1135
 
1066
1136
  def freeze_backbone(self):
1067
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1137
+ for _, param in self.backbone.model.named_parameters():
1068
1138
  param.requires_grad_(False)
1069
1139
 
1070
1140
  def unfreeze_backbone(self):
1071
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1141
+ for _, param in self.backbone.model.named_parameters():
1072
1142
  param.requires_grad_(True)
1073
1143
 
1074
1144
  @auto_docstring
1145
+ @can_return_tuple
1075
1146
  def forward(
1076
1147
  self,
1077
- pixel_values: torch.FloatTensor,
1078
- pixel_mask: Optional[torch.LongTensor] = None,
1079
- decoder_attention_mask: Optional[torch.FloatTensor] = None,
1080
- encoder_outputs: Optional[torch.FloatTensor] = None,
1081
- inputs_embeds: Optional[torch.FloatTensor] = None,
1082
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1083
- output_attentions: Optional[bool] = None,
1084
- output_hidden_states: Optional[bool] = None,
1085
- return_dict: Optional[bool] = None,
1086
- **kwargs,
1087
- ) -> Union[tuple[torch.FloatTensor], DetrModelOutput]:
1148
+ pixel_values: torch.FloatTensor | None = None,
1149
+ pixel_mask: torch.LongTensor | None = None,
1150
+ decoder_attention_mask: torch.FloatTensor | None = None,
1151
+ encoder_outputs: torch.FloatTensor | None = None,
1152
+ inputs_embeds: torch.FloatTensor | None = None,
1153
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
1154
+ **kwargs: Unpack[TransformersKwargs],
1155
+ ) -> tuple[torch.FloatTensor] | DetrModelOutput:
1088
1156
  r"""
1089
1157
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1090
- Not used by default. Can be used to mask object queries.
1158
+ Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
1159
+
1160
+ - 1 for queries that are **not masked**,
1161
+ - 0 for queries that are **masked**.
1091
1162
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1092
1163
  Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1093
- can choose to directly pass a flattened representation of an image.
1164
+ can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
1094
1165
  decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1095
1166
  Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1096
- embedded representation.
1167
+ embedded representation. Useful for tasks that require custom query initialization.
1097
1168
 
1098
1169
  Examples:
1099
1170
 
1100
1171
  ```python
1101
1172
  >>> from transformers import AutoImageProcessor, DetrModel
1102
1173
  >>> from PIL import Image
1103
- >>> import requests
1174
+ >>> import httpx
1175
+ >>> from io import BytesIO
1104
1176
 
1105
1177
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1106
- >>> image = Image.open(requests.get(url, stream=True).raw)
1178
+ >>> with httpx.stream("GET", url) as response:
1179
+ ... image = Image.open(BytesIO(response.read()))
1107
1180
 
1108
1181
  >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
1109
1182
  >>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
@@ -1120,79 +1193,77 @@ class DetrModel(DetrPreTrainedModel):
1120
1193
  >>> list(last_hidden_states.shape)
1121
1194
  [1, 100, 256]
1122
1195
  ```"""
1123
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1124
- output_hidden_states = (
1125
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1126
- )
1127
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1128
-
1129
- batch_size, num_channels, height, width = pixel_values.shape
1130
- device = pixel_values.device
1131
-
1132
- if pixel_mask is None:
1133
- pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1134
-
1135
- # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1136
- # pixel_values should be of shape (batch_size, num_channels, height, width)
1137
- # pixel_mask should be of shape (batch_size, height, width)
1138
- features, object_queries_list = self.backbone(pixel_values, pixel_mask)
1139
-
1140
- # get final feature map and downsampled mask
1141
- feature_map, mask = features[-1]
1142
-
1143
- if mask is None:
1144
- raise ValueError("Backbone does not return downsampled pixel mask")
1145
-
1146
- # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1147
- projected_feature_map = self.input_projection(feature_map)
1148
-
1149
- # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1150
- # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1151
- flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1152
- object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1153
-
1154
- flattened_mask = mask.flatten(1)
1196
+ if pixel_values is None and inputs_embeds is None:
1197
+ raise ValueError("You have to specify either pixel_values or inputs_embeds")
1198
+
1199
+ if inputs_embeds is None:
1200
+ batch_size, num_channels, height, width = pixel_values.shape
1201
+ device = pixel_values.device
1202
+
1203
+ if pixel_mask is None:
1204
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1205
+ vision_features = self.backbone(pixel_values, pixel_mask)
1206
+ feature_map, mask = vision_features[-1]
1207
+
1208
+ # Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
1209
+ # Position embeddings are already flattened to (batch_size, sequence_length, hidden_size) format
1210
+ projected_feature_map = self.input_projection(feature_map)
1211
+ flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1212
+ spatial_position_embeddings = self.position_embedding(
1213
+ shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
1214
+ )
1215
+ flattened_mask = mask.flatten(1)
1216
+ else:
1217
+ batch_size = inputs_embeds.shape[0]
1218
+ device = inputs_embeds.device
1219
+ flattened_features = inputs_embeds
1220
+ # When using inputs_embeds, we need to infer spatial dimensions for position embeddings
1221
+ # Assume square feature map
1222
+ seq_len = inputs_embeds.shape[1]
1223
+ feat_dim = int(seq_len**0.5)
1224
+ # Create position embeddings for the inferred spatial size
1225
+ spatial_position_embeddings = self.position_embedding(
1226
+ shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]),
1227
+ device=device,
1228
+ dtype=inputs_embeds.dtype,
1229
+ )
1230
+ # If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten.
1231
+ if pixel_mask is not None:
1232
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0]
1233
+ flattened_mask = mask.flatten(1)
1234
+ else:
1235
+ # If no mask provided, assume all positions are valid
1236
+ flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long)
1155
1237
 
1156
- # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
1157
- # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
1158
- # flattened_mask is a Tensor of shape (batch_size, height*width)
1159
1238
  if encoder_outputs is None:
1160
1239
  encoder_outputs = self.encoder(
1161
1240
  inputs_embeds=flattened_features,
1162
1241
  attention_mask=flattened_mask,
1163
- object_queries=object_queries,
1164
- output_attentions=output_attentions,
1165
- output_hidden_states=output_hidden_states,
1166
- return_dict=return_dict,
1167
- )
1168
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1169
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1170
- encoder_outputs = BaseModelOutput(
1171
- last_hidden_state=encoder_outputs[0],
1172
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1173
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1242
+ spatial_position_embeddings=spatial_position_embeddings,
1243
+ **kwargs,
1174
1244
  )
1175
1245
 
1176
- # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
1177
- query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
1178
- queries = torch.zeros_like(query_position_embeddings)
1246
+ object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
1247
+ batch_size, 1, 1
1248
+ )
1249
+
1250
+ # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
1251
+ if decoder_inputs_embeds is not None:
1252
+ queries = decoder_inputs_embeds
1253
+ else:
1254
+ queries = torch.zeros_like(object_queries_position_embeddings)
1179
1255
 
1180
1256
  # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1181
1257
  decoder_outputs = self.decoder(
1182
1258
  inputs_embeds=queries,
1183
- attention_mask=None,
1184
- object_queries=object_queries,
1185
- query_position_embeddings=query_position_embeddings,
1186
- encoder_hidden_states=encoder_outputs[0],
1259
+ attention_mask=decoder_attention_mask,
1260
+ spatial_position_embeddings=spatial_position_embeddings,
1261
+ object_queries_position_embeddings=object_queries_position_embeddings,
1262
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
1187
1263
  encoder_attention_mask=flattened_mask,
1188
- output_attentions=output_attentions,
1189
- output_hidden_states=output_hidden_states,
1190
- return_dict=return_dict,
1264
+ **kwargs,
1191
1265
  )
1192
1266
 
1193
- if not return_dict:
1194
- return decoder_outputs + encoder_outputs
1195
-
1196
1267
  return DetrModelOutput(
1197
1268
  last_hidden_state=decoder_outputs.last_hidden_state,
1198
1269
  decoder_hidden_states=decoder_outputs.hidden_states,
@@ -1205,14 +1276,11 @@ class DetrModel(DetrPreTrainedModel):
1205
1276
  )
1206
1277
 
1207
1278
 
1208
- # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1209
1279
  class DetrMLPPredictionHead(nn.Module):
1210
1280
  """
1211
1281
  Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1212
1282
  height and width of a bounding box w.r.t. an image.
1213
1283
 
1214
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1215
-
1216
1284
  """
1217
1285
 
1218
1286
  def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
@@ -1252,29 +1320,30 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1252
1320
  self.post_init()
1253
1321
 
1254
1322
  @auto_docstring
1323
+ @can_return_tuple
1255
1324
  def forward(
1256
1325
  self,
1257
1326
  pixel_values: torch.FloatTensor,
1258
- pixel_mask: Optional[torch.LongTensor] = None,
1259
- decoder_attention_mask: Optional[torch.FloatTensor] = None,
1260
- encoder_outputs: Optional[torch.FloatTensor] = None,
1261
- inputs_embeds: Optional[torch.FloatTensor] = None,
1262
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1263
- labels: Optional[list[dict]] = None,
1264
- output_attentions: Optional[bool] = None,
1265
- output_hidden_states: Optional[bool] = None,
1266
- return_dict: Optional[bool] = None,
1267
- **kwargs,
1268
- ) -> Union[tuple[torch.FloatTensor], DetrObjectDetectionOutput]:
1327
+ pixel_mask: torch.LongTensor | None = None,
1328
+ decoder_attention_mask: torch.FloatTensor | None = None,
1329
+ encoder_outputs: torch.FloatTensor | None = None,
1330
+ inputs_embeds: torch.FloatTensor | None = None,
1331
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
1332
+ labels: list[dict] | None = None,
1333
+ **kwargs: Unpack[TransformersKwargs],
1334
+ ) -> tuple[torch.FloatTensor] | DetrObjectDetectionOutput:
1269
1335
  r"""
1270
1336
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1271
- Not used by default. Can be used to mask object queries.
1337
+ Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
1338
+
1339
+ - 1 for queries that are **not masked**,
1340
+ - 0 for queries that are **masked**.
1272
1341
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1273
1342
  Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1274
- can choose to directly pass a flattened representation of an image.
1343
+ can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
1275
1344
  decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1276
1345
  Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1277
- embedded representation.
1346
+ embedded representation. Useful for tasks that require custom query initialization.
1278
1347
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1279
1348
  Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1280
1349
  following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
@@ -1287,10 +1356,12 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1287
1356
  >>> from transformers import AutoImageProcessor, DetrForObjectDetection
1288
1357
  >>> import torch
1289
1358
  >>> from PIL import Image
1290
- >>> import requests
1359
+ >>> import httpx
1360
+ >>> from io import BytesIO
1291
1361
 
1292
1362
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1293
- >>> image = Image.open(requests.get(url, stream=True).raw)
1363
+ >>> with httpx.stream("GET", url) as response:
1364
+ ... image = Image.open(BytesIO(response.read()))
1294
1365
 
1295
1366
  >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
1296
1367
  >>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
@@ -1316,7 +1387,6 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1316
1387
  Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
1317
1388
  Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
1318
1389
  ```"""
1319
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1320
1390
 
1321
1391
  # First, sent images through DETR base model to obtain encoder + decoder outputs
1322
1392
  outputs = self.model(
@@ -1326,9 +1396,7 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1326
1396
  encoder_outputs=encoder_outputs,
1327
1397
  inputs_embeds=inputs_embeds,
1328
1398
  decoder_inputs_embeds=decoder_inputs_embeds,
1329
- output_attentions=output_attentions,
1330
- output_hidden_states=output_hidden_states,
1331
- return_dict=return_dict,
1399
+ **kwargs,
1332
1400
  )
1333
1401
 
1334
1402
  sequence_output = outputs[0]
@@ -1341,20 +1409,13 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1341
1409
  if labels is not None:
1342
1410
  outputs_class, outputs_coord = None, None
1343
1411
  if self.config.auxiliary_loss:
1344
- intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
1412
+ intermediate = outputs.intermediate_hidden_states
1345
1413
  outputs_class = self.class_labels_classifier(intermediate)
1346
1414
  outputs_coord = self.bbox_predictor(intermediate).sigmoid()
1347
1415
  loss, loss_dict, auxiliary_outputs = self.loss_function(
1348
1416
  logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
1349
1417
  )
1350
1418
 
1351
- if not return_dict:
1352
- if auxiliary_outputs is not None:
1353
- output = (logits, pred_boxes) + auxiliary_outputs + outputs
1354
- else:
1355
- output = (logits, pred_boxes) + outputs
1356
- return ((loss, loss_dict) + output) if loss is not None else output
1357
-
1358
1419
  return DetrObjectDetectionOutput(
1359
1420
  loss=loss,
1360
1421
  loss_dict=loss_dict,
@@ -1378,6 +1439,26 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1378
1439
  """
1379
1440
  )
1380
1441
  class DetrForSegmentation(DetrPreTrainedModel):
1442
+ _checkpoint_conversion_mapping = {
1443
+ "bbox_attention.q_linear": "bbox_attention.q_proj",
1444
+ "bbox_attention.k_linear": "bbox_attention.k_proj",
1445
+ # Mask head refactor
1446
+ "mask_head.lay1": "mask_head.conv1.conv",
1447
+ "mask_head.gn1": "mask_head.conv1.norm",
1448
+ "mask_head.lay2": "mask_head.conv2.conv",
1449
+ "mask_head.gn2": "mask_head.conv2.norm",
1450
+ "mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter",
1451
+ "mask_head.lay3": "mask_head.fpn_stages.0.refine.conv",
1452
+ "mask_head.gn3": "mask_head.fpn_stages.0.refine.norm",
1453
+ "mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter",
1454
+ "mask_head.lay4": "mask_head.fpn_stages.1.refine.conv",
1455
+ "mask_head.gn4": "mask_head.fpn_stages.1.refine.norm",
1456
+ "mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter",
1457
+ "mask_head.lay5": "mask_head.fpn_stages.2.refine.conv",
1458
+ "mask_head.gn5": "mask_head.fpn_stages.2.refine.norm",
1459
+ "mask_head.out_lay": "mask_head.output_conv",
1460
+ }
1461
+
1381
1462
  def __init__(self, config: DetrConfig):
1382
1463
  super().__init__(config)
1383
1464
 
@@ -1386,42 +1467,44 @@ class DetrForSegmentation(DetrPreTrainedModel):
1386
1467
 
1387
1468
  # segmentation head
1388
1469
  hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
1389
- intermediate_channel_sizes = self.detr.model.backbone.conv_encoder.intermediate_channel_sizes
1470
+ intermediate_channel_sizes = self.detr.model.backbone.intermediate_channel_sizes
1390
1471
 
1391
1472
  self.mask_head = DetrMaskHeadSmallConv(
1392
- hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size
1473
+ input_channels=hidden_size + number_of_heads,
1474
+ fpn_channels=intermediate_channel_sizes[::-1][-3:],
1475
+ hidden_size=hidden_size,
1476
+ activation_function=config.activation_function,
1393
1477
  )
1394
1478
 
1395
- self.bbox_attention = DetrMHAttentionMap(
1396
- hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
1397
- )
1479
+ self.bbox_attention = DetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
1398
1480
  # Initialize weights and apply final processing
1399
1481
  self.post_init()
1400
1482
 
1401
1483
  @auto_docstring
1484
+ @can_return_tuple
1402
1485
  def forward(
1403
1486
  self,
1404
1487
  pixel_values: torch.FloatTensor,
1405
- pixel_mask: Optional[torch.LongTensor] = None,
1406
- decoder_attention_mask: Optional[torch.FloatTensor] = None,
1407
- encoder_outputs: Optional[torch.FloatTensor] = None,
1408
- inputs_embeds: Optional[torch.FloatTensor] = None,
1409
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1410
- labels: Optional[list[dict]] = None,
1411
- output_attentions: Optional[bool] = None,
1412
- output_hidden_states: Optional[bool] = None,
1413
- return_dict: Optional[bool] = None,
1414
- **kwargs,
1415
- ) -> Union[tuple[torch.FloatTensor], DetrSegmentationOutput]:
1488
+ pixel_mask: torch.LongTensor | None = None,
1489
+ decoder_attention_mask: torch.FloatTensor | None = None,
1490
+ encoder_outputs: torch.FloatTensor | None = None,
1491
+ inputs_embeds: torch.FloatTensor | None = None,
1492
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
1493
+ labels: list[dict] | None = None,
1494
+ **kwargs: Unpack[TransformersKwargs],
1495
+ ) -> tuple[torch.FloatTensor] | DetrSegmentationOutput:
1416
1496
  r"""
1417
1497
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1418
- Not used by default. Can be used to mask object queries.
1498
+ Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
1499
+
1500
+ - 1 for queries that are **not masked**,
1501
+ - 0 for queries that are **masked**.
1419
1502
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1420
- Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1421
- can choose to directly pass a flattened representation of an image.
1503
+ Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
1504
+ multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
1422
1505
  decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1423
1506
  Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1424
- embedded representation.
1507
+ embedded representation. Useful for tasks that require custom query initialization.
1425
1508
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1426
1509
  Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
1427
1510
  dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
@@ -1434,7 +1517,8 @@ class DetrForSegmentation(DetrPreTrainedModel):
1434
1517
 
1435
1518
  ```python
1436
1519
  >>> import io
1437
- >>> import requests
1520
+ >>> import httpx
1521
+ >>> from io import BytesIO
1438
1522
  >>> from PIL import Image
1439
1523
  >>> import torch
1440
1524
  >>> import numpy
@@ -1443,7 +1527,8 @@ class DetrForSegmentation(DetrPreTrainedModel):
1443
1527
  >>> from transformers.image_transforms import rgb_to_id
1444
1528
 
1445
1529
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1446
- >>> image = Image.open(requests.get(url, stream=True).raw)
1530
+ >>> with httpx.stream("GET", url) as response:
1531
+ ... image = Image.open(BytesIO(response.read()))
1447
1532
 
1448
1533
  >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
1449
1534
  >>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
@@ -1468,83 +1553,77 @@ class DetrForSegmentation(DetrPreTrainedModel):
1468
1553
  5
1469
1554
  ```"""
1470
1555
 
1471
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1472
-
1473
1556
  batch_size, num_channels, height, width = pixel_values.shape
1474
1557
  device = pixel_values.device
1475
1558
 
1476
1559
  if pixel_mask is None:
1477
1560
  pixel_mask = torch.ones((batch_size, height, width), device=device)
1478
1561
 
1479
- # First, get list of feature maps and position embeddings
1480
- features, object_queries_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
1562
+ vision_features = self.detr.model.backbone(pixel_values, pixel_mask)
1563
+ feature_map, mask = vision_features[-1]
1481
1564
 
1482
- # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1483
- feature_map, mask = features[-1]
1484
- batch_size, num_channels, height, width = feature_map.shape
1565
+ # Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
1485
1566
  projected_feature_map = self.detr.model.input_projection(feature_map)
1486
-
1487
- # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1488
- # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1489
1567
  flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1490
- object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1491
-
1568
+ spatial_position_embeddings = self.detr.model.position_embedding(
1569
+ shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
1570
+ )
1492
1571
  flattened_mask = mask.flatten(1)
1493
1572
 
1494
- # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
1495
- # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
1496
- # flattened_mask is a Tensor of shape (batch_size, height*width)
1497
1573
  if encoder_outputs is None:
1498
1574
  encoder_outputs = self.detr.model.encoder(
1499
1575
  inputs_embeds=flattened_features,
1500
1576
  attention_mask=flattened_mask,
1501
- object_queries=object_queries,
1502
- output_attentions=output_attentions,
1503
- output_hidden_states=output_hidden_states,
1504
- return_dict=return_dict,
1505
- )
1506
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1507
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1508
- encoder_outputs = BaseModelOutput(
1509
- last_hidden_state=encoder_outputs[0],
1510
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1511
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1577
+ spatial_position_embeddings=spatial_position_embeddings,
1578
+ **kwargs,
1512
1579
  )
1513
1580
 
1514
- # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
1515
- query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
1581
+ object_queries_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
1516
1582
  batch_size, 1, 1
1517
1583
  )
1518
- queries = torch.zeros_like(query_position_embeddings)
1519
1584
 
1520
- # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1585
+ # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
1586
+ if decoder_inputs_embeds is not None:
1587
+ queries = decoder_inputs_embeds
1588
+ else:
1589
+ queries = torch.zeros_like(object_queries_position_embeddings)
1590
+
1521
1591
  decoder_outputs = self.detr.model.decoder(
1522
1592
  inputs_embeds=queries,
1523
- attention_mask=None,
1524
- object_queries=object_queries,
1525
- query_position_embeddings=query_position_embeddings,
1526
- encoder_hidden_states=encoder_outputs[0],
1593
+ attention_mask=decoder_attention_mask,
1594
+ spatial_position_embeddings=spatial_position_embeddings,
1595
+ object_queries_position_embeddings=object_queries_position_embeddings,
1596
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
1527
1597
  encoder_attention_mask=flattened_mask,
1528
- output_attentions=output_attentions,
1529
- output_hidden_states=output_hidden_states,
1530
- return_dict=return_dict,
1598
+ **kwargs,
1531
1599
  )
1532
1600
 
1533
1601
  sequence_output = decoder_outputs[0]
1534
1602
 
1535
- # Sixth, compute logits, pred_boxes and pred_masks
1536
1603
  logits = self.detr.class_labels_classifier(sequence_output)
1537
1604
  pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
1538
1605
 
1539
- memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width)
1540
- mask = flattened_mask.view(batch_size, height, width)
1606
+ height, width = feature_map.shape[-2:]
1607
+ memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
1608
+ batch_size, self.config.d_model, height, width
1609
+ )
1610
+ attention_mask = flattened_mask.view(batch_size, height, width)
1541
1611
 
1542
- # FIXME h_boxes takes the last one computed, keep this in mind
1543
- # important: we need to reverse the mask, since in the original implementation the mask works reversed
1544
- # bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
1545
- bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)
1612
+ if attention_mask is not None:
1613
+ min_dtype = torch.finfo(memory.dtype).min
1614
+ attention_mask = torch.where(
1615
+ attention_mask.unsqueeze(1).unsqueeze(1),
1616
+ torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
1617
+ min_dtype,
1618
+ )
1546
1619
 
1547
- seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]])
1620
+ bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
1621
+
1622
+ seg_masks = self.mask_head(
1623
+ features=projected_feature_map,
1624
+ attention_masks=bbox_mask,
1625
+ fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
1626
+ )
1548
1627
 
1549
1628
  pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
1550
1629
 
@@ -1552,20 +1631,13 @@ class DetrForSegmentation(DetrPreTrainedModel):
1552
1631
  if labels is not None:
1553
1632
  outputs_class, outputs_coord = None, None
1554
1633
  if self.config.auxiliary_loss:
1555
- intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
1634
+ intermediate = decoder_outputs.intermediate_hidden_states
1556
1635
  outputs_class = self.detr.class_labels_classifier(intermediate)
1557
1636
  outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
1558
1637
  loss, loss_dict, auxiliary_outputs = self.loss_function(
1559
1638
  logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
1560
1639
  )
1561
1640
 
1562
- if not return_dict:
1563
- if auxiliary_outputs is not None:
1564
- output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
1565
- else:
1566
- output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
1567
- return ((loss, loss_dict) + output) if loss is not None else output
1568
-
1569
1641
  return DetrSegmentationOutput(
1570
1642
  loss=loss,
1571
1643
  loss_dict=loss_dict,
@@ -1583,119 +1655,6 @@ class DetrForSegmentation(DetrPreTrainedModel):
1583
1655
  )
1584
1656
 
1585
1657
 
1586
- def _expand(tensor, length: int):
1587
- return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
1588
-
1589
-
1590
- # taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
1591
- class DetrMaskHeadSmallConv(nn.Module):
1592
- """
1593
- Simple convolutional head, using group norm. Upsampling is done using a FPN approach
1594
- """
1595
-
1596
- def __init__(self, dim, fpn_dims, context_dim):
1597
- super().__init__()
1598
-
1599
- if dim % 8 != 0:
1600
- raise ValueError(
1601
- "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
1602
- " GroupNorm is set to 8"
1603
- )
1604
-
1605
- inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
1606
-
1607
- self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
1608
- self.gn1 = nn.GroupNorm(8, dim)
1609
- self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
1610
- self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
1611
- self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
1612
- self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
1613
- self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
1614
- self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
1615
- self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
1616
- self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
1617
- self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
1618
-
1619
- self.dim = dim
1620
-
1621
- self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
1622
- self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
1623
- self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
1624
-
1625
- for m in self.modules():
1626
- if isinstance(m, nn.Conv2d):
1627
- init.kaiming_uniform_(m.weight, a=1)
1628
- init.constant_(m.bias, 0)
1629
-
1630
- def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]):
1631
- # here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with
1632
- # the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
1633
- # We expand the projected feature map to match the number of heads.
1634
- x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
1635
-
1636
- x = self.lay1(x)
1637
- x = self.gn1(x)
1638
- x = nn.functional.relu(x)
1639
- x = self.lay2(x)
1640
- x = self.gn2(x)
1641
- x = nn.functional.relu(x)
1642
-
1643
- cur_fpn = self.adapter1(fpns[0])
1644
- if cur_fpn.size(0) != x.size(0):
1645
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1646
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1647
- x = self.lay3(x)
1648
- x = self.gn3(x)
1649
- x = nn.functional.relu(x)
1650
-
1651
- cur_fpn = self.adapter2(fpns[1])
1652
- if cur_fpn.size(0) != x.size(0):
1653
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1654
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1655
- x = self.lay4(x)
1656
- x = self.gn4(x)
1657
- x = nn.functional.relu(x)
1658
-
1659
- cur_fpn = self.adapter3(fpns[2])
1660
- if cur_fpn.size(0) != x.size(0):
1661
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1662
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1663
- x = self.lay5(x)
1664
- x = self.gn5(x)
1665
- x = nn.functional.relu(x)
1666
-
1667
- x = self.out_lay(x)
1668
- return x
1669
-
1670
-
1671
- class DetrMHAttentionMap(nn.Module):
1672
- """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
1673
-
1674
- def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
1675
- super().__init__()
1676
- self.num_heads = num_heads
1677
- self.hidden_dim = hidden_dim
1678
- self.dropout = nn.Dropout(dropout)
1679
-
1680
- self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1681
- self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1682
-
1683
- self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
1684
-
1685
- def forward(self, q, k, mask: Optional[Tensor] = None):
1686
- q = self.q_linear(q)
1687
- k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
1688
- queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
1689
- keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
1690
- weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
1691
-
1692
- if mask is not None:
1693
- weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
1694
- weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
1695
- weights = self.dropout(weights)
1696
- return weights
1697
-
1698
-
1699
1658
  __all__ = [
1700
1659
  "DetrForObjectDetection",
1701
1660
  "DetrForSegmentation",