transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,9 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/deformable_detr/modular_deformable_detr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_deformable_detr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
1
7
  # Copyright 2022 SenseTime and The HuggingFace Inc. team. All rights reserved.
2
8
  #
3
9
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,128 +17,54 @@
11
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
18
  # See the License for the specific language governing permissions and
13
19
  # limitations under the License.
14
- """PyTorch Deformable DETR model."""
15
-
16
20
  import math
17
21
  import warnings
22
+ from collections.abc import Callable
18
23
  from dataclasses import dataclass
19
- from typing import Any
20
24
 
21
25
  import torch
26
+ import torch.nn as nn
22
27
  import torch.nn.functional as F
23
- from torch import Tensor, nn
28
+ from torch import Tensor
24
29
 
25
30
  from ... import initialization as init
26
31
  from ...activations import ACT2FN
32
+ from ...backbone_utils import load_backbone
27
33
  from ...integrations import use_kernel_forward_from_hub
28
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
29
34
  from ...modeling_layers import GradientCheckpointingLayer
30
- from ...modeling_outputs import BaseModelOutput
31
- from ...modeling_utils import PreTrainedModel
32
- from ...pytorch_utils import meshgrid
33
- from ...utils import (
34
- ModelOutput,
35
- auto_docstring,
36
- is_timm_available,
37
- logging,
38
- requires_backends,
39
- )
40
- from ...utils.backbone_utils import load_backbone
35
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
36
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from ...processing_utils import Unpack
38
+ from ...pytorch_utils import compile_compatible_method_lru_cache, meshgrid
39
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check
40
+ from ...utils.generic import OutputRecorder, can_return_tuple, check_model_inputs
41
41
  from .configuration_deformable_detr import DeformableDetrConfig
42
42
 
43
43
 
44
- logger = logging.get_logger(__name__)
45
-
46
-
47
- if is_timm_available():
48
- from timm import create_model
49
-
50
-
51
- logger = logging.get_logger(__name__)
52
-
53
-
54
- @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
55
- class MultiScaleDeformableAttention(nn.Module):
56
- def forward(
57
- self,
58
- value: Tensor,
59
- value_spatial_shapes: Tensor,
60
- value_spatial_shapes_list: list[tuple],
61
- level_start_index: Tensor,
62
- sampling_locations: Tensor,
63
- attention_weights: Tensor,
64
- im2col_step: int,
65
- ):
66
- batch_size, _, num_heads, hidden_dim = value.shape
67
- _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
68
- value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
69
- sampling_grids = 2 * sampling_locations - 1
70
- sampling_value_list = []
71
- for level_id, (height, width) in enumerate(value_spatial_shapes_list):
72
- # batch_size, height*width, num_heads, hidden_dim
73
- # -> batch_size, height*width, num_heads*hidden_dim
74
- # -> batch_size, num_heads*hidden_dim, height*width
75
- # -> batch_size*num_heads, hidden_dim, height, width
76
- value_l_ = (
77
- value_list[level_id]
78
- .flatten(2)
79
- .transpose(1, 2)
80
- .reshape(batch_size * num_heads, hidden_dim, height, width)
81
- )
82
- # batch_size, num_queries, num_heads, num_points, 2
83
- # -> batch_size, num_heads, num_queries, num_points, 2
84
- # -> batch_size*num_heads, num_queries, num_points, 2
85
- sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
86
- # batch_size*num_heads, hidden_dim, num_queries, num_points
87
- sampling_value_l_ = nn.functional.grid_sample(
88
- value_l_,
89
- sampling_grid_l_,
90
- mode="bilinear",
91
- padding_mode="zeros",
92
- align_corners=False,
93
- )
94
- sampling_value_list.append(sampling_value_l_)
95
- # (batch_size, num_queries, num_heads, num_levels, num_points)
96
- # -> (batch_size, num_heads, num_queries, num_levels, num_points)
97
- # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
98
- attention_weights = attention_weights.transpose(1, 2).reshape(
99
- batch_size * num_heads, 1, num_queries, num_levels * num_points
100
- )
101
- output = (
102
- (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
103
- .sum(-1)
104
- .view(batch_size, num_heads * hidden_dim, num_queries)
105
- )
106
- return output.transpose(1, 2).contiguous()
107
-
108
-
109
44
  @dataclass
110
45
  @auto_docstring(
111
46
  custom_intro="""
112
- Base class for outputs of the DeformableDetrDecoder. This class adds two attributes to
113
- BaseModelOutputWithCrossAttentions, namely:
114
- - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
115
- - a stacked tensor of intermediate reference points.
47
+ Base class for outputs of the DEFORMABLE_DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
48
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
49
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
116
50
  """
117
51
  )
118
- class DeformableDetrDecoderOutput(ModelOutput):
52
+ class DeformableDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
119
53
  r"""
120
- intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
121
- Stacked intermediate hidden states (output of each layer of the decoder).
122
- intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
123
- Stacked intermediate reference points (reference points of each layer of the decoder).
124
54
  cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
125
55
  Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
126
56
  sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
127
57
  used to compute the weighted average in the cross-attention heads.
58
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
59
+ Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
60
+ layernorm.
61
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
62
+ Stacked intermediate reference points (reference points of each layer of the decoder).
128
63
  """
129
64
 
130
- last_hidden_state: torch.FloatTensor | None = None
131
65
  intermediate_hidden_states: torch.FloatTensor | None = None
66
+
132
67
  intermediate_reference_points: torch.FloatTensor | None = None
133
- hidden_states: tuple[torch.FloatTensor] | None = None
134
- attentions: tuple[torch.FloatTensor] | None = None
135
- cross_attentions: tuple[torch.FloatTensor] | None = None
136
68
 
137
69
 
138
70
  @dataclass
@@ -198,10 +130,10 @@ class DeformableDetrObjectDetectionOutput(ModelOutput):
198
130
  Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
199
131
  and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
200
132
  `pred_boxes`) for each decoder layer.
201
- init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
202
- Initial reference points sent through the Transformer decoder.
203
133
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
204
134
  Sequence of hidden-states at the output of the last layer of the decoder of the model.
135
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
136
+ Initial reference points sent through the Transformer decoder.
205
137
  intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
206
138
  Stacked intermediate hidden states (output of each layer of the decoder).
207
139
  intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
@@ -219,28 +151,76 @@ class DeformableDetrObjectDetectionOutput(ModelOutput):
219
151
  logits: torch.FloatTensor | None = None
220
152
  pred_boxes: torch.FloatTensor | None = None
221
153
  auxiliary_outputs: list[dict] | None = None
222
- init_reference_points: torch.FloatTensor | None = None
223
154
  last_hidden_state: torch.FloatTensor | None = None
224
- intermediate_hidden_states: torch.FloatTensor | None = None
225
- intermediate_reference_points: torch.FloatTensor | None = None
226
155
  decoder_hidden_states: tuple[torch.FloatTensor] | None = None
227
156
  decoder_attentions: tuple[torch.FloatTensor] | None = None
228
157
  cross_attentions: tuple[torch.FloatTensor] | None = None
229
158
  encoder_last_hidden_state: torch.FloatTensor | None = None
230
159
  encoder_hidden_states: tuple[torch.FloatTensor] | None = None
231
160
  encoder_attentions: tuple[torch.FloatTensor] | None = None
232
- enc_outputs_class: Any = None
161
+
162
+ init_reference_points: torch.FloatTensor | None = None
163
+ intermediate_hidden_states: torch.FloatTensor | None = None
164
+ intermediate_reference_points: torch.FloatTensor | None = None
165
+ enc_outputs_class: torch.FloatTensor | None = None
233
166
  enc_outputs_coord_logits: torch.FloatTensor | None = None
