transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -14,32 +14,35 @@
14
14
  """PyTorch DETR model."""
15
15
 
16
16
  import math
17
+ from collections.abc import Callable
17
18
  from dataclasses import dataclass
18
19
 
19
20
  import torch
20
- from torch import Tensor, nn
21
+ import torch.nn as nn
21
22
 
22
23
  from ... import initialization as init
23
24
  from ...activations import ACT2FN
24
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
25
+ from ...backbone_utils import load_backbone
26
+ from ...masking_utils import create_bidirectional_mask
25
27
  from ...modeling_layers import GradientCheckpointingLayer
26
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
27
- from ...modeling_utils import PreTrainedModel
28
+ from ...modeling_outputs import (
29
+ BaseModelOutput,
30
+ BaseModelOutputWithCrossAttentions,
31
+ Seq2SeqModelOutput,
32
+ )
33
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
34
+ from ...processing_utils import Unpack
35
+ from ...pytorch_utils import compile_compatible_method_lru_cache
28
36
  from ...utils import (
29
37
  ModelOutput,
38
+ TransformersKwargs,
30
39
  auto_docstring,
31
- is_timm_available,
32
40
  logging,
33
- requires_backends,
34
41
  )
35
- from ...utils.backbone_utils import load_backbone
42
+ from ...utils.generic import can_return_tuple, check_model_inputs
36
43
  from .configuration_detr import DetrConfig
37
44
 
38
45
 
39
- if is_timm_available():
40
- from timm import create_model
41
-
42
-
43
46
  logger = logging.get_logger(__name__)
44
47
 
45
48
 
@@ -178,8 +181,6 @@ class DetrSegmentationOutput(ModelOutput):
178
181
  encoder_attentions: tuple[torch.FloatTensor] | None = None
179
182
 
180
183
 
181
- # BELOW: utilities copied from
182
- # https://github.com/facebookresearch/detr/blob/master/backbone.py
183
184
  class DetrFrozenBatchNorm2d(nn.Module):
184
185
  """
185
186
  BatchNorm2d where the batch statistics and the affine parameters are fixed.
@@ -256,47 +257,25 @@ class DetrConvEncoder(nn.Module):
256
257
 
257
258
  self.config = config
258
259
 
259
- # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
260
- if config.use_timm_backbone:
261
- # We default to values which were previously hard-coded. This enables configurability from the config
262
- # using backbone arguments, while keeping the default behavior the same.
263
- requires_backends(self, ["timm"])
264
- kwargs = getattr(config, "backbone_kwargs", {})
265
- kwargs = {} if kwargs is None else kwargs.copy()
266
- out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
267
- num_channels = kwargs.pop("in_chans", config.num_channels)
268
- if config.dilation:
269
- kwargs["output_stride"] = kwargs.get("output_stride", 16)
270
- backbone = create_model(
271
- config.backbone,
272
- pretrained=config.use_pretrained_backbone,
273
- features_only=True,
274
- out_indices=out_indices,
275
- in_chans=num_channels,
276
- **kwargs,
277
- )
278
- else:
279
- backbone = load_backbone(config)
260
+ backbone = load_backbone(config)
261
+ self.intermediate_channel_sizes = backbone.channels
280
262
 
281
263
  # replace batch norm by frozen batch norm
282
264
  with torch.no_grad():
283
265
  replace_batch_norm(backbone)
284
- self.model = backbone
285
- self.intermediate_channel_sizes = (
286
- self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
287
- )
288
266
 
289
- backbone_model_type = None
290
- if config.backbone is not None:
291
- backbone_model_type = config.backbone
292
- elif config.backbone_config is not None:
293
- backbone_model_type = config.backbone_config.model_type
294
- else:
295
- raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
267
+ # We used to load with timm library directly instead of the AutoBackbone API
268
+ # so we need to unwrap the `backbone._backbone` module to load weights without mismatch
269
+ is_timm_model = False
270
+ if hasattr(backbone, "_backbone"):
271
+ backbone = backbone._backbone
272
+ is_timm_model = True
273
+ self.model = backbone
296
274
 
275
+ backbone_model_type = config.backbone_config.model_type
297
276
  if "resnet" in backbone_model_type:
298
277
  for name, parameter in self.model.named_parameters():
299
- if config.use_timm_backbone:
278
+ if is_timm_model:
300
279
  if "layer2" not in name and "layer3" not in name and "layer4" not in name:
301
280
  parameter.requires_grad_(False)
302
281
  else:
@@ -305,7 +284,9 @@ class DetrConvEncoder(nn.Module):
305
284
 
306
285
  def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
307
286
  # send pixel_values through the model to get list of feature maps
308
- features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
287
+ features = self.model(pixel_values)
288
+ if isinstance(features, dict):
289
+ features = features.feature_maps
309
290
 
310
291
  out = []
311
292
  for feature_map in features:
@@ -315,61 +296,55 @@ class DetrConvEncoder(nn.Module):
315
296
  return out
316
297
 
317
298
 
318
- class DetrConvModel(nn.Module):
319
- """
320
- This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
321
- """
322
-
323
- def __init__(self, conv_encoder, position_embedding):
324
- super().__init__()
325
- self.conv_encoder = conv_encoder
326
- self.position_embedding = position_embedding
327
-
328
- def forward(self, pixel_values, pixel_mask):
329
- # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
330
- out = self.conv_encoder(pixel_values, pixel_mask)
331
- pos = []
332
- for feature_map, mask in out:
333
- # position encoding
334
- pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
335
-
336
- return out, pos
337
-
338
-
339
299
  class DetrSinePositionEmbedding(nn.Module):
340
300
  """
341
301
  This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
342
302
  need paper, generalized to work on images.
343
303
  """
344
304
 
345
- def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
305
+ def __init__(
306
+ self,
307
+ num_position_features: int = 64,
308
+ temperature: int = 10000,
309
+ normalize: bool = False,
310
+ scale: float | None = None,
311
+ ):
346
312
  super().__init__()
347
- self.embedding_dim = embedding_dim
348
- self.temperature = temperature
349
- self.normalize = normalize
350
313
  if scale is not None and normalize is False:
351
314
  raise ValueError("normalize should be True if scale is passed")
352
- if scale is None:
353
- scale = 2 * math.pi
354
- self.scale = scale
315
+ self.num_position_features = num_position_features
316
+ self.temperature = temperature
317
+ self.normalize = normalize
318
+ self.scale = 2 * math.pi if scale is None else scale
355
319
 
356
- def forward(self, pixel_values, pixel_mask):
357
- if pixel_mask is None:
358
- raise ValueError("No pixel mask provided")
359
- y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
360
- x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
320
+ @compile_compatible_method_lru_cache(maxsize=1)
321
+ def forward(
322
+ self,
323
+ shape: torch.Size,
324
+ device: torch.device | str,
325
+ dtype: torch.dtype,
326
+ mask: torch.Tensor | None = None,
327
+ ) -> torch.Tensor:
328
+ if mask is None:
329
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
330
+ y_embed = mask.cumsum(1, dtype=dtype)
331
+ x_embed = mask.cumsum(2, dtype=dtype)
361
332
  if self.normalize:
362
- y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
363
- x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
333
+ eps = 1e-6
334
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
335
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
364
336
 
365
- dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
366
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
337
+ dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
338
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
367
339
 
368
340
  pos_x = x_embed[:, :, :, None] / dim_t
369
341
  pos_y = y_embed[:, :, :, None] / dim_t
370
342
  pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
371
343
  pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
372
344
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
345
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
346
+ # expected by the encoder
347
+ pos = pos.flatten(2).permute(0, 2, 1)
373
348
  return pos
374
349
 
375
350
 
@@ -383,207 +358,260 @@ class DetrLearnedPositionEmbedding(nn.Module):
383
358
  self.row_embeddings = nn.Embedding(50, embedding_dim)
384
359
  self.column_embeddings = nn.Embedding(50, embedding_dim)
385
360
 
386
- def forward(self, pixel_values, pixel_mask=None):
387
- height, width = pixel_values.shape[-2:]
388
- width_values = torch.arange(width, device=pixel_values.device)
389
- height_values = torch.arange(height, device=pixel_values.device)
361
+ @compile_compatible_method_lru_cache(maxsize=1)
362
+ def forward(
363
+ self,
364
+ shape: torch.Size,
365
+ device: torch.device | str,
366
+ dtype: torch.dtype,
367
+ mask: torch.Tensor | None = None,
368
+ ):
369
+ height, width = shape[-2:]
370
+ width_values = torch.arange(width, device=device)
371
+ height_values = torch.arange(height, device=device)
390
372
  x_emb = self.column_embeddings(width_values)
391
373
  y_emb = self.row_embeddings(height_values)
392
374
  pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
393
375
  pos = pos.permute(2, 0, 1)
394
376
  pos = pos.unsqueeze(0)
395
- pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
377
+ pos = pos.repeat(shape[0], 1, 1, 1)
378
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
379
+ # expected by the encoder
380
+ pos = pos.flatten(2).permute(0, 2, 1)
396
381
  return pos
397
382
 
398
383
 
399
- def build_position_encoding(config):
400
- n_steps = config.d_model // 2
401
- if config.position_embedding_type == "sine":
402
- # TODO find a better way of exposing other arguments
403
- position_embedding = DetrSinePositionEmbedding(n_steps, normalize=True)
404
- elif config.position_embedding_type == "learned":
405
- position_embedding = DetrLearnedPositionEmbedding(n_steps)
406
- else:
407
- raise ValueError(f"Not supported {config.position_embedding_type}")
384
+ # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
385
+ def eager_attention_forward(
386
+ module: nn.Module,
387
+ query: torch.Tensor,
388
+ key: torch.Tensor,
389
+ value: torch.Tensor,
390
+ attention_mask: torch.Tensor | None,
391
+ scaling: float | None = None,
392
+ dropout: float = 0.0,
393
+ **kwargs: Unpack[TransformersKwargs],
394
+ ):
395
+ if scaling is None:
396
+ scaling = query.size(-1) ** -0.5
397
+
398
+ # Take the dot product between "query" and "key" to get the raw attention scores.
399
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
400
+
401
+ if attention_mask is not None:
402
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
403
+ attn_weights = attn_weights + attention_mask
404
+
405
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
406
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
408
407
 
409
- return position_embedding
408
+ attn_output = torch.matmul(attn_weights, value)
409
+ attn_output = attn_output.transpose(1, 2).contiguous()
410
410
 
411
+ return attn_output, attn_weights
411
412
 
412
- class DetrAttention(nn.Module):
413
+
414
+ class DetrSelfAttention(nn.Module):
413
415
  """
