transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,9 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/conditional_detr/modular_conditional_detr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_conditional_detr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
1
7
  # Copyright 2022 Microsoft Research Asia and The HuggingFace Inc. team. All rights reserved.
2
8
  #
3
9
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,39 +17,33 @@
11
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
18
  # See the License for the specific language governing permissions and
13
19
  # limitations under the License.
14
- """PyTorch Conditional DETR model."""
15
-
16
20
  import math
21
+ from collections.abc import Callable
17
22
  from dataclasses import dataclass
18
23
 
19
24
  import torch
20
- from torch import Tensor, nn
25
+ from torch import nn
21
26
 
22
27
  from ... import initialization as init
23
28
  from ...activations import ACT2FN
24
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
29
+ from ...backbone_utils import load_backbone
30
+ from ...masking_utils import create_bidirectional_mask
25
31
  from ...modeling_layers import GradientCheckpointingLayer
26
32
  from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
27
- from ...modeling_utils import PreTrainedModel
28
- from ...utils import ModelOutput, auto_docstring, is_timm_available, logging, requires_backends
29
- from ...utils.backbone_utils import load_backbone
33
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
34
+ from ...processing_utils import Unpack
35
+ from ...pytorch_utils import compile_compatible_method_lru_cache
36
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring
37
+ from ...utils.generic import OutputRecorder, can_return_tuple, check_model_inputs
30
38
  from .configuration_conditional_detr import ConditionalDetrConfig
31
39
 
32
40
 
33
- if is_timm_available():
34
- from timm import create_model
35
-
36
-
37
- logger = logging.get_logger(__name__)
38
-
39
-
40
41
  @dataclass
41
42
  @auto_docstring(
42
43
  custom_intro="""
43
- Base class for outputs of the Conditional DETR decoder. This class adds one attribute to
44
- BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder activations, i.e. the output
45
- of each decoder layer, each of them gone through a layernorm. This is useful when training the model with auxiliary
46
- decoding losses.
44
+ Base class for outputs of the CONDITIONAL_DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
45
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
46
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
47
47
  """
48
48
  )
49
49
  class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
@@ -60,16 +60,16 @@ class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
60
60
  """
61
61
 
62
62
  intermediate_hidden_states: torch.FloatTensor | None = None
63
+
63
64
  reference_points: tuple[torch.FloatTensor] | None = None
64
65
 
65
66
 
66
67
  @dataclass
67
68
  @auto_docstring(
68
69
  custom_intro="""
69
- Base class for outputs of the Conditional DETR encoder-decoder model. This class adds one attribute to
70
- Seq2SeqModelOutput, namely an optional stack of intermediate decoder activations, i.e. the output of each decoder
71
- layer, each of them gone through a layernorm. This is useful when training the model with auxiliary decoding
72
- losses.
70
+ Base class for outputs of the CONDITIONAL_DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
71
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
72
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
73
73
  """
74
74
  )
75
75
  class ConditionalDetrModelOutput(Seq2SeqModelOutput):
@@ -84,6 +84,7 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
84
84
  """
85
85
 
86
86
  intermediate_hidden_states: torch.FloatTensor | None = None
87
+
87
88
  reference_points: tuple[torch.FloatTensor] | None = None
88
89
 
89
90
 
@@ -93,7 +94,6 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
93
94
  Output type of [`ConditionalDetrForObjectDetection`].
94
95
  """
95
96
  )
96
- # Copied from transformers.models.detr.modeling_detr.DetrObjectDetectionOutput with Detr->ConditionalDetr
97
97
  class ConditionalDetrObjectDetectionOutput(ModelOutput):
98
98
  r"""
99
99
  loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
@@ -137,7 +137,6 @@ class ConditionalDetrObjectDetectionOutput(ModelOutput):
137
137
  Output type of [`ConditionalDetrForSegmentation`].
138
138
  """
139
139
  )
140
- # Copied from transformers.models.detr.modeling_detr.DetrSegmentationOutput with Detr->ConditionalDetr
141
140
  class ConditionalDetrSegmentationOutput(ModelOutput):
142
141
  r"""
143
142
  loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
@@ -182,7 +181,6 @@ class ConditionalDetrSegmentationOutput(ModelOutput):
182
181
  encoder_attentions: tuple[torch.FloatTensor] | None = None
183
182
 
184
183
 
185
- # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->ConditionalDetr
186
184
  class ConditionalDetrFrozenBatchNorm2d(nn.Module):
187
185
  """
188
186
  BatchNorm2d where the batch statistics and the affine parameters are fixed.
@@ -222,7 +220,6 @@ class ConditionalDetrFrozenBatchNorm2d(nn.Module):
222
220
  return x * scale + bias
223
221
 
224
222
 
225
- # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->ConditionalDetr
226
223
  def replace_batch_norm(model):
227
224
  r"""
228
225
  Recursively replace all `torch.nn.BatchNorm2d` with `ConditionalDetrFrozenBatchNorm2d`.
@@ -247,7 +244,6 @@ def replace_batch_norm(model):
247
244
  replace_batch_norm(module)
248
245
 
249
246
 
250
- # Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->ConditionalDetr
251
247
  class ConditionalDetrConvEncoder(nn.Module):
252
248
  """
253
249
  Convolutional backbone, using either the AutoBackbone API or one from the timm library.
@@ -261,47 +257,25 @@ class ConditionalDetrConvEncoder(nn.Module):
261
257
 
262
258
  self.config = config
263
259
 
264
- # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
265
- if config.use_timm_backbone:
266
- # We default to values which were previously hard-coded. This enables configurability from the config
267
- # using backbone arguments, while keeping the default behavior the same.
268
- requires_backends(self, ["timm"])
269
- kwargs = getattr(config, "backbone_kwargs", {})
270
- kwargs = {} if kwargs is None else kwargs.copy()
271
- out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
272
- num_channels = kwargs.pop("in_chans", config.num_channels)
273
- if config.dilation:
274
- kwargs["output_stride"] = kwargs.get("output_stride", 16)
275
- backbone = create_model(
276
- config.backbone,
277
- pretrained=config.use_pretrained_backbone,
278
- features_only=True,
279
- out_indices=out_indices,
280
- in_chans=num_channels,
281
- **kwargs,
282
- )
283
- else:
284
- backbone = load_backbone(config)
260
+ backbone = load_backbone(config)
261
+ self.intermediate_channel_sizes = backbone.channels
285
262
 
286
263
  # replace batch norm by frozen batch norm
287
264
  with torch.no_grad():
288
265
  replace_batch_norm(backbone)
289
- self.model = backbone
290
- self.intermediate_channel_sizes = (
291
- self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
292
- )
293
266
 
294
- backbone_model_type = None
295
- if config.backbone is not None:
296
- backbone_model_type = config.backbone
297
- elif config.backbone_config is not None:
298
- backbone_model_type = config.backbone_config.model_type
299
- else:
300
- raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
267
+ # We used to load with timm library directly instead of the AutoBackbone API
268
+ # so we need to unwrap the `backbone._backbone` module to load weights without mismatch
269
+ is_timm_model = False
270
+ if hasattr(backbone, "_backbone"):
271
+ backbone = backbone._backbone
272
+ is_timm_model = True
273
+ self.model = backbone
301
274
 
275
+ backbone_model_type = config.backbone_config.model_type
302
276
  if "resnet" in backbone_model_type:
303
277
  for name, parameter in self.model.named_parameters():
304
- if config.use_timm_backbone:
278
+ if is_timm_model:
305
279
  if "layer2" not in name and "layer3" not in name and "layer4" not in name:
306
280
  parameter.requires_grad_(False)
307
281
  else:
@@ -310,7 +284,9 @@ class ConditionalDetrConvEncoder(nn.Module):
310
284
 
311
285
  def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
312
286
  # send pixel_values through the model to get list of feature maps
313
- features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
287
+ features = self.model(pixel_values)
288
+ if isinstance(features, dict):
289
+ features = features.feature_maps
314
290
 
315
291
  out = []
316
292
  for feature_map in features:
@@ -320,66 +296,58 @@ class ConditionalDetrConvEncoder(nn.Module):
320
296
  return out
321
297
 
322
298
 
323
- # Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->ConditionalDetr
324
- class ConditionalDetrConvModel(nn.Module):
325
- """
326
- This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
327
- """
328
-
329
- def __init__(self, conv_encoder, position_embedding):
330
- super().__init__()
331
- self.conv_encoder = conv_encoder
332
- self.position_embedding = position_embedding
333
-
334
- def forward(self, pixel_values, pixel_mask):
335
- # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
336
- out = self.conv_encoder(pixel_values, pixel_mask)
337
- pos = []
338
- for feature_map, mask in out:
339
- # position encoding
340
- pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
341
-
342
- return out, pos
343
-
344
-
345
299
  class ConditionalDetrSinePositionEmbedding(nn.Module):
346
300
  """
347
301
  This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
348
302
  need paper, generalized to work on images.
349
303
  """
350
304
 
351
- def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
305
+ def __init__(
306
+ self,
307
+ num_position_features: int = 64,
308
+ temperature: int = 10000,
309
+ normalize: bool = False,
310
+ scale: float | None = None,
311
+ ):
352
312
  super().__init__()
353
- self.embedding_dim = embedding_dim
354
- self.temperature = temperature
355
- self.normalize = normalize
356
313
  if scale is not None and normalize is False:
357
314
  raise ValueError("normalize should be True if scale is passed")
358
- if scale is None:
359
- scale = 2 * math.pi
360
- self.scale = scale
315
+ self.num_position_features = num_position_features
316
+ self.temperature = temperature
317
+ self.normalize = normalize
318
+ self.scale = 2 * math.pi if scale is None else scale
361
319
 
362
- def forward(self, pixel_values, pixel_mask):
363
- if pixel_mask is None:
364
- raise ValueError("No pixel mask provided")
365
- y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
366
- x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
320
+ @compile_compatible_method_lru_cache(maxsize=1)
321
+ def forward(
322
+ self,
323
+ shape: torch.Size,
324
+ device: torch.device | str,
325
+ dtype: torch.dtype,
326
+ mask: torch.Tensor | None = None,
327
+ ) -> torch.Tensor:
328
+ if mask is None:
329
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
330
+ y_embed = mask.cumsum(1, dtype=dtype)
331
+ x_embed = mask.cumsum(2, dtype=dtype)
367
332
  if self.normalize:
368
- y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
369
- x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
333
+ eps = 1e-6
334
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
335
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
370
336
 
371
- dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
372
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
337
+ dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
338
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
373
339
 
374
340
  pos_x = x_embed[:, :, :, None] / dim_t
375
341
  pos_y = y_embed[:, :, :, None] / dim_t
376
342
  pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
377
343
  pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
378
344
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
345
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
346
+ # expected by the encoder
347
+ pos = pos.flatten(2).permute(0, 2, 1)
379
348
  return pos
380
349
 
381
350
 
382
- # Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding with Detr->ConditionalDetr
383
351
  class ConditionalDetrLearnedPositionEmbedding(nn.Module):
384
352
  """
385
353
  This module learns positional embeddings up to a fixed maximum size.
@@ -390,354 +358,385 @@ class ConditionalDetrLearnedPositionEmbedding(nn.Module):
390
358
  self.row_embeddings = nn.Embedding(50, embedding_dim)
391
359
  self.column_embeddings = nn.Embedding(50, embedding_dim)
392
360
 
393
- def forward(self, pixel_values, pixel_mask=None):
394
- height, width = pixel_values.shape[-2:]
395
- width_values = torch.arange(width, device=pixel_values.device)
396
- height_values = torch.arange(height, device=pixel_values.device)
361
+ @compile_compatible_method_lru_cache(maxsize=1)
362
+ def forward(
363
+ self,
364
+ shape: torch.Size,
365
+ device: torch.device | str,
366
+ dtype: torch.dtype,
367
+ mask: torch.Tensor | None = None,
368
+ ):
369
+ height, width = shape[-2:]
370
+ width_values = torch.arange(width, device=device)
371
+ height_values = torch.arange(height, device=device)
397
372
  x_emb = self.column_embeddings(width_values)