234
167
 
235
168
 
236
- def inverse_sigmoid(x, eps=1e-5):
237
- x = x.clamp(min=0, max=1)
238
- x1 = x.clamp(min=eps)
239
- x2 = (1 - x).clamp(min=eps)
240
- return torch.log(x1 / x2)
169
+ @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
170
+ class MultiScaleDeformableAttention(nn.Module):
171
+ def forward(
172
+ self,
173
+ value: Tensor,
174
+ value_spatial_shapes: Tensor,
175
+ value_spatial_shapes_list: list[tuple],
176
+ level_start_index: Tensor,
177
+ sampling_locations: Tensor,
178
+ attention_weights: Tensor,
179
+ im2col_step: int,
180
+ ):
181
+ batch_size, _, num_heads, hidden_dim = value.shape
182
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
183
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
184
+ sampling_grids = 2 * sampling_locations - 1
185
+ sampling_value_list = []
186
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
187
+ # batch_size, height*width, num_heads, hidden_dim
188
+ # -> batch_size, height*width, num_heads*hidden_dim
189
+ # -> batch_size, num_heads*hidden_dim, height*width
190
+ # -> batch_size*num_heads, hidden_dim, height, width
191
+ value_l_ = (
192
+ value_list[level_id]
193
+ .flatten(2)
194
+ .transpose(1, 2)
195
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
196
+ )
197
+ # batch_size, num_queries, num_heads, num_points, 2
198
+ # -> batch_size, num_heads, num_queries, num_points, 2
199
+ # -> batch_size*num_heads, num_queries, num_points, 2
200
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
201
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
202
+ sampling_value_l_ = nn.functional.grid_sample(
203
+ value_l_,
204
+ sampling_grid_l_,
205
+ mode="bilinear",
206
+ padding_mode="zeros",
207
+ align_corners=False,
208
+ )
209
+ sampling_value_list.append(sampling_value_l_)
210
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
211
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
212
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
213
+ attention_weights = attention_weights.transpose(1, 2).reshape(
214
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
215
+ )
216
+ output = (
217
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
218
+ .sum(-1)
219
+ .view(batch_size, num_heads * hidden_dim, num_queries)
220
+ )
221
+ return output.transpose(1, 2).contiguous()
241
222
 
242
223
 
243
- # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->DeformableDetr
244
224
  class DeformableDetrFrozenBatchNorm2d(nn.Module):
245
225
  """
246
226
  BatchNorm2d where the batch statistics and the affine parameters are fixed.
@@ -280,7 +260,6 @@ class DeformableDetrFrozenBatchNorm2d(nn.Module):
280
260
  return x * scale + bias
281
261
 
282
262
 
283
- # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DeformableDetr
284
263
  def replace_batch_norm(model):
285
264
  r"""
286
265
  Recursively replace all `torch.nn.BatchNorm2d` with `DeformableDetrFrozenBatchNorm2d`.
@@ -318,57 +297,36 @@ class DeformableDetrConvEncoder(nn.Module):
318
297
 
319
298
  self.config = config
320
299
 
321
- # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
322
- if config.use_timm_backbone:
323
- # We default to values which were previously hard-coded. This enables configurability from the config
324
- # using backbone arguments, while keeping the default behavior the same.
325
- requires_backends(self, ["timm"])
326
- kwargs = getattr(config, "backbone_kwargs", {})
327
- kwargs = {} if kwargs is None else kwargs.copy()
328
- out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,))
329
- num_channels = kwargs.pop("in_chans", config.num_channels)
330
- if config.dilation:
331
- kwargs["output_stride"] = kwargs.get("output_stride", 16)
332
- backbone = create_model(
333
- config.backbone,
334
- pretrained=config.use_pretrained_backbone,
335
- features_only=True,
336
- out_indices=out_indices,
337
- in_chans=num_channels,
338
- **kwargs,
339
- )
340
- else:
341
- backbone = load_backbone(config)
300
+ backbone = load_backbone(config)
301
+ self.intermediate_channel_sizes = backbone.channels
342
302
 
343
303
  # replace batch norm by frozen batch norm
344
304
  with torch.no_grad():
345
305
  replace_batch_norm(backbone)
346
- self.model = backbone
347
- self.intermediate_channel_sizes = (
348
- self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
349
- )
350
306
 
351
- backbone_model_type = None
352
- if config.backbone is not None:
353
- backbone_model_type = config.backbone
354
- elif config.backbone_config is not None:
355
- backbone_model_type = config.backbone_config.model_type
356
- else:
357
- raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
307
+ # We used to load with timm library directly instead of the AutoBackbone API
308
+ # so we need to unwrap the `backbone._backbone` module to load weights without mismatch
309
+ is_timm_model = False
310
+ if hasattr(backbone, "_backbone"):
311
+ backbone = backbone._backbone
312
+ is_timm_model = True
313
+ self.model = backbone
358
314
 
315
+ backbone_model_type = config.backbone_config.model_type
359
316
  if "resnet" in backbone_model_type:
360
317
  for name, parameter in self.model.named_parameters():
361
- if config.use_timm_backbone:
318
+ if is_timm_model:
362
319
  if "layer2" not in name and "layer3" not in name and "layer4" not in name:
363
320
  parameter.requires_grad_(False)
364
321
  else:
365
322
  if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
366
323
  parameter.requires_grad_(False)
367
324
 
368
- # Copied from transformers.models.detr.modeling_detr.DetrConvEncoder.forward with Detr->DeformableDetr
369
325
  def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
370
326
  # send pixel_values through the model to get list of feature maps
371
- features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
327
+ features = self.model(pixel_values)
328
+ if isinstance(features, dict):
329
+ features = features.feature_maps
372
330
 
373
331
  out = []
374
332
  for feature_map in features:
@@ -378,67 +336,58 @@ class DeformableDetrConvEncoder(nn.Module):
378
336
  return out
379
337
 
380
338
 
381
- # Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->DeformableDetr
382
- class DeformableDetrConvModel(nn.Module):
383
- """
384
- This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
385
- """
386
-
387
- def __init__(self, conv_encoder, position_embedding):
388
- super().__init__()
389
- self.conv_encoder = conv_encoder
390
- self.position_embedding = position_embedding
391
-
392
- def forward(self, pixel_values, pixel_mask):
393
- # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
394
- out = self.conv_encoder(pixel_values, pixel_mask)
395
- pos = []
396
- for feature_map, mask in out:
397
- # position encoding
398
- pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
399
-
400
- return out, pos
401
-
402
-
403
339
  class DeformableDetrSinePositionEmbedding(nn.Module):
404
340
  """
405
341
  This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
406
342
  need paper, generalized to work on images.
407
343
  """
408
344
 
409
- def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
345
+ def __init__(
346
+ self,
347
+ num_position_features: int = 64,
348
+ temperature: int = 10000,
349
+ normalize: bool = False,
350
+ scale: float | None = None,
351
+ ):
410
352
  super().__init__()
411
- self.embedding_dim = embedding_dim
412
- self.temperature = temperature
413
- self.normalize = normalize
414
353
  if scale is not None and normalize is False:
415
354
  raise ValueError("normalize should be True if scale is passed")
416
- if scale is None:
417
- scale = 2 * math.pi
418
- self.scale = scale
355
+ self.num_position_features = num_position_features
356
+ self.temperature = temperature
357
+ self.normalize = normalize
358
+ self.scale = 2 * math.pi if scale is None else scale
419
359
 