414
- Multi-headed attention from 'Attention Is All You Need' paper.
416
+ Multi-headed self-attention from 'Attention Is All You Need' paper.
415
417
 
416
- Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
418
+ In DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
417
419
  """
418
420
 
419
421
  def __init__(
420
422
  self,
421
- embed_dim: int,
422
- num_heads: int,
423
+ config: DetrConfig,
424
+ hidden_size: int,
425
+ num_attention_heads: int,
423
426
  dropout: float = 0.0,
424
427
  bias: bool = True,
425
428
  ):
426
429
  super().__init__()
427
- self.embed_dim = embed_dim
428
- self.num_heads = num_heads
429
- self.dropout = dropout
430
- self.head_dim = embed_dim // num_heads
431
- if self.head_dim * num_heads != self.embed_dim:
432
- raise ValueError(
433
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
434
- f" {num_heads})."
435
- )
430
+ self.config = config
431
+ self.head_dim = hidden_size // num_attention_heads
436
432
  self.scaling = self.head_dim**-0.5
433
+ self.attention_dropout = dropout
434
+ self.is_causal = False
437
435
 
438
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
439
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
440
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
441
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
442
-
443
- def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
444
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
445
-
446
- def with_pos_embed(self, tensor: torch.Tensor, object_queries: Tensor | None):
447
- return tensor if object_queries is None else tensor + object_queries
436
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
437
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
438
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
439
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
448
440
 
449
441
  def forward(
450
442
  self,
451
443
  hidden_states: torch.Tensor,
452
444
  attention_mask: torch.Tensor | None = None,
453
- object_queries: torch.Tensor | None = None,
454
- key_value_states: torch.Tensor | None = None,
455
- spatial_position_embeddings: torch.Tensor | None = None,
456
- output_attentions: bool = False,
457
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
458
- """Input shape: Batch x Time x Channel"""
459
- # if key_value_states are provided this layer is used as a cross-attention layer
460
- # for the decoder
461
- is_cross_attention = key_value_states is not None
462
- batch_size, target_len, embed_dim = hidden_states.size()
463
-
464
- # add position embeddings to the hidden states before projecting to queries and keys
465
- if object_queries is not None:
466
- hidden_states_original = hidden_states
467
- hidden_states = self.with_pos_embed(hidden_states, object_queries)
468
-
469
- # add key-value position embeddings to the key value states
470
- if spatial_position_embeddings is not None:
471
- key_value_states_original = key_value_states
472
- key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
473
-
474
- # get query proj
475
- query_states = self.q_proj(hidden_states) * self.scaling
476
- # get key, value proj
477
- if is_cross_attention:
478
- # cross_attentions
479
- key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
480
- value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
481
- else:
482
- # self_attention
483
- key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
484
- value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
445
+ position_embeddings: torch.Tensor | None = None,
446
+ **kwargs: Unpack[TransformersKwargs],
447
+ ) -> tuple[torch.Tensor, torch.Tensor]:
448
+ """
449
+ Position embeddings are added to both queries and keys (but not values).
450
+ """
451
+ input_shape = hidden_states.shape[:-1]
452
+ hidden_shape = (*input_shape, -1, self.head_dim)
485
453
 
486
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
487
- query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
488
- key_states = key_states.view(*proj_shape)
489
- value_states = value_states.view(*proj_shape)
454
+ query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
490
455
 
491
- source_len = key_states.size(1)
456
+ query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
457
+ key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
458
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
492
459
 
493
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
460
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
461
+ self.config._attn_implementation, eager_attention_forward
462
+ )
494
463
 
495
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
496
- raise ValueError(
497
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
498
- f" {attn_weights.size()}"
499
- )
464
+ attn_output, attn_weights = attention_interface(
465
+ self,
466
+ query_states,
467
+ key_states,
468
+ value_states,
469
+ attention_mask,
470
+ dropout=0.0 if not self.training else self.attention_dropout,
471
+ scaling=self.scaling,
472
+ **kwargs,
473
+ )
500
474
 
501
- if attention_mask is not None:
502
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
503
- raise ValueError(
504
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
505
- f" {attention_mask.size()}"
506
- )
507
- if attention_mask.dtype == torch.bool:
508
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
509
- attention_mask, -torch.inf
510
- )
511
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
512
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
513
-
514
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
515
-
516
- if output_attentions:
517
- # this operation is a bit awkward, but it's required to
518
- # make sure that attn_weights keeps its gradient.
519
- # In order to do so, attn_weights have to reshaped
520
- # twice and have to be reused in the following
521
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
522
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
523
- else:
524
- attn_weights_reshaped = None
475
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
476
+ attn_output = self.o_proj(attn_output)
477
+ return attn_output, attn_weights
525
478
 
526
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
527
479
 
528
- attn_output = torch.bmm(attn_probs, value_states)
480
+ class DetrCrossAttention(nn.Module):
481
+ """
482
+ Multi-headed cross-attention from 'Attention Is All You Need' paper.
529
483
 
530
- if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
531
- raise ValueError(
532
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
533
- f" {attn_output.size()}"
534
- )
484
+ In DETR, queries get their own position embeddings, while keys get encoder position embeddings.
485
+ Values don't get any position embeddings.
486
+ """
487
+
488
+ def __init__(
489
+ self,
490
+ config: DetrConfig,
491
+ hidden_size: int,
492
+ num_attention_heads: int,
493
+ dropout: float = 0.0,
494
+ bias: bool = True,
495
+ ):
496
+ super().__init__()
497
+ self.config = config
498
+ self.head_dim = hidden_size // num_attention_heads
499
+ self.scaling = self.head_dim**-0.5
500
+ self.attention_dropout = dropout
501
+ self.is_causal = False
502
+
503
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
504
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
505
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
506
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
507
+
508
+ def forward(
509
+ self,
510
+ hidden_states: torch.Tensor,
511
+ key_value_states: torch.Tensor,
512
+ attention_mask: torch.Tensor | None = None,
513
+ position_embeddings: torch.Tensor | None = None,
514
+ encoder_position_embeddings: torch.Tensor | None = None,
515
+ **kwargs: Unpack[TransformersKwargs],
516
+ ) -> tuple[torch.Tensor, torch.Tensor]:
517
+ """
518
+ Position embeddings logic:
519
+ - Queries get position_embeddings
520
+ - Keys get encoder_position_embeddings
521
+ - Values don't get any position embeddings
522
+ """
523
+ query_input_shape = hidden_states.shape[:-1]
524
+ query_hidden_shape = (*query_input_shape, -1, self.head_dim)
535
525
 
536
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
537
- attn_output = attn_output.transpose(1, 2)
538
- attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
526
+ kv_input_shape = key_value_states.shape[:-1]
527
+ kv_hidden_shape = (*kv_input_shape, -1, self.head_dim)
539
528
 
540
- attn_output = self.out_proj(attn_output)
529
+ query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
530
+ key_input = (
531
+ key_value_states + encoder_position_embeddings
532
+ if encoder_position_embeddings is not None
533
+ else key_value_states
534
+ )
535
+
536
+ query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2)
537
+ key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2)
538
+ value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2)
539
+
540
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
541
+ self.config._attn_implementation, eager_attention_forward
542
+ )
543
+
544
+ attn_output, attn_weights = attention_interface(
545
+ self,
546
+ query_states,
547
+ key_states,
548
+ value_states,
549
+ attention_mask,
550
+ dropout=0.0 if not self.training else self.attention_dropout,
551
+ scaling=self.scaling,
552
+ **kwargs,
553
+ )
554
+
555
+ attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
556
+ attn_output = self.o_proj(attn_output)
557
+ return attn_output, attn_weights
558
+
559
+
560
+ class DetrMLP(nn.Module):
561
+ def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int):
562
+ super().__init__()
563
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
564
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
565
+ self.activation_fn = ACT2FN[config.activation_function]
566
+ self.activation_dropout = config.activation_dropout
567
+ self.dropout = config.dropout
541
568
 
542
- return attn_output, attn_weights_reshaped
569
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
570
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
571
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
572
+ hidden_states = self.fc2(hidden_states)
573
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
574
+ return hidden_states
543
575
 
544
576
 
545
- class DetrEncoderLayer(nn.Module):
577
+ class DetrEncoderLayer(GradientCheckpointingLayer):
546
578
  def __init__(self, config: DetrConfig):
547
579
  super().__init__()
548
- self.embed_dim = config.d_model
549
- self.self_attn = DetrAttention(
550
- embed_dim=self.embed_dim,
551
- num_heads=config.encoder_attention_heads,
580
+ self.hidden_size = config.d_model
581
+ self.self_attn = DetrSelfAttention(
582
+ config=config,
583
+ hidden_size=self.hidden_size,
584
+ num_attention_heads=config.encoder_attention_heads,
552
585
  dropout=config.attention_dropout,
553
586
  )
554
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
587
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
555
588
  self.dropout = config.dropout
556
- self.activation_fn = ACT2FN[config.activation_function]
557
- self.activation_dropout = config.activation_dropout
558
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
559
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
560
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
589
+ self.mlp = DetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
590
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
561
591
 
562
592
  def forward(
563
593
  self,
564
594
  hidden_states: torch.Tensor,
565
595
  attention_mask: torch.Tensor,
566
- object_queries: torch.Tensor | None = None,
567
- output_attentions: bool = False,
568
- ):
596
+ spatial_position_embeddings: torch.Tensor | None = None,
597
+ **kwargs: Unpack[TransformersKwargs],
598
+ ) -> torch.Tensor:
569
599
  """