398
373
  y_emb = self.row_embeddings(height_values)
399
374
  pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
400
375
  pos = pos.permute(2, 0, 1)
401
376
  pos = pos.unsqueeze(0)
402
- pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
377
+ pos = pos.repeat(shape[0], 1, 1, 1)
378
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
379
+ # expected by the encoder
380
+ pos = pos.flatten(2).permute(0, 2, 1)
403
381
  return pos
404
382
 
405
383
 
406
- # Copied from transformers.models.detr.modeling_detr.build_position_encoding with Detr->ConditionalDetr
407
- def build_position_encoding(config):
408
- n_steps = config.d_model // 2
409
- if config.position_embedding_type == "sine":
410
- # TODO find a better way of exposing other arguments
411
- position_embedding = ConditionalDetrSinePositionEmbedding(n_steps, normalize=True)
412
- elif config.position_embedding_type == "learned":
413
- position_embedding = ConditionalDetrLearnedPositionEmbedding(n_steps)
414
- else:
415
- raise ValueError(f"Not supported {config.position_embedding_type}")
384
+ def eager_attention_forward(
385
+ module: nn.Module,
386
+ query: torch.Tensor,
387
+ key: torch.Tensor,
388
+ value: torch.Tensor,
389
+ attention_mask: torch.Tensor | None,
390
+ scaling: float | None = None,
391
+ dropout: float = 0.0,
392
+ **kwargs: Unpack[TransformersKwargs],
393
+ ):
394
+ if scaling is None:
395
+ scaling = query.size(-1) ** -0.5
416
396
 
417
- return position_embedding
397
+ # Take the dot product between "query" and "key" to get the raw attention scores.
398
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
418
399
 
400
+ if attention_mask is not None:
401
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
402
+ attn_weights = attn_weights + attention_mask
419
403
 
420
- # function to generate sine positional embedding for 2d coordinates
421
- def gen_sine_position_embeddings(pos_tensor, d_model):
422
- scale = 2 * math.pi
423
- dim = d_model // 2
424
- dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
425
- dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
426
- x_embed = pos_tensor[:, :, 0] * scale
427
- y_embed = pos_tensor[:, :, 1] * scale
428
- pos_x = x_embed[:, :, None] / dim_t
429
- pos_y = y_embed[:, :, None] / dim_t
430
- pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
431
- pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
432
- pos = torch.cat((pos_y, pos_x), dim=2)
433
- return pos.to(pos_tensor.dtype)
404
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
405
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
434
406
 
407
+ attn_output = torch.matmul(attn_weights, value)
408
+ attn_output = attn_output.transpose(1, 2).contiguous()
435
409
 
436
- def inverse_sigmoid(x, eps=1e-5):
437
- x = x.clamp(min=0, max=1)
438
- x1 = x.clamp(min=eps)
439
- x2 = (1 - x).clamp(min=eps)
440
- return torch.log(x1 / x2)
410
+ return attn_output, attn_weights
441
411
 
442
412
 
443
- # Copied from transformers.models.detr.modeling_detr.DetrAttention
444
- class DetrAttention(nn.Module):
413
+ class ConditionalDetrSelfAttention(nn.Module):
445
414
  """
446
- Multi-headed attention from 'Attention Is All You Need' paper.
415
+ Multi-headed self-attention from 'Attention Is All You Need' paper.
447
416
 
448
- Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
417
+ In CONDITIONAL_DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
449
418
  """
450
419
 
451
420
  def __init__(
452
421
  self,
453
- embed_dim: int,
454
- num_heads: int,
422
+ config: ConditionalDetrConfig,
423
+ hidden_size: int,
424
+ num_attention_heads: int,
455
425
  dropout: float = 0.0,
456
426
  bias: bool = True,
457
427
  ):
458
428
  super().__init__()
459
- self.embed_dim = embed_dim
460
- self.num_heads = num_heads
461
- self.dropout = dropout
462
- self.head_dim = embed_dim // num_heads
463
- if self.head_dim * num_heads != self.embed_dim:
464
- raise ValueError(
465
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
466
- f" {num_heads})."
467
- )
429
+ self.config = config
430
+ self.head_dim = hidden_size // num_attention_heads
468
431
  self.scaling = self.head_dim**-0.5
432
+ self.attention_dropout = dropout
433
+ self.is_causal = False
469
434
 
470
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
471
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
472
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
473
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
474
-
475
- def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
476
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
477
-
478
- def with_pos_embed(self, tensor: torch.Tensor, object_queries: Tensor | None):
479
- return tensor if object_queries is None else tensor + object_queries
435
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
436
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
437
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
438
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
480
439
 
481
440
  def forward(
482
441
  self,
483
442
  hidden_states: torch.Tensor,
484
443
  attention_mask: torch.Tensor | None = None,
485
- object_queries: torch.Tensor | None = None,
486
- key_value_states: torch.Tensor | None = None,
487
- spatial_position_embeddings: torch.Tensor | None = None,
488
- output_attentions: bool = False,
489
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
490
- """Input shape: Batch x Time x Channel"""
491
- # if key_value_states are provided this layer is used as a cross-attention layer
492
- # for the decoder
493
- is_cross_attention = key_value_states is not None
494
- batch_size, target_len, embed_dim = hidden_states.size()
495
-
496
- # add position embeddings to the hidden states before projecting to queries and keys
497
- if object_queries is not None:
498
- hidden_states_original = hidden_states
499
- hidden_states = self.with_pos_embed(hidden_states, object_queries)
500
-
501
- # add key-value position embeddings to the key value states
502
- if spatial_position_embeddings is not None:
503
- key_value_states_original = key_value_states
504
- key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
505
-
506
- # get query proj
507
- query_states = self.q_proj(hidden_states) * self.scaling
508
- # get key, value proj
509
- if is_cross_attention:
510
- # cross_attentions
511
- key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
512
- value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
513
- else:
514
- # self_attention
515
- key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
516
- value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
444
+ position_embeddings: torch.Tensor | None = None,
445
+ **kwargs: Unpack[TransformersKwargs],
446
+ ) -> tuple[torch.Tensor, torch.Tensor]:
447
+ """
448
+ Position embeddings are added to both queries and keys (but not values).
449
+ """
450
+ input_shape = hidden_states.shape[:-1]
451
+ hidden_shape = (*input_shape, -1, self.head_dim)
517
452
 
518
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
519
- query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
520
- key_states = key_states.view(*proj_shape)
521
- value_states = value_states.view(*proj_shape)
453
+ query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
522
454
 
523
- source_len = key_states.size(1)
455
+ query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
456
+ key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
457
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
524
458
 
525
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
459
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
460
+ self.config._attn_implementation, eager_attention_forward
461
+ )
526
462
 
527
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
528
- raise ValueError(
529
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
530
- f" {attn_weights.size()}"
531
- )
463
+ attn_output, attn_weights = attention_interface(
464
+ self,
465
+ query_states,
466
+ key_states,
467
+ value_states,
468
+ attention_mask,
469
+ dropout=0.0 if not self.training else self.attention_dropout,
470
+ scaling=self.scaling,
471
+ **kwargs,
472
+ )
473
+
474
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
475
+ attn_output = self.o_proj(attn_output)
476
+ return attn_output, attn_weights
532
477
 
533
- if attention_mask is not None:
534
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
535
- raise ValueError(
536
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
537
- f" {attention_mask.size()}"
538
- )
539
- if attention_mask.dtype == torch.bool:
540
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
541
- attention_mask, -torch.inf
542
- )
543
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
544
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
545
-
546
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
547
-
548
- if output_attentions:
549
- # this operation is a bit awkward, but it's required to
550
- # make sure that attn_weights keeps its gradient.
551
- # In order to do so, attn_weights have to reshaped
552
- # twice and have to be reused in the following
553
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
554
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
555
- else:
556
- attn_weights_reshaped = None
557
478
 
558
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
479
+ class ConditionalDetrDecoderSelfAttention(nn.Module):
480
+ """
481
+ Multi-headed self-attention for Conditional DETR decoder layers.
559
482
 
560
- attn_output = torch.bmm(attn_probs, value_states)
483
+ This attention module handles separate content and position projections, which are then combined
484
+ before applying standard self-attention. Position embeddings are added to both queries and keys.
485
+ """
561
486
 
562
- if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
563
- raise ValueError(
564
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
565
- f" {attn_output.size()}"
566
- )
487
+ def __init__(
488
+ self,
489
+ config: ConditionalDetrConfig,
490
+ hidden_size: int,
491
+ num_attention_heads: int,
492
+ dropout: float = 0.0,
493
+ ):
494
+ super().__init__()
495
+ self.config = config
496
+ self.hidden_size = hidden_size
497
+ self.head_dim = hidden_size // num_attention_heads
498
+ self.scaling = self.head_dim**-0.5
499
+ self.attention_dropout = dropout
500
+ self.is_causal = False
501
+
502
+ # Content and position projections
503
+ self.q_content_proj = nn.Linear(hidden_size, hidden_size)
504
+ self.q_pos_proj = nn.Linear(hidden_size, hidden_size)
505
+ self.k_content_proj = nn.Linear(hidden_size, hidden_size)
506
+ self.k_pos_proj = nn.Linear(hidden_size, hidden_size)
507
+ self.v_proj = nn.Linear(hidden_size, hidden_size)
508
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
509
+
510
+ def forward(
511
+ self,
512
+ hidden_states: torch.Tensor,
513
+ query_position_embeddings: torch.Tensor,
514
+ attention_mask: torch.Tensor | None = None,
515
+ **kwargs: Unpack[TransformersKwargs],
516
+ ) -> tuple[torch.Tensor, torch.Tensor]:
517
+ """
518
+ Args:
519
+ hidden_states (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
520
+ Input hidden states from the decoder layer.
521
+ query_position_embeddings (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
522
+ Position embeddings for queries and keys. Required (unlike standard attention). Processed through
523
+ separate position projections (`q_pos_proj`, `k_pos_proj`) and added to content projections.
524
+ attention_mask (`torch.Tensor` of shape `(batch_size, 1, num_queries, num_queries)`, *optional*):
525
+ Attention mask to avoid attending to padding tokens.
526
+ """
527
+ input_shape = hidden_states.shape[:-1]
528
+ hidden_shape = (*input_shape, -1, self.head_dim)
529
+
530
+ query_states = (
531
+ (self.q_content_proj(hidden_states) + self.q_pos_proj(query_position_embeddings))
532
+ .view(hidden_shape)
533
+ .transpose(1, 2)
534
+ )
535
+ key_states = (
536
+ (self.k_content_proj(hidden_states) + self.k_pos_proj(query_position_embeddings))
537
+ .view(hidden_shape)
538
+ .transpose(1, 2)
539
+ )
540
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
567
541
 
568
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
569
- attn_output = attn_output.transpose(1, 2)
570
- attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
542
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
543
+ self.config._attn_implementation, eager_attention_forward
544
+ )
571
545
 
572
- attn_output = self.out_proj(attn_output)
546
+ attn_output, attn_weights = attention_interface(
547
+ self,
548
+ query_states,
549
+ key_states,
550
+ value_states,
551
+ attention_mask,
552
+ dropout=0.0 if not self.training else self.attention_dropout,
553
+ scaling=self.scaling,
554
+ **kwargs,
555
+ )
573
556
 
574
- return attn_output, attn_weights_reshaped
557
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
558
+ attn_output = self.o_proj(attn_output)
559
+ return attn_output, attn_weights
575
560
 
576
561
 