420
- def forward(self, pixel_values, pixel_mask):
421
- if pixel_mask is None:
422
- raise ValueError("No pixel mask provided")
423
- y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype)
424
- x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype)
360
+ @compile_compatible_method_lru_cache(maxsize=1)
361
+ def forward(
362
+ self,
363
+ shape: torch.Size,
364
+ device: torch.device | str,
365
+ dtype: torch.dtype,
366
+ mask: torch.Tensor | None = None,
367
+ ) -> torch.Tensor:
368
+ if mask is None:
369
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
370
+ y_embed = mask.cumsum(1, dtype=dtype)
371
+ x_embed = mask.cumsum(2, dtype=dtype)
425
372
  if self.normalize:
426
373
  eps = 1e-6
427
374
  y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
428
375
  x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
429
376
 
430
- dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device)
431
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
377
+ dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
378
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
432
379
 
433
380
  pos_x = x_embed[:, :, :, None] / dim_t
434
381
  pos_y = y_embed[:, :, :, None] / dim_t
435
382
  pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
436
383
  pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
437
384
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
385
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
386
+ # expected by the encoder
387
+ pos = pos.flatten(2).permute(0, 2, 1)
438
388
  return pos
439
389
 
440
390
 
441
- # Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding
442
391
  class DeformableDetrLearnedPositionEmbedding(nn.Module):
443
392
  """
444
393
  This module learns positional embeddings up to a fixed maximum size.
@@ -449,31 +398,122 @@ class DeformableDetrLearnedPositionEmbedding(nn.Module):
449
398
  self.row_embeddings = nn.Embedding(50, embedding_dim)
450
399
  self.column_embeddings = nn.Embedding(50, embedding_dim)
451
400
 
452
- def forward(self, pixel_values, pixel_mask=None):
453
- height, width = pixel_values.shape[-2:]
454
- width_values = torch.arange(width, device=pixel_values.device)
455
- height_values = torch.arange(height, device=pixel_values.device)
401
+ @compile_compatible_method_lru_cache(maxsize=1)
402
+ def forward(
403
+ self,
404
+ shape: torch.Size,
405
+ device: torch.device | str,
406
+ dtype: torch.dtype,
407
+ mask: torch.Tensor | None = None,
408
+ ):
409
+ height, width = shape[-2:]
410
+ width_values = torch.arange(width, device=device)
411
+ height_values = torch.arange(height, device=device)
456
412
  x_emb = self.column_embeddings(width_values)
457
413
  y_emb = self.row_embeddings(height_values)
458
414
  pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
459
415
  pos = pos.permute(2, 0, 1)
460
416
  pos = pos.unsqueeze(0)
461
- pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
417
+ pos = pos.repeat(shape[0], 1, 1, 1)
418
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
419
+ # expected by the encoder
420
+ pos = pos.flatten(2).permute(0, 2, 1)
462
421
  return pos
463
422
 
464
423
 
465
- # Copied from transformers.models.detr.modeling_detr.build_position_encoding with Detr->DeformableDetr
466
- def build_position_encoding(config):
467
- n_steps = config.d_model // 2
468
- if config.position_embedding_type == "sine":
469
- # TODO find a better way of exposing other arguments
470
- position_embedding = DeformableDetrSinePositionEmbedding(n_steps, normalize=True)
471
- elif config.position_embedding_type == "learned":
472
- position_embedding = DeformableDetrLearnedPositionEmbedding(n_steps)
473
- else:
474
- raise ValueError(f"Not supported {config.position_embedding_type}")
424
+ def eager_attention_forward(
425
+ module: nn.Module,
426
+ query: torch.Tensor,
427
+ key: torch.Tensor,
428
+ value: torch.Tensor,
429
+ attention_mask: torch.Tensor | None,
430
+ scaling: float | None = None,
431
+ dropout: float = 0.0,
432
+ **kwargs: Unpack[TransformersKwargs],
433
+ ):
434
+ if scaling is None:
435
+ scaling = query.size(-1) ** -0.5
436
+
437
+ # Take the dot product between "query" and "key" to get the raw attention scores.
438
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
439
+
440
+ if attention_mask is not None:
441
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
442
+ attn_weights = attn_weights + attention_mask
443
+
444
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
445
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
446
+
447
+ attn_output = torch.matmul(attn_weights, value)
448
+ attn_output = attn_output.transpose(1, 2).contiguous()
449
+
450
+ return attn_output, attn_weights
451
+
452
+
453
+ class DeformableDetrSelfAttention(nn.Module):
454
+ """
455
+ Multi-headed self-attention from 'Attention Is All You Need' paper.
456
+
457
+ In DEFORMABLE_DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
458
+ """
459
+
460
+ def __init__(
461
+ self,
462
+ config: DeformableDetrConfig,
463
+ hidden_size: int,
464
+ num_attention_heads: int,
465
+ dropout: float = 0.0,
466
+ bias: bool = True,
467
+ ):
468
+ super().__init__()
469
+ self.config = config
470
+ self.head_dim = hidden_size // num_attention_heads
471
+ self.scaling = self.head_dim**-0.5
472
+ self.attention_dropout = dropout
473
+ self.is_causal = False
474
+
475
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
476
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
477
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
478
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
479
+
480
+ def forward(
481
+ self,
482
+ hidden_states: torch.Tensor,
483
+ attention_mask: torch.Tensor | None = None,
484
+ position_embeddings: torch.Tensor | None = None,
485
+ **kwargs: Unpack[TransformersKwargs],
486
+ ) -> tuple[torch.Tensor, torch.Tensor]:
487
+ """
488
+ Position embeddings are added to both queries and keys (but not values).
489
+ """
490
+ input_shape = hidden_states.shape[:-1]
491
+ hidden_shape = (*input_shape, -1, self.head_dim)
492
+
493
+ query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
494
+
495
+ query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
496
+ key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
497
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
498
+
499
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
500
+ self.config._attn_implementation, eager_attention_forward
501
+ )
502
+
503
+ attn_output, attn_weights = attention_interface(
504
+ self,
505
+ query_states,
506
+ key_states,
507
+ value_states,
508
+ attention_mask,
509
+ dropout=0.0 if not self.training else self.attention_dropout,
510
+ scaling=self.scaling,
511
+ **kwargs,
512
+ )
475
513
 
476
- return position_embedding
514
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
515
+ attn_output = self.o_proj(attn_output)
516
+ return attn_output, attn_weights
477
517
 
478
518
 
479
519
  class DeformableDetrMultiscaleDeformableAttention(nn.Module):
@@ -513,9 +553,6 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
513
553
 
514
554
  self.disable_custom_kernels = config.disable_custom_kernels
515
555
 
516
- def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
517
- return tensor if position_embeddings is None else tensor + position_embeddings
518
-
519
556
  def forward(
520
557
  self,
521
558
  hidden_states: torch.Tensor,
@@ -527,19 +564,19 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
527
564
  spatial_shapes=None,
528
565
  spatial_shapes_list=None,
529
566
  level_start_index=None,
530
- output_attentions: bool = False,
531
- ):
567
+ **kwargs: Unpack[TransformersKwargs],
568
+ ) -> tuple[torch.Tensor, torch.Tensor]:
532
569
  # add position embeddings to the hidden states before projecting to queries and keys
533
570
  if position_embeddings is not None:
534
- hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
571
+ hidden_states = hidden_states + position_embeddings
535
572
 
536
573
  batch_size, num_queries, _ = hidden_states.shape
537
574
  batch_size, sequence_length, _ = encoder_hidden_states.shape
538
575
  total_elements = sum(height * width for height, width in spatial_shapes_list)
539
- if total_elements != sequence_length:
540
- raise ValueError(
541
- "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
542
- )
576
+ torch_compilable_check(
577
+ total_elements == sequence_length,
578
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
579
+ )
543
580
 
544
581
  value = self.value_proj(encoder_hidden_states)
545
582
  if attention_mask is not None:
@@ -586,159 +623,48 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
586
623
  return output, attention_weights
587
624
 
588
625
 
589
- class DeformableDetrMultiheadAttention(nn.Module):
590
- """
591
- Multi-headed attention from 'Attention Is All You Need' paper.
592
-
593
- Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
594
- """
595
-
596
- def __init__(
597
- self,
598
- embed_dim: int,
599
- num_heads: int,
600
- dropout: float = 0.0,
601
- bias: bool = True,
602
- ):
626
+ class DeformableDetrMLP(nn.Module):
627
+ def __init__(self, config: DeformableDetrConfig, hidden_size: int, intermediate_size: int):
603
628
  super().__init__()
604
- self.embed_dim = embed_dim
605
- self.num_heads = num_heads
606
- self.dropout = dropout
607
- self.head_dim = embed_dim // num_heads
608
- if self.head_dim * num_heads != self.embed_dim:
609
- raise ValueError(
610
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
611
- f" {num_heads})."
612
- )
613
- self.scaling = self.head_dim**-0.5
614
-
615
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
616
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
617
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
618
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
619
-
620
- def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
621
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
622
-
623
- def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
624
- return tensor if position_embeddings is None else tensor + position_embeddings
625
-
626
- def forward(
627
- self,
628
- hidden_states: torch.Tensor,
629
- attention_mask: torch.Tensor | None = None,
630
- position_embeddings: torch.Tensor | None = None,
631
- output_attentions: bool = False,
632
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
633
- """Input shape: Batch x Time x Channel"""
634
-
635
- batch_size, target_len, embed_dim = hidden_states.size()
636
- # add position embeddings to the hidden states before projecting to queries and keys
637
- if position_embeddings is not None:
638
- hidden_states_original = hidden_states
639
- hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
640
-
641
- # get queries, keys and values
642
- query_states = self.q_proj(hidden_states) * self.scaling
643
- key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
644
- value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
645
-
646
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
647
- query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
648
- key_states = key_states.view(*proj_shape)
649
- value_states = value_states.view(*proj_shape)
650
-
651
- source_len = key_states.size(1)
652
-
653
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
654
-
655
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
656
- raise ValueError(
657
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
658
- f" {attn_weights.size()}"
659
- )
660
-
661
- # expand attention_mask
662
- if attention_mask is not None:
663
- # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
664
- attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
665
-
666
- if attention_mask is not None:
667
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
668
- raise ValueError(
669
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
670
- f" {attention_mask.size()}"
671
- )
672
- if attention_mask.dtype == torch.bool:
673
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
674
- attention_mask, -torch.inf
675
- )
676
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
677
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
678
-
679
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
680
-
681
- if output_attentions:
682
- # this operation is a bit awkward, but it's required to
683
- # make sure that attn_weights keeps its gradient.
684
- # In order to do so, attn_weights have to reshaped
685
- # twice and have to be reused in the following
686
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
687
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
688
- else:
689
- attn_weights_reshaped = None
690
-
691
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
692
-
693
- attn_output = torch.bmm(attn_probs, value_states)
694
-
695
- if attn_output.size() != (
696
- batch_size * self.num_heads,
697
- target_len,
698
- self.head_dim,
699
- ):
700
- raise ValueError(
701
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
702
- f" {attn_output.size()}"
703
- )
704
-
705
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
706
- attn_output = attn_output.transpose(1, 2)
707
- attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
708
-
709
- attn_output = self.out_proj(attn_output)
629
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
630
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
631
+ self.activation_fn = ACT2FN[config.activation_function]
632
+ self.activation_dropout = config.activation_dropout
633
+ self.dropout = config.dropout
710
634
 
711
- return attn_output, attn_weights_reshaped
635
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
636
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
637
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
638
+ hidden_states = self.fc2(hidden_states)
639
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
640
+ return hidden_states
712
641
 
713
642
 
714
643
  class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
715
644
  def __init__(self, config: DeformableDetrConfig):
716
645
  super().__init__()
717
- self.embed_dim = config.d_model
646
+ self.hidden_size = config.d_model
718
647
  self.self_attn = DeformableDetrMultiscaleDeformableAttention(
719
648
  config,
720
649
  num_heads=config.encoder_attention_heads,
721
650
  n_points=config.encoder_n_points,
722
651
  )
723
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
652
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
724
653
  self.dropout = config.dropout
725
- self.activation_fn = ACT2FN[config.activation_function]
726
- self.activation_dropout = config.activation_dropout
727
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
728
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
729
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
654
+ self.mlp = DeformableDetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
655
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
730
656
 
731
657
  def forward(
732
658
  self,
733
659
  hidden_states: torch.Tensor,
734
660
  attention_mask: torch.Tensor,
735
- position_embeddings: torch.Tensor | None = None,
661
+ spatial_position_embeddings: torch.Tensor | None = None,
736
662
  reference_points=None,
737
663
  spatial_shapes=None,
738
664
  spatial_shapes_list=None,
739
665
  level_start_index=None,
740
- output_attentions: bool = False,
741
- ):
666
+ **kwargs: Unpack[TransformersKwargs],
667
+ ) -> torch.Tensor:
742
668
  """