570
600
  Args:
571
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
601
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
572
602
  attention_mask (`torch.FloatTensor`): attention mask of size
573
603
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
574
604
  values.
575
- object_queries (`torch.FloatTensor`, *optional*):
576
- Object queries (also called content embeddings), to be added to the hidden states.
577
- output_attentions (`bool`, *optional*):
578
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
579
- returned tensors for more detail.
605
+ spatial_position_embeddings (`torch.FloatTensor`, *optional*):
606
+ Spatial position embeddings (2D positional encodings of image locations), to be added to both
607
+ the queries and keys in self-attention (but not to values).
580
608
  """
581
609
  residual = hidden_states
582
- hidden_states, attn_weights = self.self_attn(
610
+ hidden_states, _ = self.self_attn(
583
611
  hidden_states=hidden_states,
584
612
  attention_mask=attention_mask,
585
- object_queries=object_queries,
586
- output_attentions=output_attentions,
613
+ position_embeddings=spatial_position_embeddings,
614
+ **kwargs,
587
615
  )
588
616
 
589
617
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -591,12 +619,7 @@ class DetrEncoderLayer(nn.Module):
591
619
  hidden_states = self.self_attn_layer_norm(hidden_states)
592
620
 
593
621
  residual = hidden_states
594
- hidden_states = self.activation_fn(self.fc1(hidden_states))
595
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
596
-
597
- hidden_states = self.fc2(hidden_states)
598
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
599
-
622
+ hidden_states = self.mlp(hidden_states)
600
623
  hidden_states = residual + hidden_states
601
624
  hidden_states = self.final_layer_norm(hidden_states)
602
625
 
@@ -605,78 +628,69 @@ class DetrEncoderLayer(nn.Module):
605
628
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
606
629
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
607
630
 
608
- outputs = (hidden_states,)
609
-
610
- if output_attentions:
611
- outputs += (attn_weights,)
612
-
613
- return outputs
631
+ return hidden_states
614
632
 
615
633
 
616
634
  class DetrDecoderLayer(GradientCheckpointingLayer):
617
635
  def __init__(self, config: DetrConfig):
618
636
  super().__init__()
619
- self.embed_dim = config.d_model
637
+ self.hidden_size = config.d_model
620
638
 
621
- self.self_attn = DetrAttention(
622
- embed_dim=self.embed_dim,
623
- num_heads=config.decoder_attention_heads,
639
+ self.self_attn = DetrSelfAttention(
640
+ config=config,
641
+ hidden_size=self.hidden_size,
642
+ num_attention_heads=config.decoder_attention_heads,
624
643
  dropout=config.attention_dropout,
625
644
  )
626
645
  self.dropout = config.dropout
627
- self.activation_fn = ACT2FN[config.activation_function]
628
- self.activation_dropout = config.activation_dropout
629
646
 
630
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
631
- self.encoder_attn = DetrAttention(
632
- self.embed_dim,
633
- config.decoder_attention_heads,
647
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
648
+ self.encoder_attn = DetrCrossAttention(
649
+ config=config,
650
+ hidden_size=self.hidden_size,
651
+ num_attention_heads=config.decoder_attention_heads,
634
652
  dropout=config.attention_dropout,
635
653
  )
636
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
637
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
638
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
639
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
654
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
655
+ self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
656
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
640
657
 
641
658
  def forward(
642
659
  self,
643
660
  hidden_states: torch.Tensor,
644
661
  attention_mask: torch.Tensor | None = None,
645
- object_queries: torch.Tensor | None = None,
646
- query_position_embeddings: torch.Tensor | None = None,
662
+ spatial_position_embeddings: torch.Tensor | None = None,
663
+ object_queries_position_embeddings: torch.Tensor | None = None,
647
664
  encoder_hidden_states: torch.Tensor | None = None,
648
665
  encoder_attention_mask: torch.Tensor | None = None,
649
- output_attentions: bool | None = False,
650
- ):
666
+ **kwargs: Unpack[TransformersKwargs],
667
+ ) -> torch.Tensor:
651
668
  """
652
669
  Args:
653
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
670
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
654
671
  attention_mask (`torch.FloatTensor`): attention mask of size
655
672
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
656
673
  values.
657
- object_queries (`torch.FloatTensor`, *optional*):
658
- object_queries that are added to the hidden states
659
- in the cross-attention layer.
660
- query_position_embeddings (`torch.FloatTensor`, *optional*):
661
- position embeddings that are added to the queries and keys
662
- in the self-attention layer.
674
+ spatial_position_embeddings (`torch.FloatTensor`, *optional*):
675
+ Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only
676
+ in the cross-attention layer (not to values).
677
+ object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
678
+ Position embeddings for the object query slots. In self-attention, these are added to both queries
679
+ and keys (not values). In cross-attention, these are added to queries only (not to keys or values).
663
680
  encoder_hidden_states (`torch.FloatTensor`):
664
- cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
681
+ cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
665
682
  encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
666
683
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
667
684
  values.
668
- output_attentions (`bool`, *optional*):
669
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
670
- returned tensors for more detail.
671
685
  """
672
686
  residual = hidden_states
673
687
 
674
688
  # Self Attention
675
- hidden_states, self_attn_weights = self.self_attn(
689
+ hidden_states, _ = self.self_attn(
676
690
  hidden_states=hidden_states,
677
- object_queries=query_position_embeddings,
691
+ position_embeddings=object_queries_position_embeddings,
678
692
  attention_mask=attention_mask,
679
- output_attentions=output_attentions,
693
+ **kwargs,
680
694
  )
681
695
 
682
696
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -684,17 +698,16 @@ class DetrDecoderLayer(GradientCheckpointingLayer):
684
698
  hidden_states = self.self_attn_layer_norm(hidden_states)
685
699
 
686
700
  # Cross-Attention Block
687
- cross_attn_weights = None
688
701
  if encoder_hidden_states is not None:
689
702
  residual = hidden_states
690
703
 
691
- hidden_states, cross_attn_weights = self.encoder_attn(
704
+ hidden_states, _ = self.encoder_attn(
692
705
  hidden_states=hidden_states,
693
- object_queries=query_position_embeddings,
694
706
  key_value_states=encoder_hidden_states,
695
707
  attention_mask=encoder_attention_mask,
696
- spatial_position_embeddings=object_queries,
697
- output_attentions=output_attentions,
708
+ position_embeddings=object_queries_position_embeddings,
709
+ encoder_position_embeddings=spatial_position_embeddings,
710
+ **kwargs,
698
711
  )
699
712
 
700
713
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -703,19 +716,164 @@ class DetrDecoderLayer(GradientCheckpointingLayer):
703
716
 
704
717
  # Fully Connected
705
718
  residual = hidden_states
706
- hidden_states = self.activation_fn(self.fc1(hidden_states))
707
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
708
- hidden_states = self.fc2(hidden_states)
709
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
719
+ hidden_states = self.mlp(hidden_states)
710
720
  hidden_states = residual + hidden_states
711
721
  hidden_states = self.final_layer_norm(hidden_states)
712
722
 
713
- outputs = (hidden_states,)
723
+ return hidden_states
724
+
725
+
726
+ class DetrConvBlock(nn.Module):
727
+ """Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
728
+
729
+ def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
730
+ super().__init__()
731
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
732
+ self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
733
+ self.activation = ACT2FN[activation]
734
+
735
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
736
+ return self.activation(self.norm(self.conv(x)))
737
+
738
+
739
+ class DetrFPNFusionStage(nn.Module):
740
+ """Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
741
+
742
+ def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
743
+ super().__init__()
744
+ self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
745
+ self.refine = DetrConvBlock(current_channels, output_channels, activation)
746
+
747
+ def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
748
+ """
749
+ Args:
750
+ features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
751
+ fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
752
+
753
+ Returns:
754
+ Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
755
+ """
756
+ fpn_features = self.fpn_adapter(fpn_features)
757
+ features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
758
+ return self.refine(fpn_features + features)
759
+
714
760
 
715
- if output_attentions:
716
- outputs += (self_attn_weights, cross_attn_weights)
761
+ class DetrMaskHeadSmallConv(nn.Module):
762
+ """
763
+ Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
717
764
 