577
- class ConditionalDetrAttention(nn.Module):
562
+ class ConditionalDetrDecoderCrossAttention(nn.Module):
578
563
  """
579
- Cross-Attention used in Conditional DETR 'Conditional DETR for Fast Training Convergence' paper.
564
+ Multi-headed cross-attention for Conditional DETR decoder layers.
580
565
 
581
- The key q_proj, k_proj, v_proj are defined outside the attention. This attention allows the dim of q, k to be
582
- different to v.
566
+ This attention module handles the special cross-attention logic in Conditional DETR:
567
+ - Separate content and position projections for queries and keys
568
+ - Concatenation of query sine embeddings with queries (doubling query dimension)
569
+ - Concatenation of key position embeddings with keys (doubling key dimension)
570
+ - Output dimension remains hidden_size despite doubled input dimensions
583
571
  """
584
572
 
585
573
  def __init__(
586
574
  self,
587
- embed_dim: int,
588
- out_dim: int,
589
- num_heads: int,
575
+ config: ConditionalDetrConfig,
576
+ hidden_size: int,
577
+ num_attention_heads: int,
590
578
  dropout: float = 0.0,
591
- bias: bool = True,
592
579
  ):
593
580
  super().__init__()
594
- self.embed_dim = embed_dim
595
- self.out_dim = out_dim
596
- self.num_heads = num_heads
597
- self.dropout = dropout
598
- self.head_dim = embed_dim // num_heads
599
- if self.head_dim * num_heads != self.embed_dim:
600
- raise ValueError(
601
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
602
- f" {num_heads})."
603
- )
604
- # head dimension of values
605
- self.v_head_dim = out_dim // num_heads
606
- if self.v_head_dim * num_heads != self.out_dim:
607
- raise ValueError(
608
- f"out_dim must be divisible by num_heads (got `out_dim`: {self.out_dim} and `num_heads`: {num_heads})."
609
- )
610
- self.scaling = self.head_dim**-0.5
611
-
612
- self.out_proj = nn.Linear(out_dim, out_dim, bias=bias)
613
-
614
- def _qk_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
615
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
616
-
617
- def _v_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
618
- return tensor.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous()
581
+ self.config = config
582
+ self.hidden_size = hidden_size
583
+ self.num_attention_heads = num_attention_heads
584
+ self.head_dim = hidden_size // num_attention_heads
585
+ self.attention_dropout = dropout
586
+ self.is_causal = False
587
+
588
+ # Content and position projections
589
+ self.q_content_proj = nn.Linear(hidden_size, hidden_size)
590
+ self.q_pos_proj = nn.Linear(hidden_size, hidden_size)
591
+ self.k_content_proj = nn.Linear(hidden_size, hidden_size)
592
+ self.k_pos_proj = nn.Linear(hidden_size, hidden_size)
593
+ self.v_proj = nn.Linear(hidden_size, hidden_size)
594
+ self.q_pos_sine_proj = nn.Linear(hidden_size, hidden_size)
595
+
596
+ # Output projection: input is hidden_size * 2 (from concatenated q/k), output is hidden_size
597
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
598
+
599
+ # Compute scaling for expanded head_dim (q and k have doubled dimensions after concatenation)
600
+ # This matches the original Conditional DETR implementation where embed_dim * 2 is used
601
+ expanded_head_dim = (hidden_size * 2) // num_attention_heads
602
+ self.scaling = expanded_head_dim**-0.5
619
603
 
620
604
  def forward(
621
605
  self,
622
606
  hidden_states: torch.Tensor,
607
+ encoder_hidden_states: torch.Tensor,
608
+ query_sine_embed: torch.Tensor,
609
+ encoder_position_embeddings: torch.Tensor,
610
+ query_position_embeddings: torch.Tensor | None = None,
623
611
  attention_mask: torch.Tensor | None = None,
624
- key_states: torch.Tensor | None = None,
625
- value_states: torch.Tensor | None = None,
626
- output_attentions: bool = False,
627
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
628
- """Input shape: Batch x Time x Channel"""
629
-
630
- batch_size, target_len, _ = hidden_states.size()
631
-
632
- # get query proj
633
- query_states = hidden_states * self.scaling
634
- # get key, value proj
635
- key_states = self._qk_shape(key_states, -1, batch_size)
636
- value_states = self._v_shape(value_states, -1, batch_size)
637
-
638
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
639
- v_proj_shape = (batch_size * self.num_heads, -1, self.v_head_dim)
640
- query_states = self._qk_shape(query_states, target_len, batch_size).view(*proj_shape)
641
- key_states = key_states.view(*proj_shape)
642
- value_states = value_states.view(*v_proj_shape)
643
-
644
- source_len = key_states.size(1)
645
-
646
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
647
-
648
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
649
- raise ValueError(
650
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
651
- f" {attn_weights.size()}"
652
- )
653
-
654
- if attention_mask is not None:
655
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
656
- raise ValueError(
657
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
658
- f" {attention_mask.size()}"
659
- )
660
- if attention_mask.dtype == torch.bool:
661
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
662
- attention_mask, -torch.inf
663
- )
664
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
665
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
666
-
667
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
668
-
669
- if output_attentions:
670
- # this operation is a bit awkward, but it's required to
671
- # make sure that attn_weights keeps its gradient.
672
- # In order to do so, attn_weights have to reshaped
673
- # twice and have to be reused in the following
674
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
675
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
676
- else:
677
- attn_weights_reshaped = None
612
+ **kwargs: Unpack[TransformersKwargs],
613
+ ) -> tuple[torch.Tensor, torch.Tensor]:
614
+ """
615
+ Args:
616
+ hidden_states (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
617
+ Decoder hidden states (queries).
618
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, encoder_seq_len, hidden_size)`):
619
+ Encoder output hidden states (keys and values).
620
+ query_sine_embed (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
621
+ Sine position embeddings for queries. **Concatenated** (not added) with query content,
622
+ doubling the query dimension.
623
+ encoder_position_embeddings (`torch.Tensor` of shape `(batch_size, encoder_seq_len, hidden_size)`):
624
+ Position embeddings for keys. **Concatenated** (not added) with key content, doubling the key dimension.
625
+ query_position_embeddings (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
626
+ Additional position embeddings. When provided (first layer only), **added** to query content
627
+ before concatenation with `query_sine_embed`. Also causes `encoder_position_embeddings` to be
628
+ added to key content before concatenation.
629
+ attention_mask (`torch.Tensor` of shape `(batch_size, 1, num_queries, encoder_seq_len)`, *optional*):
630
+ Attention mask to avoid attending to padding tokens.
631
+ """
632
+ query_input_shape = hidden_states.shape[:-1]
633
+ kv_input_shape = encoder_hidden_states.shape[:-1]
634
+ query_hidden_shape = (*query_input_shape, self.num_attention_heads, self.head_dim)
635
+ kv_hidden_shape = (*kv_input_shape, self.num_attention_heads, self.head_dim)
636
+
637
+ # Apply content and position projections
638
+ query_input = self.q_content_proj(hidden_states)
639
+ key_input = self.k_content_proj(encoder_hidden_states)
640
+ value_states = self.v_proj(encoder_hidden_states)
641
+ key_pos = self.k_pos_proj(encoder_position_embeddings)
642
+
643
+ # Combine content and position embeddings
644
+ if query_position_embeddings is not None:
645
+ query_input = query_input + self.q_pos_proj(query_position_embeddings)
646
+ key_input = key_input + key_pos
647
+
648
+ # Reshape and concatenate position embeddings (doubling head_dim)
649
+ query_input = query_input.view(query_hidden_shape)
650
+ key_input = key_input.view(kv_hidden_shape)
651
+ query_sine_embed = self.q_pos_sine_proj(query_sine_embed).view(query_hidden_shape)
652
+ key_pos = key_pos.view(kv_hidden_shape)
653
+
654
+ query_states = torch.cat([query_input, query_sine_embed], dim=-1).view(*query_input_shape, -1)
655
+ key_states = torch.cat([key_input, key_pos], dim=-1).view(*kv_input_shape, -1)
656
+
657
+ # Reshape for attention computation
658
+ expanded_head_dim = query_states.shape[-1] // self.num_attention_heads
659
+ query_states = query_states.view(*query_input_shape, self.num_attention_heads, expanded_head_dim).transpose(
660
+ 1, 2
661
+ )
662
+ key_states = key_states.view(*kv_input_shape, self.num_attention_heads, expanded_head_dim).transpose(1, 2)
663
+ value_states = value_states.view(kv_hidden_shape).transpose(1, 2)
678
664
 
679
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
665
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
666
+ self.config._attn_implementation, eager_attention_forward
667
+ )
680
668
 
681
- attn_output = torch.bmm(attn_probs, value_states)
669
+ attn_output, attn_weights = attention_interface(
670
+ self,
671
+ query_states,
672
+ key_states,
673
+ value_states,
674
+ attention_mask,
675
+ dropout=0.0 if not self.training else self.attention_dropout,
676
+ scaling=self.scaling,
677
+ **kwargs,
678
+ )
682
679
 
683
- if attn_output.size() != (batch_size * self.num_heads, target_len, self.v_head_dim):
684
- raise ValueError(
685
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.v_head_dim)}, but is"
686
- f" {attn_output.size()}"
687
- )
680
+ attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
681
+ attn_output = self.o_proj(attn_output)
682
+ return attn_output, attn_weights
688
683
 
689
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.v_head_dim)
690
- attn_output = attn_output.transpose(1, 2)
691
- attn_output = attn_output.reshape(batch_size, target_len, self.out_dim)
692
684
 
693
- attn_output = self.out_proj(attn_output)
685
+ class ConditionalDetrMLP(nn.Module):
686
+ def __init__(self, config: ConditionalDetrConfig, hidden_size: int, intermediate_size: int):
687
+ super().__init__()
688
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
689
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
690
+ self.activation_fn = ACT2FN[config.activation_function]
691
+ self.activation_dropout = config.activation_dropout
692
+ self.dropout = config.dropout
694
693
 
695
- return attn_output, attn_weights_reshaped
694
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
695
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
696
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
697
+ hidden_states = self.fc2(hidden_states)
698
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
699
+ return hidden_states
696
700
 
697
701
 
698
- # Copied from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->ConditionalDetrEncoderLayer,DetrConfig->ConditionalDetrConfig
699
- class ConditionalDetrEncoderLayer(nn.Module):
702
+ class ConditionalDetrEncoderLayer(GradientCheckpointingLayer):
700
703
  def __init__(self, config: ConditionalDetrConfig):
701
704
  super().__init__()
702
- self.embed_dim = config.d_model
703
- self.self_attn = DetrAttention(
704
- embed_dim=self.embed_dim,
705
- num_heads=config.encoder_attention_heads,
705
+ self.hidden_size = config.d_model
706
+ self.self_attn = ConditionalDetrSelfAttention(
707
+ config=config,
708
+ hidden_size=self.hidden_size,
709
+ num_attention_heads=config.encoder_attention_heads,
706
710
  dropout=config.attention_dropout,
707
711
  )
708
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
712
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
709
713
  self.dropout = config.dropout
710
- self.activation_fn = ACT2FN[config.activation_function]
711
- self.activation_dropout = config.activation_dropout
712
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
713
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
714
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
714
+ self.mlp = ConditionalDetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
715
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
715
716
 
716
717
  def forward(
717
718
  self,
718
719
  hidden_states: torch.Tensor,
719
720
  attention_mask: torch.Tensor,
720
- object_queries: torch.Tensor | None = None,
721
- output_attentions: bool = False,
722
- ):
721
+ spatial_position_embeddings: torch.Tensor | None = None,
722
+ **kwargs: Unpack[TransformersKwargs],
723
+ ) -> torch.Tensor:
723
724
  """
724
725
  Args:
725
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
726
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
726
727
  attention_mask (`torch.FloatTensor`): attention mask of size
727
728
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
728
729
  values.
729
- object_queries (`torch.FloatTensor`, *optional*):
730
- Object queries (also called content embeddings), to be added to the hidden states.
731
- output_attentions (`bool`, *optional*):
732
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
733
- returned tensors for more detail.
730
+ spatial_position_embeddings (`torch.FloatTensor`, *optional*):
731
+ Spatial position embeddings (2D positional encodings of image locations), to be added to both
732
+ the queries and keys in self-attention (but not to values).
734
733
  """