743
669
  Args:
744
670
  hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -753,24 +679,18 @@ class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
753
679
  Spatial shapes of the backbone feature maps.
754
680
  level_start_index (`torch.LongTensor`, *optional*):
755
681
  Level start index.
756
- output_attentions (`bool`, *optional*):
757
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
758
- returned tensors for more detail.
759
682
  """
760
683
  residual = hidden_states
761
-
762
- # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps.
763
- hidden_states, attn_weights = self.self_attn(
684
+ hidden_states, _ = self.self_attn(
764
685
  hidden_states=hidden_states,
765
686
  attention_mask=attention_mask,
766
687
  encoder_hidden_states=hidden_states,
767
688
  encoder_attention_mask=attention_mask,
768
- position_embeddings=position_embeddings,
689
+ position_embeddings=spatial_position_embeddings,
769
690
  reference_points=reference_points,
770
691
  spatial_shapes=spatial_shapes,
771
692
  spatial_shapes_list=spatial_shapes_list,
772
693
  level_start_index=level_start_index,
773
- output_attentions=output_attentions,
774
694
  )
775
695
 
776
696
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -778,12 +698,7 @@ class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
778
698
  hidden_states = self.self_attn_layer_norm(hidden_states)
779
699
 
780
700
  residual = hidden_states
781
- hidden_states = self.activation_fn(self.fc1(hidden_states))
782
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
783
-
784
- hidden_states = self.fc2(hidden_states)
785
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
786
-
701
+ hidden_states = self.mlp(hidden_states)
787
702
  hidden_states = residual + hidden_states
788
703
  hidden_states = self.final_layer_norm(hidden_states)
789
704
 
@@ -792,54 +707,44 @@ class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
792
707
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
793
708
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
794
709
 
795
- outputs = (hidden_states,)
796
-
797
- if output_attentions:
798
- outputs += (attn_weights,)
799
-
800
- return outputs
710
+ return hidden_states
801
711
 
802
712
 
803
713
  class DeformableDetrDecoderLayer(GradientCheckpointingLayer):
804
714
  def __init__(self, config: DeformableDetrConfig):
805
715
  super().__init__()
806
- self.embed_dim = config.d_model
716
+ self.hidden_size = config.d_model
807
717
 
808
- # self-attention
809
- self.self_attn = DeformableDetrMultiheadAttention(
810
- embed_dim=self.embed_dim,
811
- num_heads=config.decoder_attention_heads,
718
+ self.self_attn = DeformableDetrSelfAttention(
719
+ config=config,
720
+ hidden_size=self.hidden_size,
721
+ num_attention_heads=config.decoder_attention_heads,
812
722
  dropout=config.attention_dropout,
813
723
  )
814
724
  self.dropout = config.dropout
815
- self.activation_fn = ACT2FN[config.activation_function]
816
- self.activation_dropout = config.activation_dropout
817
725
 
818
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
819
- # cross-attention
726
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
820
727
  self.encoder_attn = DeformableDetrMultiscaleDeformableAttention(
821
728
  config,
822
729
  num_heads=config.decoder_attention_heads,
823
730
  n_points=config.decoder_n_points,
824
731
  )
825
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
826
- # feedforward neural networks
827
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
828
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
829
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
732
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
733
+ self.mlp = DeformableDetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
734
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
830
735
 
831
736
  def forward(
832
737
  self,
833
738
  hidden_states: torch.Tensor,
834
- position_embeddings: torch.Tensor | None = None,
739
+ object_queries_position_embeddings: torch.Tensor | None = None,
835
740
  reference_points=None,
836
741
  spatial_shapes=None,
837
742
  spatial_shapes_list=None,
838
743
  level_start_index=None,
839
744
  encoder_hidden_states: torch.Tensor | None = None,
840
745
  encoder_attention_mask: torch.Tensor | None = None,
841
- output_attentions: bool | None = False,
842
- ):
746
+ **kwargs: Unpack[TransformersKwargs],
747
+ ) -> torch.Tensor:
843
748
  """