718
- return outputs
765
+ Combines attention maps (spatial localization) with encoder features (semantics) and progressively
766
+ upsamples through multiple scales, fusing with FPN features for high-resolution detail.
767
+ """
768
+
769
+ def __init__(
770
+ self,
771
+ input_channels: int,
772
+ fpn_channels: list[int],
773
+ hidden_size: int,
774
+ activation_function: str = "relu",
775
+ ):
776
+ super().__init__()
777
+ if input_channels % 8 != 0:
778
+ raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
779
+
780
+ self.conv1 = DetrConvBlock(input_channels, input_channels, activation_function)
781
+ self.conv2 = DetrConvBlock(input_channels, hidden_size // 2, activation_function)
782
+
783
+ # Progressive channel reduction: /2 -> /4 -> /8 -> /16
784
+ self.fpn_stages = nn.ModuleList(
785
+ [
786
+ DetrFPNFusionStage(fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function),
787
+ DetrFPNFusionStage(fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function),
788
+ DetrFPNFusionStage(fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function),
789
+ ]
790
+ )
791
+
792
+ self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
793
+
794
+ def forward(
795
+ self,
796
+ features: torch.Tensor,
797
+ attention_masks: torch.Tensor,
798
+ fpn_features: list[torch.Tensor],
799
+ ) -> torch.Tensor:
800
+ """
801
+ Args:
802
+ features: Encoder output features, shape (batch_size, hidden_size, H, W)
803
+ attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
804
+ fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
805
+
806
+ Returns:
807
+ Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
808
+ """
809
+ num_queries = attention_masks.shape[1]
810
+
811
+ # Expand to (batch_size * num_queries) dimension
812
+ features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
813
+ attention_masks = attention_masks.flatten(0, 1)
814
+ fpn_features = [
815
+ fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
816
+ ]
817
+
818
+ hidden_states = torch.cat([features, attention_masks], dim=1)
819
+ hidden_states = self.conv1(hidden_states)
820
+ hidden_states = self.conv2(hidden_states)
821
+
822
+ for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
823
+ hidden_states = fpn_stage(hidden_states, fpn_feat)
824
+
825
+ return self.output_conv(hidden_states)
826
+
827
+
828
+ class DetrMHAttentionMap(nn.Module):
829
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
830
+
831
+ def __init__(
832
+ self,
833
+ hidden_size: int,
834
+ num_attention_heads: int,
835
+ dropout: float = 0.0,
836
+ bias: bool = True,
837
+ ):
838
+ super().__init__()
839
+ self.head_dim = hidden_size // num_attention_heads
840
+ self.scaling = self.head_dim**-0.5
841
+ self.attention_dropout = dropout
842
+
843
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
844
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
845
+
846
+ def forward(
847
+ self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
848
+ ):
849
+ query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
850
+ key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
851
+
852
+ query_states = self.q_proj(query_states).view(query_hidden_shape)
853
+ key_states = nn.functional.conv2d(
854
+ key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
855
+ ).view(key_hidden_shape)
856
+
857
+ batch_size, num_queries, num_heads, head_dim = query_states.shape
858
+ _, _, _, height, width = key_states.shape
859
+ query_shape = (batch_size * num_heads, num_queries, head_dim)
860
+ key_shape = (batch_size * num_heads, height * width, head_dim)
861
+ attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
862
+
863
+ query = query_states.transpose(1, 2).contiguous().view(query_shape)
864
+ key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
865
+
866
+ attn_weights = (
867
+ (torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
868
+ )
869
+
870
+ if attention_mask is not None:
871
+ attn_weights = attn_weights + attention_mask
872
+
873
+ attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
874
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
875
+
876
+ return attn_weights
719
877
 
720
878
 
721
879
  @auto_docstring
@@ -725,21 +883,36 @@ class DetrPreTrainedModel(PreTrainedModel):
725
883
  main_input_name = "pixel_values"
726
884
  input_modalities = ("image",)
727
885
  _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
886
+ supports_gradient_checkpointing = True
887
+ _supports_sdpa = True
888
+ _supports_flash_attn = True
889
+ _supports_attention_backend = True
890
+ _supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
891
+ _keys_to_ignore_on_load_unexpected = [
892
+ r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
893
+ ]
728
894
 
729
895
  @torch.no_grad()
730
896
  def _init_weights(self, module):
731
897
  std = self.config.init_std
732
898
  xavier_std = self.config.init_xavier_std
733
899
 
734
- if isinstance(module, DetrMHAttentionMap):
735
- init.zeros_(module.k_linear.bias)
736
- init.zeros_(module.q_linear.bias)
737
- init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
738
- init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
900
+ if isinstance(module, DetrMaskHeadSmallConv):
901
+ # DetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
902
+ for m in module.modules():
903
+ if isinstance(m, nn.Conv2d):
904
+ init.kaiming_uniform_(m.weight, a=1)
905
+ if m.bias is not None:
906
+ init.constant_(m.bias, 0)
907
+ elif isinstance(module, DetrMHAttentionMap):
908
+ init.zeros_(module.k_proj.bias)
909
+ init.zeros_(module.q_proj.bias)
910
+ init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
911
+ init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
739
912
  elif isinstance(module, DetrLearnedPositionEmbedding):
740
913
  init.uniform_(module.row_embeddings.weight)
741
914
  init.uniform_(module.column_embeddings.weight)
742
- if isinstance(module, (nn.Linear, nn.Conv2d)):
915
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
743
916
  init.normal_(module.weight, mean=0.0, std=std)
744
917
  if module.bias is not None:
745
918
  init.zeros_(module.bias)
@@ -755,47 +928,36 @@ class DetrPreTrainedModel(PreTrainedModel):
755
928
 
756
929
  class DetrEncoder(DetrPreTrainedModel):
757
930
  """
758
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
759
- [`DetrEncoderLayer`].
760
-
761
- The encoder updates the flattened feature map through multiple self-attention layers.
762
-
763
- Small tweak for DETR:
764
-
765
- - object_queries are added to the forward pass.
931
+ Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
932
+ [`DetrEncoderLayer`] modules.
766
933
 
767
934
  Args:
768
- config: DetrConfig
935
+ config (`DetrConfig`): Model configuration object.
769
936
  """
770
937
 
938
+ _can_record_outputs = {"hidden_states": DetrEncoderLayer, "attentions": DetrSelfAttention}
939
+
771
940
  def __init__(self, config: DetrConfig):
772
941
  super().__init__(config)
773
942
 
774
943
  self.dropout = config.dropout
775
- self.layerdrop = config.encoder_layerdrop
776
-
777
944
  self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)])
778
945
 
779
- # in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
780
-
781
946
  # Initialize weights and apply final processing
782
947
  self.post_init()
783
948
 
949
+ @check_model_inputs()
784
950
  def forward(
785
951
  self,
786
952
  inputs_embeds=None,
787
953
  attention_mask=None,
788
- object_queries=None,
789
- output_attentions=None,
790
- output_hidden_states=None,
791
- return_dict=None,
792
- **kwargs,
793
- ):
954
+ spatial_position_embeddings=None,
955
+ **kwargs: Unpack[TransformersKwargs],
956
+ ) -> BaseModelOutput:
794
957
  r"""
795
958
  Args:
796
959
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
797
960
  Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
798
-
799
961
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
800
962
  Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
801
963
 
@@ -803,112 +965,67 @@ class DetrEncoder(DetrPreTrainedModel):
803
965
  - 0 for pixel features that are padding (i.e. **masked**).
804
966
 