735
734
  residual = hidden_states
736
- hidden_states, attn_weights = self.self_attn(
735
+ hidden_states, _ = self.self_attn(
737
736
  hidden_states=hidden_states,
738
737
  attention_mask=attention_mask,
739
- object_queries=object_queries,
740
- output_attentions=output_attentions,
738
+ position_embeddings=spatial_position_embeddings,
739
+ **kwargs,
741
740
  )
742
741
 
743
742
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -745,12 +744,7 @@ class ConditionalDetrEncoderLayer(nn.Module):
745
744
  hidden_states = self.self_attn_layer_norm(hidden_states)
746
745
 
747
746
  residual = hidden_states
748
- hidden_states = self.activation_fn(self.fc1(hidden_states))
749
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
750
-
751
- hidden_states = self.fc2(hidden_states)
752
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
753
-
747
+ hidden_states = self.mlp(hidden_states)
754
748
  hidden_states = residual + hidden_states
755
749
  hidden_states = self.final_layer_norm(hidden_states)
756
750
 
@@ -759,80 +753,55 @@ class ConditionalDetrEncoderLayer(nn.Module):
759
753
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
760
754
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
761
755
 
762
- outputs = (hidden_states,)
763
-
764
- if output_attentions:
765
- outputs += (attn_weights,)
766
-
767
- return outputs
756
+ return hidden_states
768
757
 
769
758
 
770
759
  class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
771
760
  def __init__(self, config: ConditionalDetrConfig):
772
761
  super().__init__()
773
- self.embed_dim = config.d_model
774
-
775
- d_model = config.d_model
776
- # Decoder Self-Attention projections
777
- self.sa_qcontent_proj = nn.Linear(d_model, d_model)
778
- self.sa_qpos_proj = nn.Linear(d_model, d_model)
779
- self.sa_kcontent_proj = nn.Linear(d_model, d_model)
780
- self.sa_kpos_proj = nn.Linear(d_model, d_model)
781
- self.sa_v_proj = nn.Linear(d_model, d_model)
782
-
783
- self.self_attn = ConditionalDetrAttention(
784
- embed_dim=self.embed_dim,
785
- out_dim=self.embed_dim,
786
- num_heads=config.decoder_attention_heads,
762
+ self.hidden_size = config.d_model
763
+ self.self_attn = ConditionalDetrDecoderSelfAttention(
764
+ config=config,
765
+ hidden_size=self.hidden_size,
766
+ num_attention_heads=config.decoder_attention_heads,
787
767
  dropout=config.attention_dropout,
788
768
  )
789
769
  self.dropout = config.dropout
790
- self.activation_fn = ACT2FN[config.activation_function]
791
- self.activation_dropout = config.activation_dropout
792
770
 
793
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
794
-
795
- # Decoder Cross-Attention projections
796
- self.ca_qcontent_proj = nn.Linear(d_model, d_model)
797
- self.ca_qpos_proj = nn.Linear(d_model, d_model)
798
- self.ca_kcontent_proj = nn.Linear(d_model, d_model)
799
- self.ca_kpos_proj = nn.Linear(d_model, d_model)
800
- self.ca_v_proj = nn.Linear(d_model, d_model)
801
- self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
802
-
803
- self.encoder_attn = ConditionalDetrAttention(
804
- self.embed_dim * 2, self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout
771
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
772
+ self.encoder_attn = ConditionalDetrDecoderCrossAttention(
773
+ config=config,
774
+ hidden_size=self.hidden_size,
775
+ num_attention_heads=config.decoder_attention_heads,
776
+ dropout=config.attention_dropout,
805
777
  )
806
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
807
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
808
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
809
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
810
- self.nhead = config.decoder_attention_heads
778
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
779
+ self.mlp = ConditionalDetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
780
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
811
781
 
812
782
  def forward(
813
783
  self,
814
784
  hidden_states: torch.Tensor,
815
785
  attention_mask: torch.Tensor | None = None,
816
- object_queries: torch.Tensor | None = None,
786
+ spatial_position_embeddings: torch.Tensor | None = None,
817
787
  query_position_embeddings: torch.Tensor | None = None,
818
788
  query_sine_embed: torch.Tensor | None = None,
819
789
  encoder_hidden_states: torch.Tensor | None = None,
820
790
  encoder_attention_mask: torch.Tensor | None = None,
821
- output_attentions: bool | None = False,
822
791
  is_first: bool | None = False,
823
- ):
792
+ **kwargs: Unpack[TransformersKwargs],
793
+ ) -> torch.Tensor:
824
794
  """
825
795
  Args:
826
796
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
827
797
  attention_mask (`torch.FloatTensor`): attention mask of size
828
798
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
829
799
  values.
830
- object_queries (`torch.FloatTensor`, *optional*):
831
- object_queries that are added to the queries and keys
832
- in the cross-attention layer.
800
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
801
+ Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
833
802
  query_position_embeddings (`torch.FloatTensor`, *optional*):
834
803
  object_queries that are added to the queries and keys
835
- in the self-attention layer.
804
+ in the self-attention layer.
836
805
  encoder_hidden_states (`torch.FloatTensor`):
837
806
  cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
838
807
  encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
@@ -844,108 +813,49 @@ class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
844
813
  """
845
814
  residual = hidden_states
846
815
 
847
- # ========== Begin of Self-Attention =============
848
- # Apply projections here
849
- # shape: num_queries x batch_size x 256
850
- q_content = self.sa_qcontent_proj(
851
- hidden_states
852
- ) # target is the input of the first decoder layer. zero by default.
853
- q_pos = self.sa_qpos_proj(query_position_embeddings)
854
- k_content = self.sa_kcontent_proj(hidden_states)
855
- k_pos = self.sa_kpos_proj(query_position_embeddings)
856
- v = self.sa_v_proj(hidden_states)
857
-
858
- _, num_queries, n_model = q_content.shape
859
-
860
- q = q_content + q_pos
861
- k = k_content + k_pos
862
- hidden_states, self_attn_weights = self.self_attn(
863
- hidden_states=q,
816
+ hidden_states, _ = self.self_attn(
817
+ hidden_states=hidden_states,
818
+ query_position_embeddings=query_position_embeddings,
864
819
  attention_mask=attention_mask,
865
- key_states=k,
866
- value_states=v,
867
- output_attentions=output_attentions,
820
+ **kwargs,
868
821
  )
869
- # ============ End of Self-Attention =============
870
822
 
871
823
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
872
824
  hidden_states = residual + hidden_states
873
825
  hidden_states = self.self_attn_layer_norm(hidden_states)
874
826
 
875
- # ========== Begin of Cross-Attention =============
876
- # Apply projections here
877
- # shape: num_queries x batch_size x 256
878
- q_content = self.ca_qcontent_proj(hidden_states)
879
- k_content = self.ca_kcontent_proj(encoder_hidden_states)
880
- v = self.ca_v_proj(encoder_hidden_states)
881
-
882
- batch_size, num_queries, n_model = q_content.shape
883
- _, source_len, _ = k_content.shape
884
-
885
- k_pos = self.ca_kpos_proj(object_queries)
886
-
887
- # For the first decoder layer, we concatenate the positional embedding predicted from
888
- # the object query (the positional embedding) into the original query (key) in DETR.
889
- if is_first:
890
- q_pos = self.ca_qpos_proj(query_position_embeddings)
891
- q = q_content + q_pos
892
- k = k_content + k_pos
893
- else:
894
- q = q_content
895
- k = k_content
896
-
897
- q = q.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
898
- query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
899
- query_sine_embed = query_sine_embed.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
900
- q = torch.cat([q, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2)
901
- k = k.view(batch_size, source_len, self.nhead, n_model // self.nhead)
902
- k_pos = k_pos.view(batch_size, source_len, self.nhead, n_model // self.nhead)
903
- k = torch.cat([k, k_pos], dim=3).view(batch_size, source_len, n_model * 2)
904
-
905
- # Cross-Attention Block
906
- cross_attn_weights = None
907
827
  if encoder_hidden_states is not None:
908
828
  residual = hidden_states
909
829
 
910
- hidden_states, cross_attn_weights = self.encoder_attn(
911
- hidden_states=q,
830
+ hidden_states, _ = self.encoder_attn(
831
+ hidden_states=hidden_states,
832
+ encoder_hidden_states=encoder_hidden_states,
912
833
  attention_mask=encoder_attention_mask,
913
- key_states=k,
914
- value_states=v,
915
- output_attentions=output_attentions,
834
+ query_sine_embed=query_sine_embed,
835
+ encoder_position_embeddings=spatial_position_embeddings,
836
+ # Only pass query_position_embeddings for the first layer
837
+ query_position_embeddings=query_position_embeddings if is_first else None,
838
+ **kwargs,
916
839
  )
917
840
 
918
841
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
919
842
  hidden_states = residual + hidden_states
920
843
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
921
844
 
922
- # ============ End of Cross-Attention =============
923
-
924
845
  # Fully Connected
925
846
  residual = hidden_states
926
- hidden_states = self.activation_fn(self.fc1(hidden_states))
927
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
928
- hidden_states = self.fc2(hidden_states)
929
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
847
+ hidden_states = self.mlp(hidden_states)
930
848
  hidden_states = residual + hidden_states
931
849
  hidden_states = self.final_layer_norm(hidden_states)
932
850
 
933
- outputs = (hidden_states,)
851
+ return hidden_states
934
852
 
935
- if output_attentions:
936
- outputs += (self_attn_weights, cross_attn_weights)
937
853
 
938
- return outputs
939
-
940
-
941
- # Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with DetrMLPPredictionHead->MLP
942
- class MLP(nn.Module):
854
+ class ConditionalDetrMLPPredictionHead(nn.Module):
943
855
  """
944
856
  Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
945
857
  height and width of a bounding box w.r.t. an image.
946
858
 
947
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
948
-
949
859
  """
950
860
 
951
861
  def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
@@ -960,29 +870,202 @@ class MLP(nn.Module):
960
870
  return x
961
871
 
962
872
 
873
+ class ConditionalDetrConvBlock(nn.Module):
874
+ """Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
875
+
876
+ def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
877
+ super().__init__()
878
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
879
+ self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
880
+ self.activation = ACT2FN[activation]
881
+
882
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
883
+ return self.activation(self.norm(self.conv(x)))
884
+
885
+
886
+ class ConditionalDetrFPNFusionStage(nn.Module):
887
+ """Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
888
+
889
+ def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
890
+ super().__init__()
891
+ self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
892
+ self.refine = ConditionalDetrConvBlock(current_channels, output_channels, activation)
893
+
894
+ def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
895
+ """
896
+ Args:
897
+ features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
898
+ fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
899
+
900
+ Returns:
901
+ Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
902
+ """
903
+ fpn_features = self.fpn_adapter(fpn_features)
904
+ features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
905
+ return self.refine(fpn_features + features)
906
+
907
+
908
+ class ConditionalDetrMaskHeadSmallConv(nn.Module):
909
+ """
910
+ Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
911
+
912
+ Combines attention maps (spatial localization) with encoder features (semantics) and progressively
913
+ upsamples through multiple scales, fusing with FPN features for high-resolution detail.
914
+ """
915
+
916
+ def __init__(
917
+ self,
918
+ input_channels: int,
919
+ fpn_channels: list[int],
920
+ hidden_size: int,
921
+ activation_function: str = "relu",
922
+ ):
923
+ super().__init__()
924
+ if input_channels % 8 != 0:
925
+ raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
926
+
927
+ self.conv1 = ConditionalDetrConvBlock(input_channels, input_channels, activation_function)
928
+ self.conv2 = ConditionalDetrConvBlock(input_channels, hidden_size // 2, activation_function)
929
+
930
+ # Progressive channel reduction: /2 -> /4 -> /8 -> /16
931
+ self.fpn_stages = nn.ModuleList(
932
+ [
933
+ ConditionalDetrFPNFusionStage(
934
+ fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function
935
+ ),
936
+ ConditionalDetrFPNFusionStage(
937
+ fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function
938
+ ),
939
+ ConditionalDetrFPNFusionStage(
940
+ fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function
941
+ ),
942
+ ]
943
+ )
944
+
945
+ self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
946
+
947
+ def forward(
948
+ self,
949
+ features: torch.Tensor,
950
+ attention_masks: torch.Tensor,
951
+ fpn_features: list[torch.Tensor],
952
+ ) -> torch.Tensor:
953
+ """
954
+ Args:
955
+ features: Encoder output features, shape (batch_size, hidden_size, H, W)
956
+ attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
957
+ fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
958
+
959
+ Returns:
960
+ Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
961
+ """
962
+ num_queries = attention_masks.shape[1]
963
+
964
+ # Expand to (batch_size * num_queries) dimension
965
+ features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
966
+ attention_masks = attention_masks.flatten(0, 1)
967
+ fpn_features = [
968
+ fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
969
+ ]
970
+
971
+ hidden_states = torch.cat([features, attention_masks], dim=1)
972
+ hidden_states = self.conv1(hidden_states)
973
+ hidden_states = self.conv2(hidden_states)
974
+
975
+ for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
976
+ hidden_states = fpn_stage(hidden_states, fpn_feat)
977
+
978
+ return self.output_conv(hidden_states)
979
+
980
+
981
+ class ConditionalDetrMHAttentionMap(nn.Module):
982
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
983
+
984
+ def __init__(
985
+ self,
986
+ hidden_size: int,
987
+ num_attention_heads: int,
988
+ dropout: float = 0.0,
989
+ bias: bool = True,
990
+ ):
991
+ super().__init__()
992
+ self.head_dim = hidden_size // num_attention_heads
993
+ self.scaling = self.head_dim**-0.5
994
+ self.attention_dropout = dropout
995
+
996
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
997
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
998
+
999
+ def forward(
1000
+ self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
1001
+ ):
1002
+ query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
1003
+ key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
1004
+
1005
+ query_states = self.q_proj(query_states).view(query_hidden_shape)
1006
+ key_states = nn.functional.conv2d(
1007
+ key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
1008
+ ).view(key_hidden_shape)
1009
+
1010
+ batch_size, num_queries, num_heads, head_dim = query_states.shape
1011
+ _, _, _, height, width = key_states.shape
1012
+ query_shape = (batch_size * num_heads, num_queries, head_dim)
1013
+ key_shape = (batch_size * num_heads, height * width, head_dim)
1014
+ attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
1015
+
1016
+ query = query_states.transpose(1, 2).contiguous().view(query_shape)
1017
+ key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
1018
+
1019
+ attn_weights = (
1020
+ (torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
1021
+ )
1022
+
1023
+ if attention_mask is not None:
1024
+ attn_weights = attn_weights + attention_mask
1025
+
1026
+ attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
1027
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
1028
+
1029
+ return attn_weights
1030
+
1031
+
963
1032
  @auto_docstring
964
- # Copied from transformers.models.detr.modeling_detr.DetrPreTrainedModel with Detr->ConditionalDetr
965
1033
  class ConditionalDetrPreTrainedModel(PreTrainedModel):
966
1034
  config: ConditionalDetrConfig
967
1035
  base_model_prefix = "model"
968
1036
  main_input_name = "pixel_values"
969
1037
  input_modalities = ("image",)
970
1038
  _no_split_modules = [r"ConditionalDetrConvEncoder", r"ConditionalDetrEncoderLayer", r"ConditionalDetrDecoderLayer"]
1039
+ supports_gradient_checkpointing = True
1040
+ _supports_sdpa = True
1041
+ _supports_flash_attn = True
1042
+ _supports_attention_backend = True
1043
+ _supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
1044
+ _keys_to_ignore_on_load_unexpected = [
1045
+ r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
1046
+ ]
971
1047
 
972
1048
  @torch.no_grad()
973
1049
  def _init_weights(self, module):
974
1050
  std = self.config.init_std
975
1051
  xavier_std = self.config.init_xavier_std
976
1052
 
977
- if isinstance(module, ConditionalDetrMHAttentionMap):
978
- init.zeros_(module.k_linear.bias)
979
- init.zeros_(module.q_linear.bias)
980
- init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
981
- init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
1053
+ if isinstance(module, ConditionalDetrMaskHeadSmallConv):
1054
+ # ConditionalDetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
1055
+ for m in module.modules():
1056
+ if isinstance(m, nn.Conv2d):
1057
+ init.kaiming_uniform_(m.weight, a=1)
1058
+ if m.bias is not None:
1059
+ init.constant_(m.bias, 0)
1060
+ elif isinstance(module, ConditionalDetrMHAttentionMap):
1061
+ init.zeros_(module.k_proj.bias)
1062
+ init.zeros_(module.q_proj.bias)
1063
+ init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
1064
+ init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
982
1065
  elif isinstance(module, ConditionalDetrLearnedPositionEmbedding):
983
1066
  init.uniform_(module.row_embeddings.weight)
984
1067
  init.uniform_(module.column_embeddings.weight)
985
- if isinstance(module, (nn.Linear, nn.Conv2d)):
1068
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
986
1069
  init.normal_(module.weight, mean=0.0, std=std)
987
1070
  if module.bias is not None:
988
1071
  init.zeros_(module.bias)
@@ -996,50 +1079,38 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
996
1079
  init.zeros_(module.bias)
997
1080
 
998
1081
 
999
- # Copied from transformers.models.detr.modeling_detr.DetrEncoder with Detr->ConditionalDetr,DETR->ConditionalDETR
1000
1082
  class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
1001
1083
  """
1002
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
1003
- [`ConditionalDetrEncoderLayer`].
1004
-
1005
- The encoder updates the flattened feature map through multiple self-attention layers.
1006
-
1007
- Small tweak for ConditionalDETR:
1008
-
1009
- - object_queries are added to the forward pass.
1084
+ Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
1085
+ [`ConditionalDetrEncoderLayer`] modules.
1010
1086
 
1011
1087
  Args:
1012
- config: ConditionalDetrConfig
1088
+ config (`ConditionalDetrConfig`): Model configuration object.
1013
1089
  """
1014
1090
 
1091
+ _can_record_outputs = {"hidden_states": ConditionalDetrEncoderLayer, "attentions": ConditionalDetrSelfAttention}
1092
+
1015
1093
  def __init__(self, config: ConditionalDetrConfig):
1016
1094
  super().__init__(config)
1017
1095
 
1018
1096
  self.dropout = config.dropout
1019
- self.layerdrop = config.encoder_layerdrop
1020
-
1021
1097
  self.layers = nn.ModuleList([ConditionalDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
1022
1098
 
1023
- # in the original ConditionalDETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
1024
-
1025
1099
  # Initialize weights and apply final processing
1026
1100
  self.post_init()
1027
1101
 
1102
+ @check_model_inputs()
1028
1103
  def forward(
1029
1104
  self,
1030
1105
  inputs_embeds=None,
1031
1106
  attention_mask=None,
1032
- object_queries=None,
1033
- output_attentions=None,
1034
- output_hidden_states=None,
1035
- return_dict=None,
1036
- **kwargs,
1037
- ):
1107
+ spatial_position_embeddings=None,
1108
+ **kwargs: Unpack[TransformersKwargs],
1109
+ ) -> BaseModelOutput:
1038
1110
  r"""
1039
1111
  Args:
1040
1112
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1041
1113
  Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
1042
-
1043
1114
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1044
1115
  Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
1045
1116
 
@@ -1047,69 +1118,44 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
1047
1118
  - 0 for pixel features that are padding (i.e. **masked**).
1048
1119
 
1049
1120
  [What are attention masks?](../glossary#attention-mask)
1050
-
1051
- object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1052
- Object queries that are added to the queries in each self-attention layer.
1053
-
1054
- output_attentions (`bool`, *optional*):
1055
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1056
- returned tensors for more detail.
1057
- output_hidden_states (`bool`, *optional*):
1058
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1059
- for more detail.
1060
- return_dict (`bool`, *optional*):
1061
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1121
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1122
+ Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
1062
1123
  """
1063
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1064
- output_hidden_states = (
1065
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1066
- )
1067
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1068
-
1069
1124
  hidden_states = inputs_embeds
1070
1125
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1071
1126
 
1072
1127
  # expand attention_mask
1073
1128
  if attention_mask is not None:
1074
1129
  # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
1075
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1076
-
1077
- encoder_states = () if output_hidden_states else None
1078
- all_attentions = () if output_attentions else None
1079
- for i, encoder_layer in enumerate(self.layers):
1080
- if output_hidden_states:
1081
- encoder_states = encoder_states + (hidden_states,)
1082
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
1083
- to_drop = False
1084
- if self.training:
1085
- dropout_probability = torch.rand([])
1086
- if dropout_probability < self.layerdrop: # skip the layer
1087
- to_drop = True
1130
+ attention_mask = create_bidirectional_mask(
1131
+ config=self.config,
1132
+ input_embeds=inputs_embeds,
1133
+ attention_mask=attention_mask,
1134
+ )
1088
1135
 
1089
- if to_drop:
1090
- layer_outputs = (None, None)
1091
- else:
1092
- # we add object_queries as extra input to the encoder_layer
1093
- layer_outputs = encoder_layer(
1094
- hidden_states,
1095
- attention_mask,
1096
- object_queries=object_queries,
1097
- output_attentions=output_attentions,
1098
- )
1099
-
1100
- hidden_states = layer_outputs[0]
1101
-
1102
- if output_attentions:
1103
- all_attentions = all_attentions + (layer_outputs[1],)
1104
-
1105
- if output_hidden_states:
1106
- encoder_states = encoder_states + (hidden_states,)
1107
-
1108
- if not return_dict:
1109
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1110
- return BaseModelOutput(
1111
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
1112
- )
1136
+ for encoder_layer in self.layers:
1137
+ # we add spatial_position_embeddings as extra input to the encoder_layer
1138
+ hidden_states = encoder_layer(
1139
+ hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
1140
+ )
1141
+
1142
+ return BaseModelOutput(last_hidden_state=hidden_states)
1143
+
1144
+
1145
+ # function to generate sine positional embedding for 2d coordinates
1146
+ def gen_sine_position_embeddings(pos_tensor, d_model):
1147
+ scale = 2 * math.pi
1148
+ dim = d_model // 2
1149
+ dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
1150
+ dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
1151
+ x_embed = pos_tensor[:, :, 0] * scale
1152
+ y_embed = pos_tensor[:, :, 1] * scale
1153
+ pos_x = x_embed[:, :, None] / dim_t
1154
+ pos_y = y_embed[:, :, None] / dim_t
1155
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
1156
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
1157
+ pos = torch.cat((pos_y, pos_x), dim=2)
1158
+ return pos.to(pos_tensor.dtype)
1113
1159
 
1114
1160
 
1115
1161
  class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
@@ -1127,39 +1173,44 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1127
1173
  config: ConditionalDetrConfig
1128
1174
  """
1129
1175
 
1176
+ _can_record_outputs = {
1177
+ "hidden_states": ConditionalDetrDecoderLayer,
1178
+ "attentions": OutputRecorder(ConditionalDetrDecoderSelfAttention, layer_name="self_attn", index=1),
1179
+ "cross_attentions": OutputRecorder(ConditionalDetrDecoderCrossAttention, layer_name="encoder_attn", index=1),
1180
+ }
1181
+
1130
1182
  def __init__(self, config: ConditionalDetrConfig):
1131
1183
  super().__init__(config)
1184
+ self.hidden_size = config.d_model
1185
+
1132
1186
  self.dropout = config.dropout
1133
1187
  self.layerdrop = config.decoder_layerdrop
1134
1188
 
1135
1189
  self.layers = nn.ModuleList([ConditionalDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
1136
1190
  # in Conditional DETR, the decoder uses layernorm after the last decoder layer output
1137
1191
  self.layernorm = nn.LayerNorm(config.d_model)
1138
- d_model = config.d_model
1139
- self.gradient_checkpointing = False
1140
1192
 
1141
1193
  # query_scale is the FFN applied on f to generate transformation T
1142
- self.query_scale = MLP(d_model, d_model, d_model, 2)
1143
- self.ref_point_head = MLP(d_model, d_model, 2, 2)
1194
+ self.query_scale = ConditionalDetrMLPPredictionHead(self.hidden_size, self.hidden_size, self.hidden_size, 2)
1195
+ self.ref_point_head = ConditionalDetrMLPPredictionHead(self.hidden_size, self.hidden_size, 2, 2)
1144
1196
  for layer_id in range(config.decoder_layers - 1):
1145
- self.layers[layer_id + 1].ca_qpos_proj = None
1197
+ # Set q_pos_proj to None for layers after the first (only first layer uses query position embeddings)
1198
+ self.layers[layer_id + 1].encoder_attn.q_pos_proj = None
1146
1199
 
1147
1200
  # Initialize weights and apply final processing
1148
1201
  self.post_init()
1149
1202
 
1203
+ @check_model_inputs()
1150
1204
  def forward(
1151
1205
  self,
1152
1206
  inputs_embeds=None,
1153
1207
  attention_mask=None,
1154
1208
  encoder_hidden_states=None,
1155
1209
  encoder_attention_mask=None,
1156
- object_queries=None,
1157
- query_position_embeddings=None,
1158
- output_attentions=None,
1159
- output_hidden_states=None,
1160
- return_dict=None,
1161
- **kwargs,
1162
- ):
1210
+ spatial_position_embeddings=None,
1211
+ object_queries_position_embeddings=None,
1212
+ **kwargs: Unpack[TransformersKwargs],
1213
+ ) -> ConditionalDetrDecoderOutput:
1163
1214
  r"""
1164
1215
  Args:
1165
1216
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -1182,46 +1233,28 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1182
1233
  - 1 for pixels that are real (i.e. **not masked**),
1183
1234
  - 0 for pixels that are padding (i.e. **masked**).
1184
1235
 
1185
- object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1186
- Position embeddings that are added to the queries and keys in each cross-attention layer.
1187
- query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
1236
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1237
+ Spatial position embeddings that are added to the queries and keys in each cross-attention layer.
1238
+ object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
1188
1239
  , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
1189
- output_attentions (`bool`, *optional*):
1190
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1191
- returned tensors for more detail.
1192
- output_hidden_states (`bool`, *optional*):
1193
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1194
- for more detail.
1195
- return_dict (`bool`, *optional*):
1196
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1197
1240
  """
1198
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1199
- output_hidden_states = (
1200
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1201
- )
1202
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1203
-
1204
1241
  if inputs_embeds is not None:
1205
1242
  hidden_states = inputs_embeds
1206
- input_shape = inputs_embeds.size()[:-1]
1207
1243
 
1208
1244
  # expand encoder attention mask
1209
1245
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1210
1246
  # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
1211
- encoder_attention_mask = _prepare_4d_attention_mask(
1212
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1247
+ encoder_attention_mask = create_bidirectional_mask(
1248
+ self.config,
1249
+ inputs_embeds,
1250
+ encoder_attention_mask,
1213
1251
  )
1214
1252
 
1215
1253
  # optional intermediate hidden states
1216
1254
  intermediate = () if self.config.auxiliary_loss else None
1217
1255
 
1218
- # decoder layers
1219
- all_hidden_states = () if output_hidden_states else None
1220
- all_self_attns = () if output_attentions else None
1221
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1222
-
1223
1256
  reference_points_before_sigmoid = self.ref_point_head(
1224
- query_position_embeddings
1257
+ object_queries_position_embeddings
1225
1258
  ) # [num_queries, batch_size, 2]
1226
1259
  reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1)
1227
1260
  obj_center = reference_points[..., :2].transpose(0, 1)
@@ -1229,9 +1262,6 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1229
1262
  query_sine_embed_before_transformation = gen_sine_position_embeddings(obj_center, self.config.d_model)
1230
1263
 
1231
1264
  for idx, decoder_layer in enumerate(self.layers):
1232
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
1233
- if output_hidden_states:
1234
- all_hidden_states += (hidden_states,)
1235
1265
  if self.training:
1236
1266
  dropout_probability = torch.rand([])
1237
1267
  if dropout_probability < self.layerdrop:
@@ -1243,59 +1273,31 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1243
1273
  # apply transformation
1244
1274
  query_sine_embed = query_sine_embed_before_transformation * pos_transformation
1245
1275
 
1246
- layer_outputs = decoder_layer(
1276
+ hidden_states = decoder_layer(
1247
1277
  hidden_states,
1248
- None, # attention_mask
1249
- object_queries,
1250
- query_position_embeddings,
1278
+ None,
1279
+ spatial_position_embeddings,
1280
+ object_queries_position_embeddings,
1251
1281
  query_sine_embed,
1252
1282
  encoder_hidden_states, # as a positional argument for gradient checkpointing
1253
1283
  encoder_attention_mask=encoder_attention_mask,
1254
- output_attentions=output_attentions,
1255
1284
  is_first=(idx == 0),
1285
+ **kwargs,
1256
1286
  )
1257
1287
 
1258
- hidden_states = layer_outputs[0]
1259
-
1260
1288
  if self.config.auxiliary_loss:
1261
1289
  hidden_states = self.layernorm(hidden_states)
1262
1290
  intermediate += (hidden_states,)
1263
1291
 
1264
- if output_attentions:
1265
- all_self_attns += (layer_outputs[1],)
1266
-
1267
- if encoder_hidden_states is not None:
1268
- all_cross_attentions += (layer_outputs[2],)
1269
-
1270
1292
  # finally, apply layernorm
1271
1293
  hidden_states = self.layernorm(hidden_states)
1272
1294
 
1273
- # add hidden states from the last decoder layer
1274
- if output_hidden_states:
1275
- all_hidden_states += (hidden_states,)
1276
-
1277
1295
  # stack intermediate decoder activations
1278
1296
  if self.config.auxiliary_loss:
1279
1297
  intermediate = torch.stack(intermediate)
1280
1298
 
1281
- if not return_dict:
1282
- return tuple(
1283
- v
1284
- for v in [
1285
- hidden_states,
1286
- all_hidden_states,
1287
- all_self_attns,
1288
- all_cross_attentions,
1289
- intermediate,
1290
- reference_points,
1291
- ]
1292
- if v is not None
1293
- )
1294
1299
  return ConditionalDetrDecoderOutput(
1295
1300
  last_hidden_state=hidden_states,
1296
- hidden_states=all_hidden_states,
1297
- attentions=all_self_attns,
1298
- cross_attentions=all_cross_attentions,
1299
1301
  intermediate_hidden_states=intermediate,
1300
1302
  reference_points=reference_points,
1301
1303
  )
@@ -1303,23 +1305,24 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1303
1305
 
1304
1306
  @auto_docstring(
1305
1307
  custom_intro="""
1306
- The bare Conditional DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw
1307
- hidden-states without any specific head on top.
1308
+ The bare CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
1309
+ any specific head on top.
1308
1310
  """
1309
1311
  )
1310
1312
  class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1311
1313
  def __init__(self, config: ConditionalDetrConfig):
1312
1314
  super().__init__(config)
1313
1315
 
1314
- # Create backbone + positional encoding
1315
- backbone = ConditionalDetrConvEncoder(config)
1316
- object_queries = build_position_encoding(config)
1317
- self.backbone = ConditionalDetrConvModel(backbone, object_queries)
1318
-
1319
- # Create projection layer
1320
- self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
1316
+ self.backbone = ConditionalDetrConvEncoder(config)
1321
1317
 
1318
+ if config.position_embedding_type == "sine":
1319
+ self.position_embedding = ConditionalDetrSinePositionEmbedding(config.d_model // 2, normalize=True)
1320
+ elif config.position_embedding_type == "learned":
1321
+ self.position_embedding = ConditionalDetrLearnedPositionEmbedding(config.d_model // 2)
1322
+ else:
1323
+ raise ValueError(f"Not supported {config.position_embedding_type}")
1322
1324
  self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
1325
+ self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
1323
1326
 
1324
1327
  self.encoder = ConditionalDetrEncoder(config)
1325
1328
  self.decoder = ConditionalDetrDecoder(config)
@@ -1328,14 +1331,15 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1328
1331
  self.post_init()
1329
1332
 
1330
1333
  def freeze_backbone(self):
1331
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1334
+ for _, param in self.backbone.model.named_parameters():
1332
1335
  param.requires_grad_(False)
1333
1336
 
1334
1337
  def unfreeze_backbone(self):
1335
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1338
+ for _, param in self.backbone.model.named_parameters():
1336
1339
  param.requires_grad_(True)
1337
1340
 
1338
1341
  @auto_docstring
1342
+ @can_return_tuple
1339
1343
  def forward(
1340
1344
  self,
1341
1345
  pixel_values: torch.FloatTensor,
@@ -1344,11 +1348,8 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1344
1348
  encoder_outputs: torch.FloatTensor | None = None,
1345
1349
  inputs_embeds: torch.FloatTensor | None = None,
1346
1350
  decoder_inputs_embeds: torch.FloatTensor | None = None,
1347
- output_attentions: bool | None = None,
1348
- output_hidden_states: bool | None = None,
1349
- return_dict: bool | None = None,
1350
- **kwargs,
1351
- ) -> tuple[torch.FloatTensor] | ConditionalDetrModelOutput:
1351
+ **kwargs: Unpack[TransformersKwargs],
1352
+ ) -> ConditionalDetrModelOutput:
1352
1353
  r"""
1353
1354
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1354
1355
  Not used by default. Can be used to mask object queries.
@@ -1384,12 +1385,6 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1384
1385
  >>> list(last_hidden_states.shape)
1385
1386
  [1, 300, 256]
1386
1387
  ```"""
1387
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1388
- output_hidden_states = (
1389
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1390
- )
1391
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1392
-
1393
1388
  batch_size, num_channels, height, width = pixel_values.shape
1394
1389
  device = pixel_values.device
1395
1390
 
@@ -1399,7 +1394,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1399
1394
  # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1400
1395
  # pixel_values should be of shape (batch_size, num_channels, height, width)
1401
1396
  # pixel_mask should be of shape (batch_size, height, width)
1402
- features, object_queries_list = self.backbone(pixel_values, pixel_mask)
1397
+ features = self.backbone(pixel_values, pixel_mask)
1403
1398
 
1404
1399
  # get final feature map and downsampled mask
1405
1400
  feature_map, mask = features[-1]
@@ -1410,53 +1405,52 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1410
1405
  # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1411
1406
  projected_feature_map = self.input_projection(feature_map)
1412
1407
 
1413
- # Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1408
+ # Generate position embeddings
1409
+ spatial_position_embeddings = self.position_embedding(
1410
+ shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
1411
+ )
1412
+
1413
+ # Third, flatten the feature map of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1414
1414
  # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1415
1415
  flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1416
- object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1417
1416
 
1418
1417
  flattened_mask = mask.flatten(1)
1419
1418
 
1420
- # Fourth, sent flattened_features + flattened_mask + object_queries through encoder
1419
+ # Fourth, sent flattened_features + flattened_mask + spatial_position_embeddings through encoder
1421
1420
  # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
1422
1421
  # flattened_mask is a Tensor of shape (batch_size, height*width)
1423
1422
  if encoder_outputs is None:
1424
1423
  encoder_outputs = self.encoder(
1425
1424
  inputs_embeds=flattened_features,
1426
1425
  attention_mask=flattened_mask,
1427
- object_queries=object_queries,
1428
- output_attentions=output_attentions,
1429
- output_hidden_states=output_hidden_states,
1430
- return_dict=return_dict,
1426
+ spatial_position_embeddings=spatial_position_embeddings,
1427
+ **kwargs,
1431
1428
  )
1432
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1433
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1429
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
1430
+ elif not isinstance(encoder_outputs, BaseModelOutput):
1434
1431
  encoder_outputs = BaseModelOutput(
1435
1432
  last_hidden_state=encoder_outputs[0],
1436
1433
  hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1437
1434
  attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1438
1435
  )
1439
1436
 
1440
- # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
1441
- query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
1442
- queries = torch.zeros_like(query_position_embeddings)
1437
+ # Fifth, sent query embeddings through the decoder (which is conditioned on the encoder output)
1438
+ object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
1439
+ batch_size, 1, 1
1440
+ )
1441
+ queries = torch.zeros_like(object_queries_position_embeddings)
1443
1442
 
1444
1443
  # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1445
1444
  decoder_outputs = self.decoder(
1446
1445
  inputs_embeds=queries,
1447
1446
  attention_mask=None,
1448
- object_queries=object_queries,
1449
- query_position_embeddings=query_position_embeddings,
1450
- encoder_hidden_states=encoder_outputs[0],
1447
+ spatial_position_embeddings=spatial_position_embeddings,
1448
+ object_queries_position_embeddings=object_queries_position_embeddings,
1449
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
1451
1450
  encoder_attention_mask=flattened_mask,
1452
- output_attentions=output_attentions,
1453
- output_hidden_states=output_hidden_states,
1454
- return_dict=return_dict,
1451
+ **kwargs,
1455
1452
  )
1456
1453
 
1457
- if not return_dict:
1458
- return decoder_outputs + encoder_outputs
1459
-
1460
1454
  return ConditionalDetrModelOutput(
1461
1455
  last_hidden_state=decoder_outputs.last_hidden_state,
1462
1456
  decoder_hidden_states=decoder_outputs.hidden_states,
@@ -1470,45 +1464,26 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1470
1464
  )
1471
1465
 
1472
1466
 
1473
- # Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->ConditionalDetr
1474
- class ConditionalDetrMLPPredictionHead(nn.Module):
1475
- """
1476
- Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1477
- height and width of a bounding box w.r.t. an image.
1478
-
1479
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1480
-
1481
- """
1482
-
1483
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
1484
- super().__init__()
1485
- self.num_layers = num_layers
1486
- h = [hidden_dim] * (num_layers - 1)
1487
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1488
-
1489
- def forward(self, x):
1490
- for i, layer in enumerate(self.layers):
1491
- x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1492
- return x
1467
+ def inverse_sigmoid(x, eps=1e-5):
1468
+ x = x.clamp(min=0, max=1)
1469
+ x1 = x.clamp(min=eps)
1470
+ x2 = (1 - x).clamp(min=eps)
1471
+ return torch.log(x1 / x2)
1493
1472
 
1494
1473
 
1495
1474
  @auto_docstring(
1496
1475
  custom_intro="""
1497
- Conditional DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on
1498
- top, for tasks such as COCO detection.
1476
+ CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
1477
+ such as COCO detection.
1499
1478
  """
1500
1479
  )
1501
1480
  class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1502
1481
  def __init__(self, config: ConditionalDetrConfig):
1503
1482
  super().__init__(config)
1504
1483
 
1505
- # CONDITIONAL DETR encoder-decoder model
1484
+ # CONDITIONAL_DETR encoder-decoder model
1506
1485
  self.model = ConditionalDetrModel(config)
1507
-
1508
- # Object detection heads
1509
- self.class_labels_classifier = nn.Linear(
1510
- config.d_model, config.num_labels
1511
- ) # We add one for the "no object" class
1486
+ self.class_labels_classifier = nn.Linear(config.d_model, config.num_labels)
1512
1487
  self.bbox_predictor = ConditionalDetrMLPPredictionHead(
1513
1488
  input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
1514
1489
  )
@@ -1516,11 +1491,8 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1516
1491
  # Initialize weights and apply final processing
1517
1492
  self.post_init()
1518
1493
 
1519
- # taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
1520
- def _set_aux_loss(self, outputs_class, outputs_coord):
1521
- return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
1522
-
1523
1494
  @auto_docstring
1495
+ @can_return_tuple
1524
1496
  def forward(
1525
1497
  self,
1526
1498
  pixel_values: torch.FloatTensor,
@@ -1530,11 +1502,8 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1530
1502
  inputs_embeds: torch.FloatTensor | None = None,
1531
1503
  decoder_inputs_embeds: torch.FloatTensor | None = None,
1532
1504
  labels: list[dict] | None = None,
1533
- output_attentions: bool | None = None,
1534
- output_hidden_states: bool | None = None,
1535
- return_dict: bool | None = None,
1536
- **kwargs,
1537
- ) -> tuple[torch.FloatTensor] | ConditionalDetrObjectDetectionOutput:
1505
+ **kwargs: Unpack[TransformersKwargs],
1506
+ ) -> ConditionalDetrObjectDetectionOutput:
1538
1507
  r"""
1539
1508
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1540
1509
  Not used by default. Can be used to mask object queries.
@@ -1584,8 +1553,6 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1584
1553
  Detected remote with confidence 0.683 at location [334.48, 73.49, 366.37, 190.01]
1585
1554
  Detected couch with confidence 0.535 at location [0.52, 1.19, 640.35, 475.1]
1586
1555
  ```"""
1587
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1588
-
1589
1556
  # First, sent images through CONDITIONAL_DETR base model to obtain encoder + decoder outputs
1590
1557
  outputs = self.model(
1591
1558
  pixel_values,
@@ -1594,9 +1561,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1594
1561
  encoder_outputs=encoder_outputs,
1595
1562
  inputs_embeds=inputs_embeds,
1596
1563
  decoder_inputs_embeds=decoder_inputs_embeds,
1597
- output_attentions=output_attentions,
1598
- output_hidden_states=output_hidden_states,
1599
- return_dict=return_dict,
1564
+ **kwargs,
1600
1565
  )
1601
1566
 
1602
1567
  sequence_output = outputs[0]
@@ -1604,11 +1569,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1604
1569
  # class logits + predicted bounding boxes
1605
1570
  logits = self.class_labels_classifier(sequence_output)
1606
1571
 
1607
- # Index [-2] is valid only if `output_attentions` and `output_hidden_states`
1608
- # are not specified, otherwise it will be another index which is hard to determine.
1609
- # Leave it as is, because it's not a common case to use
1610
- # return_dict=False + output_attentions=True / output_hidden_states=True
1611
- reference = outputs.reference_points if return_dict else outputs[-2]
1572
+ reference = outputs.reference_points
1612
1573
  reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
1613
1574
 
1614
1575
  hs = sequence_output
@@ -1622,7 +1583,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1622
1583
  outputs_class, outputs_coord = None, None
1623
1584
  if self.config.auxiliary_loss:
1624
1585
  outputs_coords = []
1625
- intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
1586
+ intermediate = outputs.intermediate_hidden_states
1626
1587
  outputs_class = self.class_labels_classifier(intermediate)
1627
1588
  for lvl in range(intermediate.shape[0]):
1628
1589
  tmp = self.bbox_predictor(intermediate[lvl])
@@ -1634,13 +1595,6 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1634
1595
  logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
1635
1596
  )
1636
1597
 
1637
- if not return_dict:
1638
- if auxiliary_outputs is not None:
1639
- output = (logits, pred_boxes) + auxiliary_outputs + outputs
1640
- else:
1641
- output = (logits, pred_boxes) + outputs
1642
- return ((loss, loss_dict) + output) if loss is not None else output
1643
-
1644
1598
  return ConditionalDetrObjectDetectionOutput(
1645
1599
  loss=loss,
1646
1600
  loss_dict=loss_dict,
@@ -1656,14 +1610,38 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1656
1610
  encoder_attentions=outputs.encoder_attentions,
1657
1611
  )
1658
1612
 
1613
+ # taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
1614
+ def _set_aux_loss(self, outputs_class, outputs_coord):
1615
+ return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
1616
+
1659
1617
 
1660
1618
  @auto_docstring(
1661
1619
  custom_intro="""
1662
- Conditional DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top,
1663
- for tasks such as COCO panoptic.
1620
+ CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
1621
+ such as COCO panoptic.
1664
1622
  """
1665
1623
  )
1666
1624
  class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1625
+ _checkpoint_conversion_mapping = {
1626
+ "bbox_attention.q_linear": "bbox_attention.q_proj",
1627
+ "bbox_attention.k_linear": "bbox_attention.k_proj",
1628
+ # Mask head refactor
1629
+ "mask_head.lay1": "mask_head.conv1.conv",
1630
+ "mask_head.gn1": "mask_head.conv1.norm",
1631
+ "mask_head.lay2": "mask_head.conv2.conv",
1632
+ "mask_head.gn2": "mask_head.conv2.norm",
1633
+ "mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter",
1634
+ "mask_head.lay3": "mask_head.fpn_stages.0.refine.conv",
1635
+ "mask_head.gn3": "mask_head.fpn_stages.0.refine.norm",
1636
+ "mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter",
1637
+ "mask_head.lay4": "mask_head.fpn_stages.1.refine.conv",
1638
+ "mask_head.gn4": "mask_head.fpn_stages.1.refine.norm",
1639
+ "mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter",
1640
+ "mask_head.lay5": "mask_head.fpn_stages.2.refine.conv",
1641
+ "mask_head.gn5": "mask_head.fpn_stages.2.refine.norm",
1642
+ "mask_head.out_lay": "mask_head.output_conv",
1643
+ }
1644
+
1667
1645
  def __init__(self, config: ConditionalDetrConfig):
1668
1646
  super().__init__(config)
1669
1647
 
@@ -1672,20 +1650,21 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1672
1650
 
1673
1651
  # segmentation head
1674
1652
  hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
1675
- intermediate_channel_sizes = self.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes
1653
+ intermediate_channel_sizes = self.conditional_detr.model.backbone.intermediate_channel_sizes
1676
1654
 
1677
1655
  self.mask_head = ConditionalDetrMaskHeadSmallConv(
1678
- hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size
1679
- )
1680
-
1681
- self.bbox_attention = ConditionalDetrMHAttentionMap(
1682
- hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
1656
+ input_channels=hidden_size + number_of_heads,
1657
+ fpn_channels=intermediate_channel_sizes[::-1][-3:],
1658
+ hidden_size=hidden_size,
1659
+ activation_function=config.activation_function,
1683
1660
  )
1684
1661
 
1662
+ self.bbox_attention = ConditionalDetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
1685
1663
  # Initialize weights and apply final processing
1686
1664
  self.post_init()
1687
1665
 
1688
1666
  @auto_docstring
1667
+ @can_return_tuple
1689
1668
  def forward(
1690
1669
  self,
1691
1670
  pixel_values: torch.FloatTensor,
@@ -1695,20 +1674,20 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1695
1674
  inputs_embeds: torch.FloatTensor | None = None,
1696
1675
  decoder_inputs_embeds: torch.FloatTensor | None = None,
1697
1676
  labels: list[dict] | None = None,
1698
- output_attentions: bool | None = None,
1699
- output_hidden_states: bool | None = None,
1700
- return_dict: bool | None = None,
1701
- **kwargs,
1677
+ **kwargs: Unpack[TransformersKwargs],
1702
1678
  ) -> tuple[torch.FloatTensor] | ConditionalDetrSegmentationOutput:
1703
1679
  r"""
1704
1680
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1705
- Not used by default. Can be used to mask object queries.
1681
+ Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
1682
+
1683
+ - 1 for queries that are **not masked**,
1684
+ - 0 for queries that are **masked**.
1706
1685
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1707
- Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1708
- can choose to directly pass a flattened representation of an image.
1686
+ Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
1687
+ multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
1709
1688
  decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1710
1689
  Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1711
- embedded representation.
1690
+ embedded representation. Useful for tasks that require custom query initialization.
1712
1691
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1713
1692
  Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
1714
1693
  dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
@@ -1721,26 +1700,21 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1721
1700
 
1722
1701
  ```python
1723
1702
  >>> import io
1724
- >>> import requests
1703
+ >>> import httpx
1704
+ >>> from io import BytesIO
1725
1705
  >>> from PIL import Image
1726
1706
  >>> import torch
1727
1707
  >>> import numpy
1728
1708
 
1729
- >>> from transformers import (
1730
- ... AutoImageProcessor,
1731
- ... ConditionalDetrConfig,
1732
- ... ConditionalDetrForSegmentation,
1733
- ... )
1709
+ >>> from transformers import AutoImageProcessor, ConditionalDetrForSegmentation
1734
1710
  >>> from transformers.image_transforms import rgb_to_id
1735
1711
 
1736
1712
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1737
- >>> image = Image.open(requests.get(url, stream=True).raw)
1713
+ >>> with httpx.stream("GET", url) as response:
1714
+ ... image = Image.open(BytesIO(response.read()))
1738
1715
 
1739
- >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50")
1740
-
1741
- >>> # randomly initialize all weights of the model
1742
- >>> config = ConditionalDetrConfig()
1743
- >>> model = ConditionalDetrForSegmentation(config)
1716
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/conditional_detr-resnet-50-panoptic")
1717
+ >>> model = ConditionalDetrForSegmentation.from_pretrained("facebook/conditional_detr-resnet-50-panoptic")
1744
1718
 
1745
1719
  >>> # prepare image for the model
1746
1720
  >>> inputs = image_processor(images=image, return_tensors="pt")
@@ -1751,89 +1725,88 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1751
1725
  >>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
1752
1726
  >>> # Segmentation results are returned as a list of dictionaries
1753
1727
  >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
1728
+
1754
1729
  >>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
1755
1730
  >>> panoptic_seg = result[0]["segmentation"]
1731
+ >>> panoptic_seg.shape
1732
+ torch.Size([300, 500])
1756
1733
  >>> # Get prediction score and segment_id to class_id mapping of each segment
1757
1734
  >>> panoptic_segments_info = result[0]["segments_info"]
1735
+ >>> len(panoptic_segments_info)
1736
+ 5
1758
1737
  ```"""
1759
1738
 
1760
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1761
-
1762
1739
  batch_size, num_channels, height, width = pixel_values.shape
1763
1740
  device = pixel_values.device
1764
1741
 
1765
1742
  if pixel_mask is None:
1766
1743
  pixel_mask = torch.ones((batch_size, height, width), device=device)
1767
1744
 
1768
- # First, get list of feature maps and object_queries
1769
- features, object_queries_list = self.conditional_detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
1745
+ vision_features = self.conditional_detr.model.backbone(pixel_values, pixel_mask)
1746
+ feature_map, mask = vision_features[-1]
1770
1747
 
1771
- # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1772
- feature_map, mask = features[-1]
1773
- batch_size, num_channels, height, width = feature_map.shape
1748
+ # Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
1774
1749
  projected_feature_map = self.conditional_detr.model.input_projection(feature_map)
1775
-
1776
- # Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1777
- # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1778
1750
  flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1779
- object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1780
-
1751
+ spatial_position_embeddings = self.conditional_detr.model.position_embedding(
1752
+ shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
1753
+ )
1781
1754
  flattened_mask = mask.flatten(1)
1782
1755
 
1783
- # Fourth, sent flattened_features + flattened_mask + object_queries through encoder
1784
- # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
1785
- # flattened_mask is a Tensor of shape (batch_size, height*width)
1786
1756
  if encoder_outputs is None:
1787
1757
  encoder_outputs = self.conditional_detr.model.encoder(
1788
1758
  inputs_embeds=flattened_features,
1789
1759
  attention_mask=flattened_mask,
1790
- object_queries=object_queries,
1791
- output_attentions=output_attentions,
1792
- output_hidden_states=output_hidden_states,
1793
- return_dict=return_dict,
1794
- )
1795
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1796
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1797
- encoder_outputs = BaseModelOutput(
1798
- last_hidden_state=encoder_outputs[0],
1799
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1800
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1760
+ spatial_position_embeddings=spatial_position_embeddings,
1761
+ **kwargs,
1801
1762
  )
1802
1763
 
1803
- # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
1804
- query_position_embeddings = self.conditional_detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
1805
- batch_size, 1, 1
1806
- )
1807
- queries = torch.zeros_like(query_position_embeddings)
1764
+ object_queries_position_embeddings = self.conditional_detr.model.query_position_embeddings.weight.unsqueeze(
1765
+ 0
1766
+ ).repeat(batch_size, 1, 1)
1767
+
1768
+ # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
1769
+ if decoder_inputs_embeds is not None:
1770
+ queries = decoder_inputs_embeds
1771
+ else:
1772
+ queries = torch.zeros_like(object_queries_position_embeddings)
1808
1773
 
1809
- # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1810
1774
  decoder_outputs = self.conditional_detr.model.decoder(
1811
1775
  inputs_embeds=queries,
1812
- attention_mask=None,
1813
- object_queries=object_queries,
1814
- query_position_embeddings=query_position_embeddings,
1815
- encoder_hidden_states=encoder_outputs[0],
1776
+ attention_mask=decoder_attention_mask,
1777
+ spatial_position_embeddings=spatial_position_embeddings,
1778
+ object_queries_position_embeddings=object_queries_position_embeddings,
1779
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
1816
1780
  encoder_attention_mask=flattened_mask,
1817
- output_attentions=output_attentions,
1818
- output_hidden_states=output_hidden_states,
1819
- return_dict=return_dict,
1781
+ **kwargs,
1820
1782
  )
1821
1783
 
1822
1784
  sequence_output = decoder_outputs[0]
1823
1785
 
1824
- # Sixth, compute logits, pred_boxes and pred_masks
1825
1786
  logits = self.conditional_detr.class_labels_classifier(sequence_output)
1826
1787
  pred_boxes = self.conditional_detr.bbox_predictor(sequence_output).sigmoid()
1827
1788
 
1828
- memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width)
1829
- mask = flattened_mask.view(batch_size, height, width)
1789
+ height, width = feature_map.shape[-2:]
1790
+ memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
1791
+ batch_size, self.config.d_model, height, width
1792
+ )
1793
+ attention_mask = flattened_mask.view(batch_size, height, width)
1794
+
1795
+ if attention_mask is not None:
1796
+ min_dtype = torch.finfo(memory.dtype).min
1797
+ attention_mask = torch.where(
1798
+ attention_mask.unsqueeze(1).unsqueeze(1),
1799
+ torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
1800
+ min_dtype,
1801
+ )
1830
1802
 
1831
- # FIXME h_boxes takes the last one computed, keep this in mind
1832
- # important: we need to reverse the mask, since in the original implementation the mask works reversed
1833
- # bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
1834
- bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)
1803
+ bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
1835
1804
 
1836
- seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]])
1805
+ seg_masks = self.mask_head(
1806
+ features=projected_feature_map,
1807
+ attention_masks=bbox_mask,
1808
+ fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
1809
+ )
1837
1810
 
1838
1811
  pred_masks = seg_masks.view(
1839
1812
  batch_size, self.conditional_detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]
@@ -1843,20 +1816,13 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1843
1816
  if labels is not None:
1844
1817
  outputs_class, outputs_coord = None, None
1845
1818
  if self.config.auxiliary_loss:
1846
- intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
1819
+ intermediate = decoder_outputs.intermediate_hidden_states
1847
1820
  outputs_class = self.conditional_detr.class_labels_classifier(intermediate)
1848
1821
  outputs_coord = self.conditional_detr.bbox_predictor(intermediate).sigmoid()
1849
1822
  loss, loss_dict, auxiliary_outputs = self.loss_function(
1850
- logits, labels, self.device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
1823
+ logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
1851
1824
  )
1852
1825
 
1853
- if not return_dict:
1854
- if auxiliary_outputs is not None:
1855
- output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
1856
- else:
1857
- output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
1858
- return ((loss, loss_dict) + output) if loss is not None else output
1859
-
1860
1826
  return ConditionalDetrSegmentationOutput(
1861
1827
  loss=loss,
1862
1828
  loss_dict=loss_dict,
@@ -1874,120 +1840,6 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1874
1840
  )
1875
1841
 
1876
1842
 
1877
- def _expand(tensor, length: int):
1878
- return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
1879
-
1880
-
1881
- # Copied from transformers.models.detr.modeling_detr.DetrMaskHeadSmallConv with Detr->ConditionalDetr
1882
- class ConditionalDetrMaskHeadSmallConv(nn.Module):
1883
- """
1884
- Simple convolutional head, using group norm. Upsampling is done using a FPN approach
1885
- """
1886
-
1887
- def __init__(self, dim, fpn_dims, context_dim):
1888
- super().__init__()
1889
-
1890
- if dim % 8 != 0:
1891
- raise ValueError(
1892
- "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
1893
- " GroupNorm is set to 8"
1894
- )
1895
-
1896
- inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
1897
-
1898
- self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
1899
- self.gn1 = nn.GroupNorm(8, dim)
1900
- self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
1901
- self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
1902
- self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
1903
- self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
1904
- self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
1905
- self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
1906
- self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
1907
- self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
1908
- self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
1909
-
1910
- self.dim = dim
1911
-
1912
- self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
1913
- self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
1914
- self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
1915
-
1916
- for m in self.modules():
1917
- if isinstance(m, nn.Conv2d):
1918
- init.kaiming_uniform_(m.weight, a=1)
1919
- init.constant_(m.bias, 0)
1920
-
1921
- def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]):
1922
- # here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with
1923
- # the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
1924
- # We expand the projected feature map to match the number of heads.
1925
- x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
1926
-
1927
- x = self.lay1(x)
1928
- x = self.gn1(x)
1929
- x = nn.functional.relu(x)
1930
- x = self.lay2(x)
1931
- x = self.gn2(x)
1932
- x = nn.functional.relu(x)
1933
-
1934
- cur_fpn = self.adapter1(fpns[0])
1935
- if cur_fpn.size(0) != x.size(0):
1936
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1937
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1938
- x = self.lay3(x)
1939
- x = self.gn3(x)
1940
- x = nn.functional.relu(x)
1941
-
1942
- cur_fpn = self.adapter2(fpns[1])
1943
- if cur_fpn.size(0) != x.size(0):
1944
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1945
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1946
- x = self.lay4(x)
1947
- x = self.gn4(x)
1948
- x = nn.functional.relu(x)
1949
-
1950
- cur_fpn = self.adapter3(fpns[2])
1951
- if cur_fpn.size(0) != x.size(0):
1952
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1953
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1954
- x = self.lay5(x)
1955
- x = self.gn5(x)
1956
- x = nn.functional.relu(x)
1957
-
1958
- x = self.out_lay(x)
1959
- return x
1960
-
1961
-
1962
- # Copied from transformers.models.detr.modeling_detr.DetrMHAttentionMap with Detr->ConditionalDetr
1963
- class ConditionalDetrMHAttentionMap(nn.Module):
1964
- """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
1965
-
1966
- def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
1967
- super().__init__()
1968
- self.num_heads = num_heads
1969
- self.hidden_dim = hidden_dim
1970
- self.dropout = nn.Dropout(dropout)
1971
-
1972
- self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1973
- self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1974
-
1975
- self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
1976
-
1977
- def forward(self, q, k, mask: Tensor | None = None):
1978
- q = self.q_linear(q)
1979
- k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
1980
- queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
1981
- keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
1982
- weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
1983
-
1984
- if mask is not None:
1985
- weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
1986
- weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
1987
- weights = self.dropout(weights)
1988
- return weights
1989
-
1990
-
1991
1843
  __all__ = [
1992
1844
  "ConditionalDetrForObjectDetection",
1993
1845
  "ConditionalDetrForSegmentation",