844
749
  Args:
845
750
  hidden_states (`torch.FloatTensor`):
@@ -857,60 +762,47 @@ class DeformableDetrDecoderLayer(GradientCheckpointingLayer):
857
762
  encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
858
763
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
859
764
  values.
860
- output_attentions (`bool`, *optional*):
861
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
862
- returned tensors for more detail.
863
765
  """
864
766
  residual = hidden_states
865
767
 
866
768
  # Self Attention
867
- hidden_states, self_attn_weights = self.self_attn(
769
+ hidden_states, _ = self.self_attn(
868
770
  hidden_states=hidden_states,
869
- position_embeddings=position_embeddings,
870
- output_attentions=output_attentions,
771
+ position_embeddings=object_queries_position_embeddings,
772
+ **kwargs,
871
773
  )
872
774
 
873
775
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
874
776
  hidden_states = residual + hidden_states
875
777
  hidden_states = self.self_attn_layer_norm(hidden_states)
876
778
 
877
- second_residual = hidden_states
779
+ residual = hidden_states
878
780
 
879
781
  # Cross-Attention
880
- cross_attn_weights = None
881
- hidden_states, cross_attn_weights = self.encoder_attn(
782
+ hidden_states, _ = self.encoder_attn(
882
783
  hidden_states=hidden_states,
883
784
  attention_mask=encoder_attention_mask,
884
785
  encoder_hidden_states=encoder_hidden_states,
885
786
  encoder_attention_mask=encoder_attention_mask,
886
- position_embeddings=position_embeddings,
787
+ position_embeddings=object_queries_position_embeddings,
887
788
  reference_points=reference_points,
888
789
  spatial_shapes=spatial_shapes,
889
790
  spatial_shapes_list=spatial_shapes_list,
890
791
  level_start_index=level_start_index,
891
- output_attentions=output_attentions,
892
792
  )
893
793
 
894
794
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
895
- hidden_states = second_residual + hidden_states
795
+ hidden_states = residual + hidden_states
896
796
 
897
797
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
898
798
 
899
799
  # Fully Connected
900
800
  residual = hidden_states
901
- hidden_states = self.activation_fn(self.fc1(hidden_states))
902
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
903
- hidden_states = self.fc2(hidden_states)
904
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
801
+ hidden_states = self.mlp(hidden_states)
905
802
  hidden_states = residual + hidden_states
906
803
  hidden_states = self.final_layer_norm(hidden_states)
907
804
 
908
- outputs = (hidden_states,)
909
-
910
- if output_attentions:
911
- outputs += (self_attn_weights, cross_attn_weights)
912
-
913
- return outputs
805
+ return hidden_states
914
806
 
915
807
 
916
808
  @auto_docstring
@@ -925,6 +817,13 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
925
817
  r"DeformableDetrEncoderLayer",
926
818
  r"DeformableDetrDecoderLayer",
927
819
  ]
820
+ _supports_sdpa = True
821
+ _supports_flash_attn = True
822
+ _supports_attention_backend = True
823
+ _supports_flex_attn = True
824
+ _keys_to_ignore_on_load_unexpected = [
825
+ r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
826
+ ]
928
827
 
929
828
  @torch.no_grad()
930
829
  def _init_weights(self, module):
@@ -982,9 +881,13 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
982
881
  config: DeformableDetrConfig
983
882
  """
984
883
 
884
+ _can_record_outputs = {
885
+ "hidden_states": DeformableDetrEncoderLayer,
886
+ "attentions": OutputRecorder(DeformableDetrMultiscaleDeformableAttention, layer_name="self_attn", index=1),
887
+ }
888
+
985
889
  def __init__(self, config: DeformableDetrConfig):
986
890
  super().__init__(config)
987
- self.gradient_checkpointing = False
988
891
 
989
892
  self.dropout = config.dropout
990
893
  self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
@@ -992,51 +895,18 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
992
895
  # Initialize weights and apply final processing
993
896
  self.post_init()
994
897
 
995
- @staticmethod
996
- def get_reference_points(spatial_shapes, valid_ratios, device):
997
- """
998
- Get reference points for each feature map. Used in decoder.
999
-
1000
- Args:
1001
- spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
1002
- Spatial shapes of each feature map.
1003
- valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
1004
- Valid ratios of each feature map.
1005
- device (`torch.device`):
1006
- Device on which to create the tensors.
1007
- Returns:
1008
- `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
1009
- """
1010
- reference_points_list = []
1011
- for level, (height, width) in enumerate(spatial_shapes):
1012
- ref_y, ref_x = meshgrid(
1013
- torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
1014
- torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
1015
- indexing="ij",
1016
- )
1017
- # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
1018
- ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
1019
- ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
1020
- ref = torch.stack((ref_x, ref_y), -1)
1021
- reference_points_list.append(ref)
1022
- reference_points = torch.cat(reference_points_list, 1)
1023
- reference_points = reference_points[:, :, None] * valid_ratios[:, None]
1024
- return reference_points
1025
-
898
+ @check_model_inputs()
1026
899
  def forward(
1027
900
  self,
1028
901
  inputs_embeds=None,
1029
902
  attention_mask=None,
1030
- position_embeddings=None,
903
+ spatial_position_embeddings=None,
1031
904
  spatial_shapes=None,
1032
905
  spatial_shapes_list=None,
1033
906
  level_start_index=None,
1034
907
  valid_ratios=None,
1035
- output_attentions=None,
1036
- output_hidden_states=None,
1037
- return_dict=None,
1038
- **kwargs,
1039
- ):
908
+ **kwargs: Unpack[TransformersKwargs],
909
+ ) -> BaseModelOutput:
1040
910
  r"""
1041
911
  Args:
1042
912
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -1046,66 +916,72 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
1046
916
  - 1 for pixel features that are real (i.e. **not masked**),
1047
917
  - 0 for pixel features that are padding (i.e. **masked**).
1048
918
  [What are attention masks?](../glossary#attention-mask)
1049
- position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1050
- Position embeddings that are added to the queries and keys in each self-attention layer.
919
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
920
+ Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
1051
921
  spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
1052
922
  Spatial shapes of each feature map.
1053
923
  level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
1054
924
  Starting index of each feature map.
1055
925
  valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
1056
926
  Ratio of valid area in each feature level.
1057
- output_attentions (`bool`, *optional*):
1058
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1059
- returned tensors for more detail.
1060
- output_hidden_states (`bool`, *optional*):
1061
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1062
- for more detail.
1063
- return_dict (`bool`, *optional*):
1064
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1065
927
  """