805
967
  [What are attention masks?](../glossary#attention-mask)
806
-
807
- object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
808
- Object queries that are added to the queries in each self-attention layer.
809
-
810
- output_attentions (`bool`, *optional*):
811
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
812
- returned tensors for more detail.
813
- output_hidden_states (`bool`, *optional*):
814
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
815
- for more detail.
816
- return_dict (`bool`, *optional*):
817
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
968
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
969
+ Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
818
970
  """
819
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
820
- output_hidden_states = (
821
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
822
- )
823
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
824
-
825
971
  hidden_states = inputs_embeds
826
972
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
827
973
 
828
974
  # expand attention_mask
829
975
  if attention_mask is not None:
830
976
  # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
831
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
832
-
833
- encoder_states = () if output_hidden_states else None
834
- all_attentions = () if output_attentions else None
835
- for i, encoder_layer in enumerate(self.layers):
836
- if output_hidden_states:
837
- encoder_states = encoder_states + (hidden_states,)
838
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
839
- to_drop = False
840
- if self.training:
841
- dropout_probability = torch.rand([])
842
- if dropout_probability < self.layerdrop: # skip the layer
843
- to_drop = True
844
-
845
- if to_drop:
846
- layer_outputs = (None, None)
847
- else:
848
- # we add object_queries as extra input to the encoder_layer
849
- layer_outputs = encoder_layer(
850
- hidden_states,
851
- attention_mask,
852
- object_queries=object_queries,
853
- output_attentions=output_attentions,
854
- )
855
-
856
- hidden_states = layer_outputs[0]
857
-
858
- if output_attentions:
859
- all_attentions = all_attentions + (layer_outputs[1],)
860
-
861
- if output_hidden_states:
862
- encoder_states = encoder_states + (hidden_states,)
863
-
864
- if not return_dict:
865
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
866
- return BaseModelOutput(
867
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
868
- )
869
-
977
+ attention_mask = create_bidirectional_mask(
978
+ config=self.config,
979
+ input_embeds=inputs_embeds,
980
+ attention_mask=attention_mask,
981
+ )
870
982
 
871
- class DetrDecoder(DetrPreTrainedModel):
872
- """
873
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
983
+ for encoder_layer in self.layers:
984
+ # we add spatial_position_embeddings as extra input to the encoder_layer
985
+ hidden_states = encoder_layer(
986
+ hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
987
+ )
874
988
 
875
- The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
989
+ return BaseModelOutput(last_hidden_state=hidden_states)
876
990
 
877
- Some small tweaks for DETR:
878
991
 
879
- - object_queries and query_position_embeddings are added to the forward pass.
880
- - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
992
+ class DetrDecoder(DetrPreTrainedModel):
993
+ """
994
+ Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules,
995
+ which apply self-attention to the queries and cross-attention to the encoder's outputs.
881
996
 
882
997
  Args:
883
- config: DetrConfig
998
+ config (`DetrConfig`): Model configuration object.
884
999
  """
885
1000
 
1001
+ _can_record_outputs = {
1002
+ "hidden_states": DetrDecoderLayer,
1003
+ "attentions": DetrSelfAttention,
1004
+ "cross_attentions": DetrCrossAttention,
1005
+ }
1006
+
886
1007
  def __init__(self, config: DetrConfig):
887
1008
  super().__init__(config)
888
1009
  self.dropout = config.dropout
889
- self.layerdrop = config.decoder_layerdrop
890
1010
 
891
1011
  self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
892
1012
  # in DETR, the decoder uses layernorm after the last decoder layer output
893
1013
  self.layernorm = nn.LayerNorm(config.d_model)
894
1014
 
895
- self.gradient_checkpointing = False
896
1015
  # Initialize weights and apply final processing
897
1016
  self.post_init()
898
1017
 
1018
+ @check_model_inputs()
899
1019
  def forward(
900
1020
  self,
901
1021
  inputs_embeds=None,
902
1022
  attention_mask=None,
903
1023
  encoder_hidden_states=None,
904
1024
  encoder_attention_mask=None,
905
- object_queries=None,
906
- query_position_embeddings=None,
907
- output_attentions=None,
908
- output_hidden_states=None,
909
- return_dict=None,
910
- **kwargs,
911
- ):
1025
+ spatial_position_embeddings=None,
1026
+ object_queries_position_embeddings=None,
1027
+ **kwargs: Unpack[TransformersKwargs],
1028
+ ) -> DetrDecoderOutput:
912
1029
  r"""
913
1030
  Args:
914
1031
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -931,108 +1048,62 @@ class DetrDecoder(DetrPreTrainedModel):
931
1048
  - 1 for pixels that are real (i.e. **not masked**),
932
1049
  - 0 for pixels that are padding (i.e. **masked**).
933
1050
 
934
- object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
935
- Object queries that are added to the queries and keys in each cross-attention layer.
936
- query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
937
- , *optional*): Position embeddings that are added to the values and keys in each self-attention layer.
938
-
939
- output_attentions (`bool`, *optional*):
940
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
941
- returned tensors for more detail.
942
- output_hidden_states (`bool`, *optional*):
943
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
944
- for more detail.
945
- return_dict (`bool`, *optional*):
946
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1051
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1052
+ Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer.
1053
+ object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1054
+ Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
947
1055
  """
948
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
949
- output_hidden_states = (
950
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
951
- )
952
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
953
1056
 
954
1057
  if inputs_embeds is not None:
955
1058
  hidden_states = inputs_embeds
956
- input_shape = inputs_embeds.size()[:-1]
957
-
958
- combined_attention_mask = None
959
1059
 