1066
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1067
- output_hidden_states = (
1068
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1069
- )
1070
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1071
-
1072
928
  hidden_states = inputs_embeds
1073
929
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1074
930
 
1075
931
  spatial_shapes_tuple = tuple(spatial_shapes_list)
1076
932
  reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device)
1077
933
 
1078
- encoder_states = () if output_hidden_states else None
1079
- all_attentions = () if output_attentions else None
1080
- for i, encoder_layer in enumerate(self.layers):
1081
- if output_hidden_states:
1082
- encoder_states = encoder_states + (hidden_states,)
1083
- layer_outputs = encoder_layer(
934
+ for encoder_layer in self.layers:
935
+ hidden_states = encoder_layer(
1084
936
  hidden_states,
1085
937
  attention_mask,
1086
- position_embeddings=position_embeddings,
938
+ spatial_position_embeddings=spatial_position_embeddings,
1087
939
  reference_points=reference_points,
1088
940
  spatial_shapes=spatial_shapes,
1089
941
  spatial_shapes_list=spatial_shapes_list,
1090
942
  level_start_index=level_start_index,
1091
- output_attentions=output_attentions,
943
+ **kwargs,
1092
944
  )
1093
945
 
1094
- hidden_states = layer_outputs[0]
946
+ return BaseModelOutput(last_hidden_state=hidden_states)
1095
947
 
1096
- if output_attentions:
1097
- all_attentions = all_attentions + (layer_outputs[1],)
948
+ @staticmethod
949
+ def get_reference_points(spatial_shapes_list, valid_ratios, device):
950
+ """
951
+ Get reference points for each feature map. Used in decoder.
952
+
953
+ Args:
954
+ spatial_shapes_list (`list[tuple[int, int]]`):
955
+ Spatial shapes of each feature map.
956
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
957
+ Valid ratios of each feature map.
958
+ device (`torch.device`):
959
+ Device on which to create the tensors.
960
+ Returns:
961
+ `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
962
+ """
963
+ reference_points_list = []
964
+ for level, (height, width) in enumerate(spatial_shapes_list):
965
+ ref_y, ref_x = meshgrid(
966
+ torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
967
+ torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
968
+ indexing="ij",
969
+ )
970
+ # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
971
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
972
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
973
+ ref = torch.stack((ref_x, ref_y), -1)
974
+ reference_points_list.append(ref)
975
+ reference_points = torch.cat(reference_points_list, 1)
976
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
977
+ return reference_points
1098
978
 
1099
- if output_hidden_states:
1100
- encoder_states = encoder_states + (hidden_states,)
1101
979
 
1102
- if not return_dict:
1103
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1104
- return BaseModelOutput(
1105
- last_hidden_state=hidden_states,
1106
- hidden_states=encoder_states,
1107
- attentions=all_attentions,
1108
- )
980
+ def inverse_sigmoid(x, eps=1e-5):
981
+ x = x.clamp(min=0, max=1)
982
+ x1 = x.clamp(min=eps)
983
+ x2 = (1 - x).clamp(min=eps)
984
+ return torch.log(x1 / x2)
1109
985
 
1110
986
 
1111
987
  class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
@@ -1123,12 +999,19 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1123
999
  config: DeformableDetrConfig
1124
1000
  """
1125
1001
 
1002
+ _can_record_outputs = {
1003
+ "hidden_states": DeformableDetrDecoderLayer,
1004
+ "attentions": OutputRecorder(DeformableDetrSelfAttention, layer_name="self_attn", index=1),
1005
+ "cross_attentions": OutputRecorder(
1006
+ DeformableDetrMultiscaleDeformableAttention, layer_name="encoder_attn", index=1
1007
+ ),
1008
+ }
1009
+
1126
1010
  def __init__(self, config: DeformableDetrConfig):
1127
1011
  super().__init__(config)
1128
1012
 
1129
1013
  self.dropout = config.dropout
1130
1014
  self.layers = nn.ModuleList([DeformableDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
1131
- self.gradient_checkpointing = False
1132
1015
 
1133
1016
  # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
1134
1017
  self.bbox_embed = None
@@ -1137,21 +1020,19 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1137
1020
  # Initialize weights and apply final processing
1138
1021
  self.post_init()
1139
1022
 
1023
+ @check_model_inputs()
1140
1024
  def forward(
1141
1025
  self,
1142
1026
  inputs_embeds=None,
1143
1027
  encoder_hidden_states=None,
1144
1028
  encoder_attention_mask=None,
1145
- position_embeddings=None,
1029
+ object_queries_position_embeddings=None,
1146
1030
  reference_points=None,
1147
1031
  spatial_shapes=None,
1148
1032
  spatial_shapes_list=None,
1149
1033
  level_start_index=None,
1150
1034
  valid_ratios=None,
1151
- output_attentions=None,
1152
- output_hidden_states=None,
1153
- return_dict=None,
1154
- **kwargs,
1035
+ **kwargs: Unpack[TransformersKwargs],
1155
1036
  ):
1156
1037
  r"""
1157
1038
  Args:
@@ -1165,8 +1046,8 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1165
1046
  in `[0, 1]`:
1166
1047
  - 1 for pixels that are real (i.e. **not masked**),
1167
1048
  - 0 for pixels that are padding (i.e. **masked**).
1168
- position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1169
- Position embeddings that are added to the queries and keys in each self-attention layer.
1049
+ object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1050
+ Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
1170
1051
  reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
1171
1052
  Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
1172
1053
  spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
@@ -1176,28 +1057,11 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1176
1057
  valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
1177
1058
  Ratio of valid area in each feature level.
1178
1059
 
1179
- output_attentions (`bool`, *optional*):
1180
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1181
- returned tensors for more detail.
1182
- output_hidden_states (`bool`, *optional*):
1183
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1184
- for more detail.
1185
- return_dict (`bool`, *optional*):
1186
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1187
1060
  """
1188
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1189
- output_hidden_states = (
1190
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1191
- )
1192
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1193
-
1194
1061
  if inputs_embeds is not None:
1195
1062
  hidden_states = inputs_embeds
1196
1063
 
1197
1064
  # decoder layers
1198
- all_hidden_states = () if output_hidden_states else None
1199
- all_self_attns = () if output_attentions else None
1200
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1201
1065
  intermediate = ()
1202
1066
  intermediate_reference_points = ()
1203
1067
 
@@ -1212,23 +1076,18 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1212
1076
  else:
1213
1077
  raise ValueError("Reference points' last dimension must be of size 2")
1214
1078
 
1215
- if output_hidden_states:
1216
- all_hidden_states += (hidden_states,)
1217
-
1218
- layer_outputs = decoder_layer(
1079
+ hidden_states = decoder_layer(
1219
1080
  hidden_states,
1220
- position_embeddings,
1081
+ object_queries_position_embeddings,
1221
1082
  reference_points_input,
1222
1083
  spatial_shapes,
1223
1084
  spatial_shapes_list,
1224
1085
  level_start_index,
1225
1086
  encoder_hidden_states, # as a positional argument for gradient checkpointing
1226
1087
  encoder_attention_mask,
1227
- output_attentions,
1088
+ **kwargs,
1228
1089
  )
1229
1090
 
1230
- hidden_states = layer_outputs[0]
1231
-
1232
1091
  # hack implementation for iterative bounding box refinement
1233
1092
  if self.bbox_embed is not None:
1234
1093
  tmp = self.bbox_embed[idx](hidden_states)
@@ -1249,40 +1108,14 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1249
1108
  intermediate += (hidden_states,)
1250
1109
  intermediate_reference_points += (reference_points,)
1251
1110
 
1252
- if output_attentions:
1253
- all_self_attns += (layer_outputs[1],)
1254
-
1255
- if encoder_hidden_states is not None:
1256
- all_cross_attentions += (layer_outputs[2],)
1257
-
1258
1111
  # Keep batch_size as first dimension
1259
1112
  intermediate = torch.stack(intermediate, dim=1)
1260
1113
  intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
1261
1114
 
1262
- # add hidden states from the last decoder layer
1263
- if output_hidden_states:
1264
- all_hidden_states += (hidden_states,)
1265
-
1266
- if not return_dict:
1267
- return tuple(
1268
- v
1269
- for v in [
1270
- hidden_states,
1271
- intermediate,
1272
- intermediate_reference_points,
1273
- all_hidden_states,
1274
- all_self_attns,
1275
- all_cross_attentions,
1276
- ]
1277
- if v is not None
1278
- )
1279
1115
  return DeformableDetrDecoderOutput(
1280
1116
  last_hidden_state=hidden_states,
1281
1117
  intermediate_hidden_states=intermediate,
1282
1118
  intermediate_reference_points=intermediate_reference_points,
1283
- hidden_states=all_hidden_states,
1284
- attentions=all_self_attns,
1285
- cross_attentions=all_cross_attentions,
1286
1119
  )
1287
1120
 
1288
1121
 
@@ -1296,17 +1129,23 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1296
1129
  def __init__(self, config: DeformableDetrConfig):
1297
1130
  super().__init__(config)
1298
1131
 
1299
- # Create backbone + positional encoding
1300
- backbone = DeformableDetrConvEncoder(config)
1301
- position_embeddings = build_position_encoding(config)
1302
- self.backbone = DeformableDetrConvModel(backbone, position_embeddings)
1132
+ # Create backbone
1133
+ self.backbone = DeformableDetrConvEncoder(config)
1134
+
1135
+ # Create positional encoding
1136
+ if config.position_embedding_type == "sine":
1137
+ self.position_embedding = DeformableDetrSinePositionEmbedding(config.d_model // 2, normalize=True)
1138
+ elif config.position_embedding_type == "learned":
1139
+ self.position_embedding = DeformableDetrLearnedPositionEmbedding(config.d_model // 2)
1140
+ else:
1141
+ raise ValueError(f"Not supported {config.position_embedding_type}")
1303
1142
 
1304
1143
  # Create input projection layers
1305
1144
  if config.num_feature_levels > 1:
1306
- num_backbone_outs = len(backbone.intermediate_channel_sizes)
1145
+ num_backbone_outs = len(self.backbone.intermediate_channel_sizes)
1307
1146
  input_proj_list = []
1308
1147
  for _ in range(num_backbone_outs):
1309
- in_channels = backbone.intermediate_channel_sizes[_]
1148
+ in_channels = self.backbone.intermediate_channel_sizes[_]
1310
1149
  input_proj_list.append(
1311
1150
  nn.Sequential(
1312
1151
  nn.Conv2d(in_channels, config.d_model, kernel_size=1),
@@ -1333,7 +1172,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1333
1172
  [
1334
1173
  nn.Sequential(
1335
1174
  nn.Conv2d(
1336
- backbone.intermediate_channel_sizes[-1],
1175
+ self.backbone.intermediate_channel_sizes[-1],
1337
1176
  config.d_model,
1338
1177
  kernel_size=1,
1339
1178
  ),
@@ -1361,11 +1200,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1361
1200
  self.post_init()
1362
1201
 
1363
1202
  def freeze_backbone(self):
1364
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1203
+ for name, param in self.backbone.model.named_parameters():
1365
1204
  param.requires_grad_(False)
1366
1205
 
1367
1206
  def unfreeze_backbone(self):
1368
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1207
+ for name, param in self.backbone.model.named_parameters():
1369
1208
  param.requires_grad_(True)
1370
1209
 
1371
1210
  def get_valid_ratio(self, mask, dtype=torch.float32):
@@ -1386,15 +1225,18 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1386
1225
  temperature = 10000
1387
1226
  scale = 2 * math.pi
1388
1227
 
1389
- dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
1228
+ # Compute position embeddings in float32 to avoid overflow with large temperature values in fp16
1229
+ proposals_dtype = proposals.dtype
1230
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
1390
1231
  dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
1391
1232
  # batch_size, num_queries, 4
1392
- proposals = proposals.sigmoid() * scale
1233
+ proposals = proposals.sigmoid().to(torch.float32) * scale
1393
1234
  # batch_size, num_queries, 4, 128
1394
1235
  pos = proposals[:, :, :, None] / dim_t
1395
1236
  # batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512
1396
1237
  pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
1397
- return pos
1238
+ # Convert back to target dtype after all computations are done
1239
+ return pos.to(proposals_dtype)
1398
1240
 
1399
1241
  def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
1400
1242
  """Generate the encoder output proposals from encoded enc_output.
@@ -1458,6 +1300,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1458
1300
  return object_query, output_proposals
1459
1301
 
1460
1302
  @auto_docstring
1303
+ @can_return_tuple
1461
1304
  def forward(
1462
1305
  self,
1463
1306
  pixel_values: torch.FloatTensor,
@@ -1466,10 +1309,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1466
1309
  encoder_outputs: torch.FloatTensor | None = None,
1467
1310
  inputs_embeds: torch.FloatTensor | None = None,
1468
1311
  decoder_inputs_embeds: torch.FloatTensor | None = None,
1469
- output_attentions: bool | None = None,
1470
- output_hidden_states: bool | None = None,
1471
- return_dict: bool | None = None,
1472
- **kwargs,
1312
+ **kwargs: Unpack[TransformersKwargs],
1473
1313
  ) -> tuple[torch.FloatTensor] | DeformableDetrModelOutput:
1474
1314
  r"""
1475
1315
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1502,12 +1342,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1502
1342
  >>> list(last_hidden_states.shape)
1503
1343
  [1, 300, 256]
1504
1344
  ```"""
1505
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1506
- output_hidden_states = (
1507
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1508
- )
1509
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1510
-
1511
1345
  batch_size, num_channels, height, width = pixel_values.shape
1512
1346
  device = pixel_values.device
1513
1347
 
@@ -1517,16 +1351,22 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1517
1351
  # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
1518
1352
  # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1519
1353
  # which is a list of tuples
1520
- features, position_embeddings_list = self.backbone(pixel_values, pixel_mask)
1354
+ features = self.backbone(pixel_values, pixel_mask)
1521
1355
 
1522
1356
  # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1523
1357
  sources = []
1524
1358
  masks = []
1359
+ position_embeddings_list = []
1525
1360
  for level, (source, mask) in enumerate(features):
1526
1361
  sources.append(self.input_proj[level](source))
1527
1362
  masks.append(mask)
1528
1363
  if mask is None:
1529
1364
  raise ValueError("No attention mask was provided")
1365
+ # Generate position embeddings for this feature level
1366
+ pos = self.position_embedding(shape=source.shape, device=device, dtype=pixel_values.dtype, mask=mask).to(
1367
+ source.dtype
1368
+ )
1369
+ position_embeddings_list.append(pos)
1530
1370
 
1531
1371
  # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
1532
1372
  if self.config.num_feature_levels > len(sources):
@@ -1539,7 +1379,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1539
1379
  mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to(
1540
1380
  torch.bool
1541
1381
  )[0]
1542
- pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
1382
+ pos_l = self.position_embedding(
1383
+ shape=source.shape, device=device, dtype=pixel_values.dtype, mask=mask
1384
+ ).to(source.dtype)
1543
1385
  sources.append(source)
1544
1386
  masks.append(mask)
1545
1387
  position_embeddings_list.append(pos_l)
@@ -1560,7 +1402,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1560
1402
  spatial_shapes_list.append(spatial_shape)
1561
1403
  source = source.flatten(2).transpose(1, 2)
1562
1404
  mask = mask.flatten(1)
1563
- pos_embed = pos_embed.flatten(2).transpose(1, 2)
1564
1405
  lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1)