960
- if attention_mask is not None and combined_attention_mask is not None:
961
- # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
962
- combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask(
963
- attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1060
+ # expand decoder attention mask (for self-attention on object queries)
1061
+ if attention_mask is not None:
1062
+ # [batch_size, num_queries] -> [batch_size, 1, num_queries, num_queries]
1063
+ attention_mask = create_bidirectional_mask(
1064
+ config=self.config,
1065
+ input_embeds=inputs_embeds,
1066
+ attention_mask=attention_mask,
964
1067
  )
965
1068
 
966
- # expand encoder attention mask
1069
+ # expand encoder attention mask (for cross-attention on encoder outputs)
967
1070
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
968
1071
  # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
969
- encoder_attention_mask = _prepare_4d_attention_mask(
970
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1072
+ encoder_attention_mask = create_bidirectional_mask(
1073
+ config=self.config,
1074
+ input_embeds=inputs_embeds,
1075
+ attention_mask=encoder_attention_mask,
1076
+ encoder_hidden_states=encoder_hidden_states,
971
1077
  )
972
1078
 
973
1079
  # optional intermediate hidden states
974
1080
  intermediate = () if self.config.auxiliary_loss else None
975
1081
 
976
1082
  # decoder layers
977
- all_hidden_states = () if output_hidden_states else None
978
- all_self_attns = () if output_attentions else None
979
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
980
1083
 
981
1084
  for idx, decoder_layer in enumerate(self.layers):
982
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
983
- if output_hidden_states:
984
- all_hidden_states += (hidden_states,)
985
- if self.training:
986
- dropout_probability = torch.rand([])
987
- if dropout_probability < self.layerdrop:
988
- continue
989
-
990
- layer_outputs = decoder_layer(
1085
+ hidden_states = decoder_layer(
991
1086
  hidden_states,
992
- combined_attention_mask,
993
- object_queries,
994
- query_position_embeddings,
1087
+ attention_mask,
1088
+ spatial_position_embeddings,
1089
+ object_queries_position_embeddings,
995
1090
  encoder_hidden_states, # as a positional argument for gradient checkpointing
996
1091
  encoder_attention_mask=encoder_attention_mask,
997
- output_attentions=output_attentions,
1092
+ **kwargs,
998
1093
  )
999
1094
 
1000
- hidden_states = layer_outputs[0]
1001
-
1002
1095
  if self.config.auxiliary_loss:
1003
1096
  hidden_states = self.layernorm(hidden_states)
1004
1097
  intermediate += (hidden_states,)
1005
1098
 
1006
- if output_attentions:
1007
- all_self_attns += (layer_outputs[1],)
1008
-
1009
- if encoder_hidden_states is not None:
1010
- all_cross_attentions += (layer_outputs[2],)
1011
-
1012
1099
  # finally, apply layernorm
1013
1100
  hidden_states = self.layernorm(hidden_states)
1014
1101
 
1015
- # add hidden states from the last decoder layer
1016
- if output_hidden_states:
1017
- all_hidden_states += (hidden_states,)
1018
-
1019
1102
  # stack intermediate decoder activations
1020
1103
  if self.config.auxiliary_loss:
1021
1104
  intermediate = torch.stack(intermediate)
1022
1105
 
1023
- if not return_dict:
1024
- return tuple(
1025
- v
1026
- for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
1027
- if v is not None
1028
- )
1029
- return DetrDecoderOutput(
1030
- last_hidden_state=hidden_states,
1031
- hidden_states=all_hidden_states,
1032
- attentions=all_self_attns,
1033
- cross_attentions=all_cross_attentions,
1034
- intermediate_hidden_states=intermediate,
1035
- )
1106
+ return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate)
1036
1107
 
1037
1108
 
1038
1109
  @auto_docstring(
@@ -1045,15 +1116,16 @@ class DetrModel(DetrPreTrainedModel):
1045
1116
  def __init__(self, config: DetrConfig):
1046
1117
  super().__init__(config)
1047
1118
 
1048
- # Create backbone + positional encoding
1049
- backbone = DetrConvEncoder(config)
1050
- object_queries = build_position_encoding(config)
1051
- self.backbone = DetrConvModel(backbone, object_queries)
1052
-
1053
- # Create projection layer
1054
- self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
1119
+ self.backbone = DetrConvEncoder(config)
1055
1120
 
1121
+ if config.position_embedding_type == "sine":
1122
+ self.position_embedding = DetrSinePositionEmbedding(config.d_model // 2, normalize=True)
1123
+ elif config.position_embedding_type == "learned":
1124
+ self.position_embedding = DetrLearnedPositionEmbedding(config.d_model // 2)
1125
+ else:
1126
+ raise ValueError(f"Not supported {config.position_embedding_type}")
1056
1127
  self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
1128
+ self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
1057
1129
 
1058
1130
  self.encoder = DetrEncoder(config)
1059
1131
  self.decoder = DetrDecoder(config)
@@ -1062,46 +1134,49 @@ class DetrModel(DetrPreTrainedModel):
1062
1134
  self.post_init()
1063
1135
 
1064
1136
  def freeze_backbone(self):
1065
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1137
+ for _, param in self.backbone.model.named_parameters():
1066
1138
  param.requires_grad_(False)
1067
1139
 
1068
1140
  def unfreeze_backbone(self):
1069
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1141
+ for _, param in self.backbone.model.named_parameters():
1070
1142
  param.requires_grad_(True)
1071
1143
 
1072
1144
  @auto_docstring
1145
+ @can_return_tuple
1073
1146
  def forward(
1074
1147
  self,
1075
- pixel_values: torch.FloatTensor,
1148
+ pixel_values: torch.FloatTensor | None = None,
1076
1149
  pixel_mask: torch.LongTensor | None = None,
1077
1150
  decoder_attention_mask: torch.FloatTensor | None = None,
1078
1151
  encoder_outputs: torch.FloatTensor | None = None,
1079
1152
  inputs_embeds: torch.FloatTensor | None = None,
1080
1153
  decoder_inputs_embeds: torch.FloatTensor | None = None,
1081
- output_attentions: bool | None = None,
1082
- output_hidden_states: bool | None = None,
1083
- return_dict: bool | None = None,
1084
- **kwargs,
1154
+ **kwargs: Unpack[TransformersKwargs],
1085
1155
  ) -> tuple[torch.FloatTensor] | DetrModelOutput:
1086
1156
  r"""
1087
1157
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1088
- Not used by default. Can be used to mask object queries.
1158
+ Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
1159
+
1160
+ - 1 for queries that are **not masked**,
1161
+ - 0 for queries that are **masked**.
1089
1162
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1090
1163
  Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1091
- can choose to directly pass a flattened representation of an image.
1164
+ can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
1092
1165
  decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1093
1166
  Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1094
- embedded representation.
1167
+ embedded representation. Useful for tasks that require custom query initialization.
1095
1168
 
1096
1169
  Examples:
1097
1170
 
1098
1171
  ```python
1099
1172
  >>> from transformers import AutoImageProcessor, DetrModel
1100
1173
  >>> from PIL import Image
1101
- >>> import requests
1174
+ >>> import httpx
1175
+ >>> from io import BytesIO
1102
1176
 
1103
1177
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1104
- >>> image = Image.open(requests.get(url, stream=True).raw)
1178
+ >>> with httpx.stream("GET", url) as response:
1179
+ ... image = Image.open(BytesIO(response.read()))
1105
1180
 
1106
1181
  >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
1107
1182
  >>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
@@ -1118,79 +1193,77 @@ class DetrModel(DetrPreTrainedModel):
1118
1193
  >>> list(last_hidden_states.shape)
1119
1194
  [1, 100, 256]
1120
1195
  ```"""
1121
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1122
- output_hidden_states = (
1123
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1124
- )
1125
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1126
-
1127
- batch_size, num_channels, height, width = pixel_values.shape
1128
- device = pixel_values.device
1129
-
1130
- if pixel_mask is None:
1131
- pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1132
-
1133
- # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1134
- # pixel_values should be of shape (batch_size, num_channels, height, width)
1135
- # pixel_mask should be of shape (batch_size, height, width)
1136
- features, object_queries_list = self.backbone(pixel_values, pixel_mask)
1137
-
1138
- # get final feature map and downsampled mask
1139
- feature_map, mask = features[-1]
1140
-
1141
- if mask is None:
1142
- raise ValueError("Backbone does not return downsampled pixel mask")
1143
-
1144
- # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1145
- projected_feature_map = self.input_projection(feature_map)
1146
-
1147
- # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1148
- # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1149
- flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1150
- object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1151
-
1152
- flattened_mask = mask.flatten(1)
1196
+ if pixel_values is None and inputs_embeds is None:
1197
+ raise ValueError("You have to specify either pixel_values or inputs_embeds")
1198
+
1199
+ if inputs_embeds is None:
1200
+ batch_size, num_channels, height, width = pixel_values.shape
1201
+ device = pixel_values.device
1202
+
1203
+ if pixel_mask is None:
1204
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1205
+ vision_features = self.backbone(pixel_values, pixel_mask)
1206
+ feature_map, mask = vision_features[-1]
1207
+
1208
+ # Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
1209
+ # Position embeddings are already flattened to (batch_size, sequence_length, hidden_size) format
1210
+ projected_feature_map = self.input_projection(feature_map)
1211
+ flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1212
+ spatial_position_embeddings = self.position_embedding(
1213
+ shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
1214
+ )
1215
+ flattened_mask = mask.flatten(1)
1216
+ else:
1217
+ batch_size = inputs_embeds.shape[0]
1218
+ device = inputs_embeds.device
1219
+ flattened_features = inputs_embeds
1220
+ # When using inputs_embeds, we need to infer spatial dimensions for position embeddings
1221
+ # Assume square feature map
1222
+ seq_len = inputs_embeds.shape[1]
1223
+ feat_dim = int(seq_len**0.5)
1224
+ # Create position embeddings for the inferred spatial size
1225
+ spatial_position_embeddings = self.position_embedding(
1226
+ shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]),
1227
+ device=device,
1228
+ dtype=inputs_embeds.dtype,
1229
+ )
1230
+ # If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten.
1231
+ if pixel_mask is not None:
1232
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0]
1233
+ flattened_mask = mask.flatten(1)
1234
+ else:
1235
+ # If no mask provided, assume all positions are valid
1236
+ flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long)
1153
1237
 
1154
- # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
1155
- # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
1156
- # flattened_mask is a Tensor of shape (batch_size, height*width)
1157
1238
  if encoder_outputs is None:
1158
1239
  encoder_outputs = self.encoder(
1159
1240
  inputs_embeds=flattened_features,
1160
1241
  attention_mask=flattened_mask,
1161
- object_queries=object_queries,
1162
- output_attentions=output_attentions,
1163
- output_hidden_states=output_hidden_states,
1164
- return_dict=return_dict,
1165
- )
1166
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1167
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1168
- encoder_outputs = BaseModelOutput(
1169
- last_hidden_state=encoder_outputs[0],
1170
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1171
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1242
+ spatial_position_embeddings=spatial_position_embeddings,
1243
+ **kwargs,
1172
1244
  )
1173
1245
 
1174
- # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
1175
- query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
1176
- queries = torch.zeros_like(query_position_embeddings)
1246
+ object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
1247
+ batch_size, 1, 1
1248
+ )
1249
+
1250
+ # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
1251
+ if decoder_inputs_embeds is not None:
1252
+ queries = decoder_inputs_embeds
1253
+ else:
1254
+ queries = torch.zeros_like(object_queries_position_embeddings)
1177
1255
 
1178
1256
  # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1179
1257
  decoder_outputs = self.decoder(
1180
1258
  inputs_embeds=queries,
1181
- attention_mask=None,
1182
- object_queries=object_queries,
1183
- query_position_embeddings=query_position_embeddings,
1184
- encoder_hidden_states=encoder_outputs[0],
1259
+ attention_mask=decoder_attention_mask,
1260
+ spatial_position_embeddings=spatial_position_embeddings,
1261
+ object_queries_position_embeddings=object_queries_position_embeddings,
1262
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
1185
1263
  encoder_attention_mask=flattened_mask,
1186
- output_attentions=output_attentions,
1187
- output_hidden_states=output_hidden_states,
1188
- return_dict=return_dict,
1264
+ **kwargs,
1189
1265
  )
1190
1266
 
1191
- if not return_dict:
1192
- return decoder_outputs + encoder_outputs
1193
-
1194
1267
  return DetrModelOutput(
1195
1268
  last_hidden_state=decoder_outputs.last_hidden_state,
1196
1269
  decoder_hidden_states=decoder_outputs.hidden_states,
@@ -1203,14 +1276,11 @@ class DetrModel(DetrPreTrainedModel):
1203
1276
  )
1204
1277
 
1205
1278
 
1206
- # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1207
1279
  class DetrMLPPredictionHead(nn.Module):
1208
1280
  """
1209
1281
  Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1210
1282
  height and width of a bounding box w.r.t. an image.
1211
1283
 
1212
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1213
-
1214
1284
  """
1215
1285
 
1216
1286
  def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
@@ -1250,6 +1320,7 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1250
1320
  self.post_init()
1251
1321
 
1252
1322
  @auto_docstring
1323
+ @can_return_tuple
1253
1324
  def forward(
1254
1325
  self,
1255
1326
  pixel_values: torch.FloatTensor,
@@ -1259,20 +1330,20 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1259
1330
  inputs_embeds: torch.FloatTensor | None = None,
1260
1331
  decoder_inputs_embeds: torch.FloatTensor | None = None,
1261
1332
  labels: list[dict] | None = None,
1262
- output_attentions: bool | None = None,
1263
- output_hidden_states: bool | None = None,
1264
- return_dict: bool | None = None,
1265
- **kwargs,
1333
+ **kwargs: Unpack[TransformersKwargs],
1266
1334
  ) -> tuple[torch.FloatTensor] | DetrObjectDetectionOutput:
1267
1335
  r"""
1268
1336
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1269
- Not used by default. Can be used to mask object queries.
1337
+ Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
1338
+
1339
+ - 1 for queries that are **not masked**,
1340
+ - 0 for queries that are **masked**.
1270
1341
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1271
1342
  Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1272
- can choose to directly pass a flattened representation of an image.
1343
+ can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone.
1273
1344
  decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1274
1345
  Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1275
- embedded representation.
1346
+ embedded representation. Useful for tasks that require custom query initialization.
1276
1347
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1277
1348
  Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1278
1349
  following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
@@ -1285,10 +1356,12 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1285
1356
  >>> from transformers import AutoImageProcessor, DetrForObjectDetection
1286
1357
  >>> import torch
1287
1358
  >>> from PIL import Image
1288
- >>> import requests
1359
+ >>> import httpx
1360
+ >>> from io import BytesIO
1289
1361
 
1290
1362
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1291
- >>> image = Image.open(requests.get(url, stream=True).raw)
1363
+ >>> with httpx.stream("GET", url) as response:
1364
+ ... image = Image.open(BytesIO(response.read()))
1292
1365
 
1293
1366
  >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
1294
1367
  >>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
@@ -1314,7 +1387,6 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1314
1387
  Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]
1315
1388
  Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]
1316
1389
  ```"""
1317
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1318
1390
 
1319
1391
  # First, sent images through DETR base model to obtain encoder + decoder outputs
1320
1392
  outputs = self.model(
@@ -1324,9 +1396,7 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1324
1396
  encoder_outputs=encoder_outputs,
1325
1397
  inputs_embeds=inputs_embeds,
1326
1398
  decoder_inputs_embeds=decoder_inputs_embeds,
1327
- output_attentions=output_attentions,
1328
- output_hidden_states=output_hidden_states,
1329
- return_dict=return_dict,
1399
+ **kwargs,
1330
1400
  )
1331
1401
 
1332
1402
  sequence_output = outputs[0]
@@ -1339,20 +1409,13 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1339
1409
  if labels is not None:
1340
1410
  outputs_class, outputs_coord = None, None
1341
1411
  if self.config.auxiliary_loss:
1342
- intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
1412
+ intermediate = outputs.intermediate_hidden_states
1343
1413
  outputs_class = self.class_labels_classifier(intermediate)
1344
1414
  outputs_coord = self.bbox_predictor(intermediate).sigmoid()
1345
1415
  loss, loss_dict, auxiliary_outputs = self.loss_function(
1346
1416
  logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
1347
1417
  )
1348
1418
 