1565
1406
  lvl_pos_embed_flatten.append(lvl_pos_embed)
1566
1407
  source_flatten.append(source)
@@ -1578,21 +1419,12 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1578
1419
  encoder_outputs = self.encoder(
1579
1420
  inputs_embeds=source_flatten,
1580
1421
  attention_mask=mask_flatten,
1581
- position_embeddings=lvl_pos_embed_flatten,
1422
+ spatial_position_embeddings=lvl_pos_embed_flatten,
1582
1423
  spatial_shapes=spatial_shapes,
1583
1424
  spatial_shapes_list=spatial_shapes_list,
1584
1425
  level_start_index=level_start_index,
1585
1426
  valid_ratios=valid_ratios,
1586
- output_attentions=output_attentions,
1587
- output_hidden_states=output_hidden_states,
1588
- return_dict=return_dict,
1589
- )
1590
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1591
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1592
- encoder_outputs = BaseModelOutput(
1593
- last_hidden_state=encoder_outputs[0],
1594
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1595
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1427
+ **kwargs,
1596
1428
  )
1597
1429
 
1598
1430
  # Fifth, prepare decoder inputs
@@ -1635,7 +1467,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1635
1467
 
1636
1468
  decoder_outputs = self.decoder(
1637
1469
  inputs_embeds=target,
1638
- position_embeddings=query_embed,
1470
+ object_queries_position_embeddings=query_embed,
1639
1471
  encoder_hidden_states=encoder_outputs[0],
1640
1472
  encoder_attention_mask=mask_flatten,
1641
1473
  reference_points=reference_points,
@@ -1643,17 +1475,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1643
1475
  spatial_shapes_list=spatial_shapes_list,
1644
1476
  level_start_index=level_start_index,
1645
1477
  valid_ratios=valid_ratios,
1646
- output_attentions=output_attentions,
1647
- output_hidden_states=output_hidden_states,
1648
- return_dict=return_dict,
1478
+ **kwargs,
1649
1479
  )
1650
1480
 
1651
- if not return_dict:
1652
- enc_outputs = tuple(value for value in [enc_outputs_class, enc_outputs_coord_logits] if value is not None)
1653
- tuple_outputs = (init_reference_points,) + decoder_outputs + encoder_outputs + enc_outputs
1654
-
1655
- return tuple_outputs
1656
-
1657
1481
  return DeformableDetrModelOutput(
1658
1482
  init_reference_points=init_reference_points,
1659
1483
  last_hidden_state=decoder_outputs.last_hidden_state,
@@ -1670,14 +1494,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1670
1494
  )
1671
1495
 
1672
1496
 
1673
- # Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead
1674
1497
  class DeformableDetrMLPPredictionHead(nn.Module):
1675
1498
  """
1676
1499
  Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1677
1500
  height and width of a bounding box w.r.t. an image.
1678
1501
 
1679
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1680
-
1681
1502
  """
1682
1503
 
1683
1504
  def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
@@ -1726,15 +1547,18 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1726
1547
  for _ in range(num_pred)
1727
1548
  ]
1728
1549
  )
1550
+ # Convert to instance attribute before modifying
1551
+ self._tied_weights_keys = self._tied_weights_keys.copy()
1729
1552
  if config.with_box_refine:
1730
1553
  self.model.decoder.bbox_embed = self.bbox_embed
1731
- self._tied_weights_keys["model.decoder.bbox_embed"] = "bbox_embed"
1554
+ self._tied_weights_keys["bbox_embed"] = "model.decoder.bbox_embed"
1732
1555
  if config.two_stage:
1733
1556
  self.model.decoder.class_embed = self.class_embed
1734
- self._tied_weights_keys["model.decoder.class_embed"] = "class_embed"
1557
+ self._tied_weights_keys["class_embed"] = "model.decoder.class_embed"
1735
1558
  self.post_init()
1736
1559
 
1737
1560
  @auto_docstring
1561
+ @can_return_tuple
1738
1562
  def forward(
1739
1563
  self,
1740
1564
  pixel_values: torch.FloatTensor,
@@ -1744,10 +1568,7 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1744
1568
  inputs_embeds: torch.FloatTensor | None = None,
1745
1569
  decoder_inputs_embeds: torch.FloatTensor | None = None,
1746
1570
  labels: list[dict] | None = None,
1747
- output_attentions: bool | None = None,
1748
- output_hidden_states: bool | None = None,
1749
- return_dict: bool | None = None,
1750
- **kwargs,
1571
+ **kwargs: Unpack[TransformersKwargs],
1751
1572
  ) -> tuple[torch.FloatTensor] | DeformableDetrObjectDetectionOutput:
1752
1573
  r"""
1753
1574
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1795,8 +1616,6 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1795
1616
  Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25]
1796
1617
  Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25]
1797
1618
  ```"""
1798
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1799
-
1800
1619
  # First, sent images through DETR base model to obtain encoder + decoder outputs
1801
1620
  outputs = self.model(
1802
1621
  pixel_values,
@@ -1805,14 +1624,12 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1805
1624
  encoder_outputs=encoder_outputs,
1806
1625
  inputs_embeds=inputs_embeds,
1807
1626
  decoder_inputs_embeds=decoder_inputs_embeds,
1808
- output_attentions=output_attentions,
1809
- output_hidden_states=output_hidden_states,
1810
- return_dict=return_dict,
1627
+ **kwargs,
1811
1628
  )
1812
1629
 
1813
- hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2]
1814
- init_reference = outputs.init_reference_points if return_dict else outputs[0]
1815
- inter_references = outputs.intermediate_reference_points if return_dict else outputs[3]
1630
+ hidden_states = outputs.intermediate_hidden_states
1631
+ init_reference = outputs.init_reference_points
1632
+ inter_references = outputs.intermediate_reference_points
1816
1633
 
1817
1634
  # class logits + predicted bounding boxes
1818
1635
  outputs_classes = []
@@ -1853,16 +1670,8 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1853
1670
  outputs_class,
1854
1671
  outputs_coord,
1855
1672
  )
1856
- if not return_dict:
1857
- if auxiliary_outputs is not None:
1858
- output = (logits, pred_boxes) + auxiliary_outputs + outputs
1859
- else:
1860
- output = (logits, pred_boxes) + outputs
1861
- tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output
1862
-
1863
- return tuple_outputs
1864
1673
 
1865
- dict_outputs = DeformableDetrObjectDetectionOutput(
1674
+ return DeformableDetrObjectDetectionOutput(
1866
1675
  loss=loss,
1867
1676
  loss_dict=loss_dict,
1868
1677
  logits=logits,
@@ -1882,11 +1691,5 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1882
1691
  enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
1883
1692
  )
1884
1693
 
1885
- return dict_outputs
1886
-
1887
1694
 
1888
- __all__ = [
1889
- "DeformableDetrForObjectDetection",
1890
- "DeformableDetrModel",
1891
- "DeformableDetrPreTrainedModel",
1892
- ]
1695
+ __all__ = ["DeformableDetrForObjectDetection", "DeformableDetrModel", "DeformableDetrPreTrainedModel"]