1349
- if not return_dict:
1350
- if auxiliary_outputs is not None:
1351
- output = (logits, pred_boxes) + auxiliary_outputs + outputs
1352
- else:
1353
- output = (logits, pred_boxes) + outputs
1354
- return ((loss, loss_dict) + output) if loss is not None else output
1355
-
1356
1419
  return DetrObjectDetectionOutput(
1357
1420
  loss=loss,
1358
1421
  loss_dict=loss_dict,
@@ -1376,6 +1439,26 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1376
1439
  """
1377
1440
  )
1378
1441
  class DetrForSegmentation(DetrPreTrainedModel):
1442
+ _checkpoint_conversion_mapping = {
1443
+ "bbox_attention.q_linear": "bbox_attention.q_proj",
1444
+ "bbox_attention.k_linear": "bbox_attention.k_proj",
1445
+ # Mask head refactor
1446
+ "mask_head.lay1": "mask_head.conv1.conv",
1447
+ "mask_head.gn1": "mask_head.conv1.norm",
1448
+ "mask_head.lay2": "mask_head.conv2.conv",
1449
+ "mask_head.gn2": "mask_head.conv2.norm",
1450
+ "mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter",
1451
+ "mask_head.lay3": "mask_head.fpn_stages.0.refine.conv",
1452
+ "mask_head.gn3": "mask_head.fpn_stages.0.refine.norm",
1453
+ "mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter",
1454
+ "mask_head.lay4": "mask_head.fpn_stages.1.refine.conv",
1455
+ "mask_head.gn4": "mask_head.fpn_stages.1.refine.norm",
1456
+ "mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter",
1457
+ "mask_head.lay5": "mask_head.fpn_stages.2.refine.conv",
1458
+ "mask_head.gn5": "mask_head.fpn_stages.2.refine.norm",
1459
+ "mask_head.out_lay": "mask_head.output_conv",
1460
+ }
1461
+
1379
1462
  def __init__(self, config: DetrConfig):
1380
1463
  super().__init__(config)
1381
1464
 
@@ -1384,19 +1467,21 @@ class DetrForSegmentation(DetrPreTrainedModel):
1384
1467
 
1385
1468
  # segmentation head
1386
1469
  hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
1387
- intermediate_channel_sizes = self.detr.model.backbone.conv_encoder.intermediate_channel_sizes
1470
+ intermediate_channel_sizes = self.detr.model.backbone.intermediate_channel_sizes
1388
1471
 
1389
1472
  self.mask_head = DetrMaskHeadSmallConv(
1390
- hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size
1473
+ input_channels=hidden_size + number_of_heads,
1474
+ fpn_channels=intermediate_channel_sizes[::-1][-3:],
1475
+ hidden_size=hidden_size,
1476
+ activation_function=config.activation_function,
1391
1477
  )
1392
1478
 
1393
- self.bbox_attention = DetrMHAttentionMap(
1394
- hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
1395
- )
1479
+ self.bbox_attention = DetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
1396
1480
  # Initialize weights and apply final processing
1397
1481
  self.post_init()
1398
1482
 
1399
1483
  @auto_docstring
1484
+ @can_return_tuple
1400
1485
  def forward(
1401
1486
  self,
1402
1487
  pixel_values: torch.FloatTensor,
@@ -1406,20 +1491,20 @@ class DetrForSegmentation(DetrPreTrainedModel):
1406
1491
  inputs_embeds: torch.FloatTensor | None = None,
1407
1492
  decoder_inputs_embeds: torch.FloatTensor | None = None,
1408
1493
  labels: list[dict] | None = None,
1409
- output_attentions: bool | None = None,
1410
- output_hidden_states: bool | None = None,
1411
- return_dict: bool | None = None,
1412
- **kwargs,
1494
+ **kwargs: Unpack[TransformersKwargs],
1413
1495
  ) -> tuple[torch.FloatTensor] | DetrSegmentationOutput:
1414
1496
  r"""
1415
1497
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1416
- Not used by default. Can be used to mask object queries.
1498
+ Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
1499
+
1500
+ - 1 for queries that are **not masked**,
1501
+ - 0 for queries that are **masked**.
1417
1502
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1418
- Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1419
- can choose to directly pass a flattened representation of an image.
1503
+ Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
1504
+ multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
1420
1505
  decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1421
1506
  Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1422
- embedded representation.
1507
+ embedded representation. Useful for tasks that require custom query initialization.
1423
1508
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1424
1509
  Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
1425
1510
  dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
@@ -1432,7 +1517,8 @@ class DetrForSegmentation(DetrPreTrainedModel):
1432
1517
 
1433
1518
  ```python
1434
1519
  >>> import io
1435
- >>> import requests
1520
+ >>> import httpx
1521
+ >>> from io import BytesIO
1436
1522
  >>> from PIL import Image
1437
1523
  >>> import torch
1438
1524
  >>> import numpy
@@ -1441,7 +1527,8 @@ class DetrForSegmentation(DetrPreTrainedModel):
1441
1527
  >>> from transformers.image_transforms import rgb_to_id
1442
1528
 
1443
1529
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1444
- >>> image = Image.open(requests.get(url, stream=True).raw)
1530
+ >>> with httpx.stream("GET", url) as response:
1531
+ ... image = Image.open(BytesIO(response.read()))
1445
1532
 
1446
1533
  >>> image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
1447
1534
  >>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
@@ -1466,83 +1553,77 @@ class DetrForSegmentation(DetrPreTrainedModel):
1466
1553
  5
1467
1554
  ```"""
1468
1555
 
1469
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1470
-
1471
1556
  batch_size, num_channels, height, width = pixel_values.shape
1472
1557
  device = pixel_values.device
1473
1558
 
1474
1559
  if pixel_mask is None:
1475
1560
  pixel_mask = torch.ones((batch_size, height, width), device=device)
1476
1561
 
1477
- # First, get list of feature maps and position embeddings
1478
- features, object_queries_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
1562
+ vision_features = self.detr.model.backbone(pixel_values, pixel_mask)
1563
+ feature_map, mask = vision_features[-1]
1479
1564
 
1480
- # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1481
- feature_map, mask = features[-1]
1482
- batch_size, num_channels, height, width = feature_map.shape
1565
+ # Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
1483
1566
  projected_feature_map = self.detr.model.input_projection(feature_map)
1484
-
1485
- # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1486
- # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1487
1567
  flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1488
- object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1489
-
1568
+ spatial_position_embeddings = self.detr.model.position_embedding(
1569
+ shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
1570
+ )
1490
1571
  flattened_mask = mask.flatten(1)
1491
1572
 
1492
- # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
1493
- # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
1494
- # flattened_mask is a Tensor of shape (batch_size, height*width)
1495
1573
  if encoder_outputs is None:
1496
1574
  encoder_outputs = self.detr.model.encoder(
1497
1575
  inputs_embeds=flattened_features,
1498
1576
  attention_mask=flattened_mask,
1499
- object_queries=object_queries,
1500
- output_attentions=output_attentions,
1501
- output_hidden_states=output_hidden_states,
1502
- return_dict=return_dict,
1503
- )
1504
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1505
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1506
- encoder_outputs = BaseModelOutput(
1507
- last_hidden_state=encoder_outputs[0],
1508
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1509
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1577
+ spatial_position_embeddings=spatial_position_embeddings,
1578
+ **kwargs,
1510
1579
  )
1511
1580
 
1512
- # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
1513
- query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
1581
+ object_queries_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
1514
1582
  batch_size, 1, 1
1515
1583
  )
1516
- queries = torch.zeros_like(query_position_embeddings)
1517
1584
 
1518
- # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1585
+ # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
1586
+ if decoder_inputs_embeds is not None:
1587
+ queries = decoder_inputs_embeds
1588
+ else:
1589
+ queries = torch.zeros_like(object_queries_position_embeddings)
1590
+
1519
1591
  decoder_outputs = self.detr.model.decoder(
1520
1592
  inputs_embeds=queries,
1521
- attention_mask=None,
1522
- object_queries=object_queries,
1523
- query_position_embeddings=query_position_embeddings,
1524
- encoder_hidden_states=encoder_outputs[0],
1593
+ attention_mask=decoder_attention_mask,
1594
+ spatial_position_embeddings=spatial_position_embeddings,
1595
+ object_queries_position_embeddings=object_queries_position_embeddings,
1596
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
1525
1597
  encoder_attention_mask=flattened_mask,
1526
- output_attentions=output_attentions,
1527
- output_hidden_states=output_hidden_states,
1528
- return_dict=return_dict,
1598
+ **kwargs,
1529
1599
  )
1530
1600
 
1531
1601
  sequence_output = decoder_outputs[0]
1532
1602
 
1533
- # Sixth, compute logits, pred_boxes and pred_masks
1534
1603
  logits = self.detr.class_labels_classifier(sequence_output)
1535
1604
  pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid()
1536
1605
 
1537
- memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width)
1538
- mask = flattened_mask.view(batch_size, height, width)
1606
+ height, width = feature_map.shape[-2:]
1607
+ memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
1608
+ batch_size, self.config.d_model, height, width
1609
+ )
1610
+ attention_mask = flattened_mask.view(batch_size, height, width)
1539
1611
 
1540
- # FIXME h_boxes takes the last one computed, keep this in mind
1541
- # important: we need to reverse the mask, since in the original implementation the mask works reversed
1542
- # bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
1543
- bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)
1612
+ if attention_mask is not None:
1613
+ min_dtype = torch.finfo(memory.dtype).min
1614
+ attention_mask = torch.where(
1615
+ attention_mask.unsqueeze(1).unsqueeze(1),
1616
+ torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
1617
+ min_dtype,
1618
+ )
1544
1619
 
1545
- seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]])
1620
+ bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
1621
+
1622
+ seg_masks = self.mask_head(
1623
+ features=projected_feature_map,
1624
+ attention_masks=bbox_mask,
1625
+ fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
1626
+ )
1546
1627
 
1547
1628
  pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
1548
1629
 
@@ -1550,20 +1631,13 @@ class DetrForSegmentation(DetrPreTrainedModel):
1550
1631
  if labels is not None:
1551
1632
  outputs_class, outputs_coord = None, None
1552
1633
  if self.config.auxiliary_loss:
1553
- intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
1634
+ intermediate = decoder_outputs.intermediate_hidden_states
1554
1635
  outputs_class = self.detr.class_labels_classifier(intermediate)
1555
1636
  outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid()
1556
1637
  loss, loss_dict, auxiliary_outputs = self.loss_function(
1557
1638
  logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
1558
1639
  )
1559
1640
 
1560
- if not return_dict:
1561
- if auxiliary_outputs is not None:
1562
- output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
1563
- else:
1564
- output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
1565
- return ((loss, loss_dict) + output) if loss is not None else output
1566
-
1567
1641
  return DetrSegmentationOutput(
1568
1642
  loss=loss,
1569
1643
  loss_dict=loss_dict,
@@ -1581,119 +1655,6 @@ class DetrForSegmentation(DetrPreTrainedModel):
1581
1655
  )
1582
1656
 
1583
1657
 
1584
- def _expand(tensor, length: int):
1585
- return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
1586
-
1587
-
1588
- # taken from https://github.com/facebookresearch/detr/blob/master/models/segmentation.py
1589
- class DetrMaskHeadSmallConv(nn.Module):
1590
- """
1591
- Simple convolutional head, using group norm. Upsampling is done using a FPN approach
1592
- """
1593
-
1594
- def __init__(self, dim, fpn_dims, context_dim):
1595
- super().__init__()
1596
-
1597
- if dim % 8 != 0:
1598
- raise ValueError(
1599
- "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
1600
- " GroupNorm is set to 8"
1601
- )
1602
-
1603
- inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
1604
-
1605
- self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
1606
- self.gn1 = nn.GroupNorm(8, dim)
1607
- self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
1608
- self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
1609
- self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
1610
- self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
1611
- self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
1612
- self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
1613
- self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
1614
- self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
1615
- self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
1616
-
1617
- self.dim = dim
1618
-
1619
- self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
1620
- self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
1621
- self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
1622
-
1623
- for m in self.modules():
1624
- if isinstance(m, nn.Conv2d):
1625
- init.kaiming_uniform_(m.weight, a=1)
1626
- init.constant_(m.bias, 0)
1627
-
1628
- def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]):
1629
- # here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with
1630
- # the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
1631
- # We expand the projected feature map to match the number of heads.
1632
- x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
1633
-
1634
- x = self.lay1(x)
1635
- x = self.gn1(x)
1636
- x = nn.functional.relu(x)
1637
- x = self.lay2(x)
1638
- x = self.gn2(x)
1639
- x = nn.functional.relu(x)
1640
-
1641
- cur_fpn = self.adapter1(fpns[0])
1642
- if cur_fpn.size(0) != x.size(0):
1643
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1644
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1645
- x = self.lay3(x)
1646
- x = self.gn3(x)
1647
- x = nn.functional.relu(x)
1648
-
1649
- cur_fpn = self.adapter2(fpns[1])
1650
- if cur_fpn.size(0) != x.size(0):
1651
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1652
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1653
- x = self.lay4(x)
1654
- x = self.gn4(x)
1655
- x = nn.functional.relu(x)
1656
-
1657
- cur_fpn = self.adapter3(fpns[2])
1658
- if cur_fpn.size(0) != x.size(0):
1659
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1660
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1661
- x = self.lay5(x)
1662
- x = self.gn5(x)
1663
- x = nn.functional.relu(x)
1664
-
1665
- x = self.out_lay(x)
1666
- return x
1667
-
1668
-
1669
- class DetrMHAttentionMap(nn.Module):
1670
- """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
1671
-
1672
- def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
1673
- super().__init__()
1674
- self.num_heads = num_heads
1675
- self.hidden_dim = hidden_dim
1676
- self.dropout = nn.Dropout(dropout)
1677
-
1678
- self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1679
- self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1680
-
1681
- self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
1682
-
1683
- def forward(self, q, k, mask: Tensor | None = None):
1684
- q = self.q_linear(q)
1685
- k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
1686
- queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
1687
- keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
1688
- weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
1689
-
1690
- if mask is not None:
1691
- weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
1692
- weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
1693
- weights = self.dropout(weights)
1694
- return weights
1695
-
1696
-
1697
1658
  __all__ = [
1698
1659
  "DetrForObjectDetection",
1699
1660
  "DetrForSegmentation",