transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,9 @@
1
- # coding=utf-8
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/conditional_detr/modular_conditional_detr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_conditional_detr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
7
  # Copyright 2022 Microsoft Research Asia and The HuggingFace Inc. team. All rights reserved.
3
8
  #
4
9
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,40 +17,33 @@
12
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
18
  # See the License for the specific language governing permissions and
14
19
  # limitations under the License.
15
- """PyTorch Conditional DETR model."""
16
-
17
20
  import math
21
+ from collections.abc import Callable
18
22
  from dataclasses import dataclass
19
- from typing import Optional, Union
20
23
 
21
24
  import torch
22
- from torch import Tensor, nn
25
+ from torch import nn
23
26
 
24
27
  from ... import initialization as init
25
28
  from ...activations import ACT2FN
26
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
29
+ from ...backbone_utils import load_backbone
30
+ from ...masking_utils import create_bidirectional_mask
27
31
  from ...modeling_layers import GradientCheckpointingLayer
28
32
  from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
29
- from ...modeling_utils import PreTrainedModel
30
- from ...utils import ModelOutput, auto_docstring, is_timm_available, logging, requires_backends
31
- from ...utils.backbone_utils import load_backbone
33
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
34
+ from ...processing_utils import Unpack
35
+ from ...pytorch_utils import compile_compatible_method_lru_cache
36
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring
37
+ from ...utils.generic import OutputRecorder, can_return_tuple, check_model_inputs
32
38
  from .configuration_conditional_detr import ConditionalDetrConfig
33
39
 
34
40
 
35
- if is_timm_available():
36
- from timm import create_model
37
-
38
-
39
- logger = logging.get_logger(__name__)
40
-
41
-
42
41
  @dataclass
43
42
  @auto_docstring(
44
43
  custom_intro="""
45
- Base class for outputs of the Conditional DETR decoder. This class adds one attribute to
46
- BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder activations, i.e. the output
47
- of each decoder layer, each of them gone through a layernorm. This is useful when training the model with auxiliary
48
- decoding losses.
44
+ Base class for outputs of the CONDITIONAL_DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
45
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
46
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
49
47
  """
50
48
  )
51
49
  class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
@@ -61,17 +59,17 @@ class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
61
59
  Reference points (reference points of each layer of the decoder).
62
60
  """
63
61
 
64
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
65
- reference_points: Optional[tuple[torch.FloatTensor]] = None
62
+ intermediate_hidden_states: torch.FloatTensor | None = None
63
+
64
+ reference_points: tuple[torch.FloatTensor] | None = None
66
65
 
67
66
 
68
67
  @dataclass
69
68
  @auto_docstring(
70
69
  custom_intro="""
71
- Base class for outputs of the Conditional DETR encoder-decoder model. This class adds one attribute to
72
- Seq2SeqModelOutput, namely an optional stack of intermediate decoder activations, i.e. the output of each decoder
73
- layer, each of them gone through a layernorm. This is useful when training the model with auxiliary decoding
74
- losses.
70
+ Base class for outputs of the CONDITIONAL_DETR encoder-decoder model. This class adds one attribute to Seq2SeqModelOutput,
71
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
72
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
75
73
  """
76
74
  )
77
75
  class ConditionalDetrModelOutput(Seq2SeqModelOutput):
@@ -85,8 +83,9 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
85
83
  Reference points (reference points of each layer of the decoder).
86
84
  """
87
85
 
88
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
89
- reference_points: Optional[tuple[torch.FloatTensor]] = None
86
+ intermediate_hidden_states: torch.FloatTensor | None = None
87
+
88
+ reference_points: tuple[torch.FloatTensor] | None = None
90
89
 
91
90
 
92
91
  @dataclass
@@ -95,7 +94,6 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
95
94
  Output type of [`ConditionalDetrForObjectDetection`].
96
95
  """
97
96
  )
98
- # Copied from transformers.models.detr.modeling_detr.DetrObjectDetectionOutput with Detr->ConditionalDetr
99
97
  class ConditionalDetrObjectDetectionOutput(ModelOutput):
100
98
  r"""
101
99
  loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
@@ -119,18 +117,18 @@ class ConditionalDetrObjectDetectionOutput(ModelOutput):
119
117
  Sequence of hidden-states at the output of the last layer of the decoder of the model.
120
118
  """
121
119
 
122
- loss: Optional[torch.FloatTensor] = None
123
- loss_dict: Optional[dict] = None
124
- logits: Optional[torch.FloatTensor] = None
125
- pred_boxes: Optional[torch.FloatTensor] = None
126
- auxiliary_outputs: Optional[list[dict]] = None
127
- last_hidden_state: Optional[torch.FloatTensor] = None
128
- decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
129
- decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
130
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
131
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
132
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
133
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
120
+ loss: torch.FloatTensor | None = None
121
+ loss_dict: dict | None = None
122
+ logits: torch.FloatTensor | None = None
123
+ pred_boxes: torch.FloatTensor | None = None
124
+ auxiliary_outputs: list[dict] | None = None
125
+ last_hidden_state: torch.FloatTensor | None = None
126
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
127
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
128
+ cross_attentions: tuple[torch.FloatTensor] | None = None
129
+ encoder_last_hidden_state: torch.FloatTensor | None = None
130
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
131
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
134
132
 
135
133
 
136
134
  @dataclass
@@ -139,7 +137,6 @@ class ConditionalDetrObjectDetectionOutput(ModelOutput):
139
137
  Output type of [`ConditionalDetrForSegmentation`].
140
138
  """
141
139
  )
142
- # Copied from transformers.models.detr.modeling_detr.DetrSegmentationOutput with Detr->ConditionalDetr
143
140
  class ConditionalDetrSegmentationOutput(ModelOutput):
144
141
  r"""
145
142
  loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
@@ -169,22 +166,21 @@ class ConditionalDetrSegmentationOutput(ModelOutput):
169
166
  Sequence of hidden-states at the output of the last layer of the decoder of the model.
170
167
  """
171
168
 
172
- loss: Optional[torch.FloatTensor] = None
173
- loss_dict: Optional[dict] = None
174
- logits: Optional[torch.FloatTensor] = None
175
- pred_boxes: Optional[torch.FloatTensor] = None
176
- pred_masks: Optional[torch.FloatTensor] = None
177
- auxiliary_outputs: Optional[list[dict]] = None
178
- last_hidden_state: Optional[torch.FloatTensor] = None
179
- decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
180
- decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
181
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
182
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
183
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
184
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
185
-
186
-
187
- # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->ConditionalDetr
169
+ loss: torch.FloatTensor | None = None
170
+ loss_dict: dict | None = None
171
+ logits: torch.FloatTensor | None = None
172
+ pred_boxes: torch.FloatTensor | None = None
173
+ pred_masks: torch.FloatTensor | None = None
174
+ auxiliary_outputs: list[dict] | None = None
175
+ last_hidden_state: torch.FloatTensor | None = None
176
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
177
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
178
+ cross_attentions: tuple[torch.FloatTensor] | None = None
179
+ encoder_last_hidden_state: torch.FloatTensor | None = None
180
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
181
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
182
+
183
+
188
184
  class ConditionalDetrFrozenBatchNorm2d(nn.Module):
189
185
  """
190
186
  BatchNorm2d where the batch statistics and the affine parameters are fixed.
@@ -224,7 +220,6 @@ class ConditionalDetrFrozenBatchNorm2d(nn.Module):
224
220
  return x * scale + bias
225
221
 
226
222
 
227
- # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->ConditionalDetr
228
223
  def replace_batch_norm(model):
229
224
  r"""
230
225
  Recursively replace all `torch.nn.BatchNorm2d` with `ConditionalDetrFrozenBatchNorm2d`.
@@ -249,7 +244,6 @@ def replace_batch_norm(model):
249
244
  replace_batch_norm(module)
250
245
 
251
246
 
252
- # Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->ConditionalDetr
253
247
  class ConditionalDetrConvEncoder(nn.Module):
254
248
  """
255
249
  Convolutional backbone, using either the AutoBackbone API or one from the timm library.
@@ -263,47 +257,25 @@ class ConditionalDetrConvEncoder(nn.Module):
263
257
 
264
258
  self.config = config
265
259
 
266
- # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
267
- if config.use_timm_backbone:
268
- # We default to values which were previously hard-coded. This enables configurability from the config
269
- # using backbone arguments, while keeping the default behavior the same.
270
- requires_backends(self, ["timm"])
271
- kwargs = getattr(config, "backbone_kwargs", {})
272
- kwargs = {} if kwargs is None else kwargs.copy()
273
- out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
274
- num_channels = kwargs.pop("in_chans", config.num_channels)
275
- if config.dilation:
276
- kwargs["output_stride"] = kwargs.get("output_stride", 16)
277
- backbone = create_model(
278
- config.backbone,
279
- pretrained=config.use_pretrained_backbone,
280
- features_only=True,
281
- out_indices=out_indices,
282
- in_chans=num_channels,
283
- **kwargs,
284
- )
285
- else:
286
- backbone = load_backbone(config)
260
+ backbone = load_backbone(config)
261
+ self.intermediate_channel_sizes = backbone.channels
287
262
 
288
263
  # replace batch norm by frozen batch norm
289
264
  with torch.no_grad():
290
265
  replace_batch_norm(backbone)
291
- self.model = backbone
292
- self.intermediate_channel_sizes = (
293
- self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
294
- )
295
266
 
296
- backbone_model_type = None
297
- if config.backbone is not None:
298
- backbone_model_type = config.backbone
299
- elif config.backbone_config is not None:
300
- backbone_model_type = config.backbone_config.model_type
301
- else:
302
- raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
267
+ # We used to load with timm library directly instead of the AutoBackbone API
268
+ # so we need to unwrap the `backbone._backbone` module to load weights without mismatch
269
+ is_timm_model = False
270
+ if hasattr(backbone, "_backbone"):
271
+ backbone = backbone._backbone
272
+ is_timm_model = True
273
+ self.model = backbone
303
274
 
275
+ backbone_model_type = config.backbone_config.model_type
304
276
  if "resnet" in backbone_model_type:
305
277
  for name, parameter in self.model.named_parameters():
306
- if config.use_timm_backbone:
278
+ if is_timm_model:
307
279
  if "layer2" not in name and "layer3" not in name and "layer4" not in name:
308
280
  parameter.requires_grad_(False)
309
281
  else:
@@ -312,7 +284,9 @@ class ConditionalDetrConvEncoder(nn.Module):
312
284
 
313
285
  def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
314
286
  # send pixel_values through the model to get list of feature maps
315
- features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
287
+ features = self.model(pixel_values)
288
+ if isinstance(features, dict):
289
+ features = features.feature_maps
316
290
 
317
291
  out = []
318
292
  for feature_map in features:
@@ -322,66 +296,58 @@ class ConditionalDetrConvEncoder(nn.Module):
322
296
  return out
323
297
 
324
298
 
325
- # Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->ConditionalDetr
326
- class ConditionalDetrConvModel(nn.Module):
327
- """
328
- This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
329
- """
330
-
331
- def __init__(self, conv_encoder, position_embedding):
332
- super().__init__()
333
- self.conv_encoder = conv_encoder
334
- self.position_embedding = position_embedding
335
-
336
- def forward(self, pixel_values, pixel_mask):
337
- # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
338
- out = self.conv_encoder(pixel_values, pixel_mask)
339
- pos = []
340
- for feature_map, mask in out:
341
- # position encoding
342
- pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
343
-
344
- return out, pos
345
-
346
-
347
299
  class ConditionalDetrSinePositionEmbedding(nn.Module):
348
300
  """
349
301
  This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
350
302
  need paper, generalized to work on images.
351
303
  """
352
304
 
353
- def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
305
+ def __init__(
306
+ self,
307
+ num_position_features: int = 64,
308
+ temperature: int = 10000,
309
+ normalize: bool = False,
310
+ scale: float | None = None,
311
+ ):
354
312
  super().__init__()
355
- self.embedding_dim = embedding_dim
356
- self.temperature = temperature
357
- self.normalize = normalize
358
313
  if scale is not None and normalize is False:
359
314
  raise ValueError("normalize should be True if scale is passed")
360
- if scale is None:
361
- scale = 2 * math.pi
362
- self.scale = scale
315
+ self.num_position_features = num_position_features
316
+ self.temperature = temperature
317
+ self.normalize = normalize
318
+ self.scale = 2 * math.pi if scale is None else scale
363
319
 
364
- def forward(self, pixel_values, pixel_mask):
365
- if pixel_mask is None:
366
- raise ValueError("No pixel mask provided")
367
- y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
368
- x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
320
+ @compile_compatible_method_lru_cache(maxsize=1)
321
+ def forward(
322
+ self,
323
+ shape: torch.Size,
324
+ device: torch.device | str,
325
+ dtype: torch.dtype,
326
+ mask: torch.Tensor | None = None,
327
+ ) -> torch.Tensor:
328
+ if mask is None:
329
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
330
+ y_embed = mask.cumsum(1, dtype=dtype)
331
+ x_embed = mask.cumsum(2, dtype=dtype)
369
332
  if self.normalize:
370
- y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
371
- x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
333
+ eps = 1e-6
334
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
335
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
372
336
 
373
- dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
374
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
337
+ dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
338
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
375
339
 
376
340
  pos_x = x_embed[:, :, :, None] / dim_t
377
341
  pos_y = y_embed[:, :, :, None] / dim_t
378
342
  pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
379
343
  pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
380
344
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
345
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
346
+ # expected by the encoder
347
+ pos = pos.flatten(2).permute(0, 2, 1)
381
348
  return pos
382
349
 
383
350
 
384
- # Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding with Detr->ConditionalDetr
385
351
  class ConditionalDetrLearnedPositionEmbedding(nn.Module):
386
352
  """
387
353
  This module learns positional embeddings up to a fixed maximum size.
@@ -392,354 +358,385 @@ class ConditionalDetrLearnedPositionEmbedding(nn.Module):
392
358
  self.row_embeddings = nn.Embedding(50, embedding_dim)
393
359
  self.column_embeddings = nn.Embedding(50, embedding_dim)
394
360
 
395
- def forward(self, pixel_values, pixel_mask=None):
396
- height, width = pixel_values.shape[-2:]
397
- width_values = torch.arange(width, device=pixel_values.device)
398
- height_values = torch.arange(height, device=pixel_values.device)
361
+ @compile_compatible_method_lru_cache(maxsize=1)
362
+ def forward(
363
+ self,
364
+ shape: torch.Size,
365
+ device: torch.device | str,
366
+ dtype: torch.dtype,
367
+ mask: torch.Tensor | None = None,
368
+ ):
369
+ height, width = shape[-2:]
370
+ width_values = torch.arange(width, device=device)
371
+ height_values = torch.arange(height, device=device)
399
372
  x_emb = self.column_embeddings(width_values)
400
373
  y_emb = self.row_embeddings(height_values)
401
374
  pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
402
375
  pos = pos.permute(2, 0, 1)
403
376
  pos = pos.unsqueeze(0)
404
- pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
377
+ pos = pos.repeat(shape[0], 1, 1, 1)
378
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
379
+ # expected by the encoder
380
+ pos = pos.flatten(2).permute(0, 2, 1)
405
381
  return pos
406
382
 
407
383
 
408
- # Copied from transformers.models.detr.modeling_detr.build_position_encoding with Detr->ConditionalDetr
409
- def build_position_encoding(config):
410
- n_steps = config.d_model // 2
411
- if config.position_embedding_type == "sine":
412
- # TODO find a better way of exposing other arguments
413
- position_embedding = ConditionalDetrSinePositionEmbedding(n_steps, normalize=True)
414
- elif config.position_embedding_type == "learned":
415
- position_embedding = ConditionalDetrLearnedPositionEmbedding(n_steps)
416
- else:
417
- raise ValueError(f"Not supported {config.position_embedding_type}")
384
+ def eager_attention_forward(
385
+ module: nn.Module,
386
+ query: torch.Tensor,
387
+ key: torch.Tensor,
388
+ value: torch.Tensor,
389
+ attention_mask: torch.Tensor | None,
390
+ scaling: float | None = None,
391
+ dropout: float = 0.0,
392
+ **kwargs: Unpack[TransformersKwargs],
393
+ ):
394
+ if scaling is None:
395
+ scaling = query.size(-1) ** -0.5
418
396
 
419
- return position_embedding
397
+ # Take the dot product between "query" and "key" to get the raw attention scores.
398
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
420
399
 
400
+ if attention_mask is not None:
401
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
402
+ attn_weights = attn_weights + attention_mask
421
403
 
422
- # function to generate sine positional embedding for 2d coordinates
423
- def gen_sine_position_embeddings(pos_tensor, d_model):
424
- scale = 2 * math.pi
425
- dim = d_model // 2
426
- dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
427
- dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
428
- x_embed = pos_tensor[:, :, 0] * scale
429
- y_embed = pos_tensor[:, :, 1] * scale
430
- pos_x = x_embed[:, :, None] / dim_t
431
- pos_y = y_embed[:, :, None] / dim_t
432
- pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
433
- pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
434
- pos = torch.cat((pos_y, pos_x), dim=2)
435
- return pos.to(pos_tensor.dtype)
404
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
405
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
436
406
 
407
+ attn_output = torch.matmul(attn_weights, value)
408
+ attn_output = attn_output.transpose(1, 2).contiguous()
437
409
 
438
- def inverse_sigmoid(x, eps=1e-5):
439
- x = x.clamp(min=0, max=1)
440
- x1 = x.clamp(min=eps)
441
- x2 = (1 - x).clamp(min=eps)
442
- return torch.log(x1 / x2)
410
+ return attn_output, attn_weights
443
411
 
444
412
 
445
- # Copied from transformers.models.detr.modeling_detr.DetrAttention
446
- class DetrAttention(nn.Module):
413
+ class ConditionalDetrSelfAttention(nn.Module):
447
414
  """
448
- Multi-headed attention from 'Attention Is All You Need' paper.
415
+ Multi-headed self-attention from 'Attention Is All You Need' paper.
449
416
 
450
- Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
417
+ In CONDITIONAL_DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
451
418
  """
452
419
 
453
420
  def __init__(
454
421
  self,
455
- embed_dim: int,
456
- num_heads: int,
422
+ config: ConditionalDetrConfig,
423
+ hidden_size: int,
424
+ num_attention_heads: int,
457
425
  dropout: float = 0.0,
458
426
  bias: bool = True,
459
427
  ):
460
428
  super().__init__()
461
- self.embed_dim = embed_dim
462
- self.num_heads = num_heads
463
- self.dropout = dropout
464
- self.head_dim = embed_dim // num_heads
465
- if self.head_dim * num_heads != self.embed_dim:
466
- raise ValueError(
467
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
468
- f" {num_heads})."
469
- )
429
+ self.config = config
430
+ self.head_dim = hidden_size // num_attention_heads
470
431
  self.scaling = self.head_dim**-0.5
432
+ self.attention_dropout = dropout
433
+ self.is_causal = False
471
434
 
472
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
473
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
474
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
475
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
476
-
477
- def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
478
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
479
-
480
- def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
481
- return tensor if object_queries is None else tensor + object_queries
435
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
436
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
437
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
438
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
482
439
 
483
440
  def forward(
484
441
  self,
485
442
  hidden_states: torch.Tensor,
486
- attention_mask: Optional[torch.Tensor] = None,
487
- object_queries: Optional[torch.Tensor] = None,
488
- key_value_states: Optional[torch.Tensor] = None,
489
- spatial_position_embeddings: Optional[torch.Tensor] = None,
490
- output_attentions: bool = False,
491
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
492
- """Input shape: Batch x Time x Channel"""
493
- # if key_value_states are provided this layer is used as a cross-attention layer
494
- # for the decoder
495
- is_cross_attention = key_value_states is not None
496
- batch_size, target_len, embed_dim = hidden_states.size()
497
-
498
- # add position embeddings to the hidden states before projecting to queries and keys
499
- if object_queries is not None:
500
- hidden_states_original = hidden_states
501
- hidden_states = self.with_pos_embed(hidden_states, object_queries)
502
-
503
- # add key-value position embeddings to the key value states
504
- if spatial_position_embeddings is not None:
505
- key_value_states_original = key_value_states
506
- key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
507
-
508
- # get query proj
509
- query_states = self.q_proj(hidden_states) * self.scaling
510
- # get key, value proj
511
- if is_cross_attention:
512
- # cross_attentions
513
- key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
514
- value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
515
- else:
516
- # self_attention
517
- key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
518
- value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
443
+ attention_mask: torch.Tensor | None = None,
444
+ position_embeddings: torch.Tensor | None = None,
445
+ **kwargs: Unpack[TransformersKwargs],
446
+ ) -> tuple[torch.Tensor, torch.Tensor]:
447
+ """
448
+ Position embeddings are added to both queries and keys (but not values).
449
+ """
450
+ input_shape = hidden_states.shape[:-1]
451
+ hidden_shape = (*input_shape, -1, self.head_dim)
519
452
 
520
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
521
- query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
522
- key_states = key_states.view(*proj_shape)
523
- value_states = value_states.view(*proj_shape)
453
+ query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
524
454
 
525
- source_len = key_states.size(1)
455
+ query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
456
+ key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
457
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
526
458
 
527
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
459
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
460
+ self.config._attn_implementation, eager_attention_forward
461
+ )
528
462
 
529
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
530
- raise ValueError(
531
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
532
- f" {attn_weights.size()}"
533
- )
463
+ attn_output, attn_weights = attention_interface(
464
+ self,
465
+ query_states,
466
+ key_states,
467
+ value_states,
468
+ attention_mask,
469
+ dropout=0.0 if not self.training else self.attention_dropout,
470
+ scaling=self.scaling,
471
+ **kwargs,
472
+ )
534
473
 
535
- if attention_mask is not None:
536
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
537
- raise ValueError(
538
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
539
- f" {attention_mask.size()}"
540
- )
541
- if attention_mask.dtype == torch.bool:
542
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
543
- attention_mask, -torch.inf
544
- )
545
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
546
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
547
-
548
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
549
-
550
- if output_attentions:
551
- # this operation is a bit awkward, but it's required to
552
- # make sure that attn_weights keeps its gradient.
553
- # In order to do so, attn_weights have to reshaped
554
- # twice and have to be reused in the following
555
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
556
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
557
- else:
558
- attn_weights_reshaped = None
474
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
475
+ attn_output = self.o_proj(attn_output)
476
+ return attn_output, attn_weights
559
477
 
560
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
561
478
 
562
- attn_output = torch.bmm(attn_probs, value_states)
479
+ class ConditionalDetrDecoderSelfAttention(nn.Module):
480
+ """
481
+ Multi-headed self-attention for Conditional DETR decoder layers.
563
482
 
564
- if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
565
- raise ValueError(
566
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
567
- f" {attn_output.size()}"
568
- )
483
+ This attention module handles separate content and position projections, which are then combined
484
+ before applying standard self-attention. Position embeddings are added to both queries and keys.
485
+ """
569
486
 
570
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
571
- attn_output = attn_output.transpose(1, 2)
572
- attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
487
+ def __init__(
488
+ self,
489
+ config: ConditionalDetrConfig,
490
+ hidden_size: int,
491
+ num_attention_heads: int,
492
+ dropout: float = 0.0,
493
+ ):
494
+ super().__init__()
495
+ self.config = config
496
+ self.hidden_size = hidden_size
497
+ self.head_dim = hidden_size // num_attention_heads
498
+ self.scaling = self.head_dim**-0.5
499
+ self.attention_dropout = dropout
500
+ self.is_causal = False
501
+
502
+ # Content and position projections
503
+ self.q_content_proj = nn.Linear(hidden_size, hidden_size)
504
+ self.q_pos_proj = nn.Linear(hidden_size, hidden_size)
505
+ self.k_content_proj = nn.Linear(hidden_size, hidden_size)
506
+ self.k_pos_proj = nn.Linear(hidden_size, hidden_size)
507
+ self.v_proj = nn.Linear(hidden_size, hidden_size)
508
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
509
+
510
+ def forward(
511
+ self,
512
+ hidden_states: torch.Tensor,
513
+ query_position_embeddings: torch.Tensor,
514
+ attention_mask: torch.Tensor | None = None,
515
+ **kwargs: Unpack[TransformersKwargs],
516
+ ) -> tuple[torch.Tensor, torch.Tensor]:
517
+ """
518
+ Args:
519
+ hidden_states (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
520
+ Input hidden states from the decoder layer.
521
+ query_position_embeddings (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
522
+ Position embeddings for queries and keys. Required (unlike standard attention). Processed through
523
+ separate position projections (`q_pos_proj`, `k_pos_proj`) and added to content projections.
524
+ attention_mask (`torch.Tensor` of shape `(batch_size, 1, num_queries, num_queries)`, *optional*):
525
+ Attention mask to avoid attending to padding tokens.
526
+ """
527
+ input_shape = hidden_states.shape[:-1]
528
+ hidden_shape = (*input_shape, -1, self.head_dim)
529
+
530
+ query_states = (
531
+ (self.q_content_proj(hidden_states) + self.q_pos_proj(query_position_embeddings))
532
+ .view(hidden_shape)
533
+ .transpose(1, 2)
534
+ )
535
+ key_states = (
536
+ (self.k_content_proj(hidden_states) + self.k_pos_proj(query_position_embeddings))
537
+ .view(hidden_shape)
538
+ .transpose(1, 2)
539
+ )
540
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
573
541
 
574
- attn_output = self.out_proj(attn_output)
542
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
543
+ self.config._attn_implementation, eager_attention_forward
544
+ )
545
+
546
+ attn_output, attn_weights = attention_interface(
547
+ self,
548
+ query_states,
549
+ key_states,
550
+ value_states,
551
+ attention_mask,
552
+ dropout=0.0 if not self.training else self.attention_dropout,
553
+ scaling=self.scaling,
554
+ **kwargs,
555
+ )
575
556
 
576
- return attn_output, attn_weights_reshaped
557
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
558
+ attn_output = self.o_proj(attn_output)
559
+ return attn_output, attn_weights
577
560
 
578
561
 
579
- class ConditionalDetrAttention(nn.Module):
562
+ class ConditionalDetrDecoderCrossAttention(nn.Module):
580
563
  """
581
- Cross-Attention used in Conditional DETR 'Conditional DETR for Fast Training Convergence' paper.
564
+ Multi-headed cross-attention for Conditional DETR decoder layers.
582
565
 
583
- The key q_proj, k_proj, v_proj are defined outside the attention. This attention allows the dim of q, k to be
584
- different to v.
566
+ This attention module handles the special cross-attention logic in Conditional DETR:
567
+ - Separate content and position projections for queries and keys
568
+ - Concatenation of query sine embeddings with queries (doubling query dimension)
569
+ - Concatenation of key position embeddings with keys (doubling key dimension)
570
+ - Output dimension remains hidden_size despite doubled input dimensions
585
571
  """
586
572
 
587
573
  def __init__(
588
574
  self,
589
- embed_dim: int,
590
- out_dim: int,
591
- num_heads: int,
575
+ config: ConditionalDetrConfig,
576
+ hidden_size: int,
577
+ num_attention_heads: int,
592
578
  dropout: float = 0.0,
593
- bias: bool = True,
594
579
  ):
595
580
  super().__init__()
596
- self.embed_dim = embed_dim
597
- self.out_dim = out_dim
598
- self.num_heads = num_heads
599
- self.dropout = dropout
600
- self.head_dim = embed_dim // num_heads
601
- if self.head_dim * num_heads != self.embed_dim:
602
- raise ValueError(
603
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
604
- f" {num_heads})."
605
- )
606
- # head dimension of values
607
- self.v_head_dim = out_dim // num_heads
608
- if self.v_head_dim * num_heads != self.out_dim:
609
- raise ValueError(
610
- f"out_dim must be divisible by num_heads (got `out_dim`: {self.out_dim} and `num_heads`: {num_heads})."
611
- )
612
- self.scaling = self.head_dim**-0.5
613
-
614
- self.out_proj = nn.Linear(out_dim, out_dim, bias=bias)
615
-
616
- def _qk_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
617
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
618
-
619
- def _v_shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
620
- return tensor.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous()
581
+ self.config = config
582
+ self.hidden_size = hidden_size
583
+ self.num_attention_heads = num_attention_heads
584
+ self.head_dim = hidden_size // num_attention_heads
585
+ self.attention_dropout = dropout
586
+ self.is_causal = False
587
+
588
+ # Content and position projections
589
+ self.q_content_proj = nn.Linear(hidden_size, hidden_size)
590
+ self.q_pos_proj = nn.Linear(hidden_size, hidden_size)
591
+ self.k_content_proj = nn.Linear(hidden_size, hidden_size)
592
+ self.k_pos_proj = nn.Linear(hidden_size, hidden_size)
593
+ self.v_proj = nn.Linear(hidden_size, hidden_size)
594
+ self.q_pos_sine_proj = nn.Linear(hidden_size, hidden_size)
595
+
596
+ # Output projection: input is hidden_size * 2 (from concatenated q/k), output is hidden_size
597
+ self.o_proj = nn.Linear(hidden_size, hidden_size)
598
+
599
+ # Compute scaling for expanded head_dim (q and k have doubled dimensions after concatenation)
600
+ # This matches the original Conditional DETR implementation where embed_dim * 2 is used
601
+ expanded_head_dim = (hidden_size * 2) // num_attention_heads
602
+ self.scaling = expanded_head_dim**-0.5
621
603
 
622
604
  def forward(
623
605
  self,
624
606
  hidden_states: torch.Tensor,
625
- attention_mask: Optional[torch.Tensor] = None,
626
- key_states: Optional[torch.Tensor] = None,
627
- value_states: Optional[torch.Tensor] = None,
628
- output_attentions: bool = False,
629
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
630
- """Input shape: Batch x Time x Channel"""
631
-
632
- batch_size, target_len, _ = hidden_states.size()
633
-
634
- # get query proj
635
- query_states = hidden_states * self.scaling
636
- # get key, value proj
637
- key_states = self._qk_shape(key_states, -1, batch_size)
638
- value_states = self._v_shape(value_states, -1, batch_size)
639
-
640
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
641
- v_proj_shape = (batch_size * self.num_heads, -1, self.v_head_dim)
642
- query_states = self._qk_shape(query_states, target_len, batch_size).view(*proj_shape)
643
- key_states = key_states.view(*proj_shape)
644
- value_states = value_states.view(*v_proj_shape)
645
-
646
- source_len = key_states.size(1)
647
-
648
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
649
-
650
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
651
- raise ValueError(
652
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
653
- f" {attn_weights.size()}"
654
- )
655
-
656
- if attention_mask is not None:
657
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
658
- raise ValueError(
659
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
660
- f" {attention_mask.size()}"
661
- )
662
- if attention_mask.dtype == torch.bool:
663
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
664
- attention_mask, -torch.inf
665
- )
666
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
667
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
668
-
669
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
670
-
671
- if output_attentions:
672
- # this operation is a bit awkward, but it's required to
673
- # make sure that attn_weights keeps its gradient.
674
- # In order to do so, attn_weights have to reshaped
675
- # twice and have to be reused in the following
676
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
677
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
678
- else:
679
- attn_weights_reshaped = None
607
+ encoder_hidden_states: torch.Tensor,
608
+ query_sine_embed: torch.Tensor,
609
+ encoder_position_embeddings: torch.Tensor,
610
+ query_position_embeddings: torch.Tensor | None = None,
611
+ attention_mask: torch.Tensor | None = None,
612
+ **kwargs: Unpack[TransformersKwargs],
613
+ ) -> tuple[torch.Tensor, torch.Tensor]:
614
+ """
615
+ Args:
616
+ hidden_states (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
617
+ Decoder hidden states (queries).
618
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, encoder_seq_len, hidden_size)`):
619
+ Encoder output hidden states (keys and values).
620
+ query_sine_embed (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`):
621
+ Sine position embeddings for queries. **Concatenated** (not added) with query content,
622
+ doubling the query dimension.
623
+ encoder_position_embeddings (`torch.Tensor` of shape `(batch_size, encoder_seq_len, hidden_size)`):
624
+ Position embeddings for keys. **Concatenated** (not added) with key content, doubling the key dimension.
625
+ query_position_embeddings (`torch.Tensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
626
+ Additional position embeddings. When provided (first layer only), **added** to query content
627
+ before concatenation with `query_sine_embed`. Also causes `encoder_position_embeddings` to be
628
+ added to key content before concatenation.
629
+ attention_mask (`torch.Tensor` of shape `(batch_size, 1, num_queries, encoder_seq_len)`, *optional*):
630
+ Attention mask to avoid attending to padding tokens.
631
+ """
632
+ query_input_shape = hidden_states.shape[:-1]
633
+ kv_input_shape = encoder_hidden_states.shape[:-1]
634
+ query_hidden_shape = (*query_input_shape, self.num_attention_heads, self.head_dim)
635
+ kv_hidden_shape = (*kv_input_shape, self.num_attention_heads, self.head_dim)
636
+
637
+ # Apply content and position projections
638
+ query_input = self.q_content_proj(hidden_states)
639
+ key_input = self.k_content_proj(encoder_hidden_states)
640
+ value_states = self.v_proj(encoder_hidden_states)
641
+ key_pos = self.k_pos_proj(encoder_position_embeddings)
642
+
643
+ # Combine content and position embeddings
644
+ if query_position_embeddings is not None:
645
+ query_input = query_input + self.q_pos_proj(query_position_embeddings)
646
+ key_input = key_input + key_pos
647
+
648
+ # Reshape and concatenate position embeddings (doubling head_dim)
649
+ query_input = query_input.view(query_hidden_shape)
650
+ key_input = key_input.view(kv_hidden_shape)
651
+ query_sine_embed = self.q_pos_sine_proj(query_sine_embed).view(query_hidden_shape)
652
+ key_pos = key_pos.view(kv_hidden_shape)
653
+
654
+ query_states = torch.cat([query_input, query_sine_embed], dim=-1).view(*query_input_shape, -1)
655
+ key_states = torch.cat([key_input, key_pos], dim=-1).view(*kv_input_shape, -1)
656
+
657
+ # Reshape for attention computation
658
+ expanded_head_dim = query_states.shape[-1] // self.num_attention_heads
659
+ query_states = query_states.view(*query_input_shape, self.num_attention_heads, expanded_head_dim).transpose(
660
+ 1, 2
661
+ )
662
+ key_states = key_states.view(*kv_input_shape, self.num_attention_heads, expanded_head_dim).transpose(1, 2)
663
+ value_states = value_states.view(kv_hidden_shape).transpose(1, 2)
680
664
 
681
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
665
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
666
+ self.config._attn_implementation, eager_attention_forward
667
+ )
682
668
 
683
- attn_output = torch.bmm(attn_probs, value_states)
669
+ attn_output, attn_weights = attention_interface(
670
+ self,
671
+ query_states,
672
+ key_states,
673
+ value_states,
674
+ attention_mask,
675
+ dropout=0.0 if not self.training else self.attention_dropout,
676
+ scaling=self.scaling,
677
+ **kwargs,
678
+ )
684
679
 
685
- if attn_output.size() != (batch_size * self.num_heads, target_len, self.v_head_dim):
686
- raise ValueError(
687
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.v_head_dim)}, but is"
688
- f" {attn_output.size()}"
689
- )
680
+ attn_output = attn_output.reshape(*query_input_shape, -1).contiguous()
681
+ attn_output = self.o_proj(attn_output)
682
+ return attn_output, attn_weights
690
683
 
691
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.v_head_dim)
692
- attn_output = attn_output.transpose(1, 2)
693
- attn_output = attn_output.reshape(batch_size, target_len, self.out_dim)
694
684
 
695
- attn_output = self.out_proj(attn_output)
685
+ class ConditionalDetrMLP(nn.Module):
686
+ def __init__(self, config: ConditionalDetrConfig, hidden_size: int, intermediate_size: int):
687
+ super().__init__()
688
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
689
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
690
+ self.activation_fn = ACT2FN[config.activation_function]
691
+ self.activation_dropout = config.activation_dropout
692
+ self.dropout = config.dropout
696
693
 
697
- return attn_output, attn_weights_reshaped
694
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
695
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
696
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
697
+ hidden_states = self.fc2(hidden_states)
698
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
699
+ return hidden_states
698
700
 
699
701
 
700
- # Copied from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->ConditionalDetrEncoderLayer,DetrConfig->ConditionalDetrConfig
701
- class ConditionalDetrEncoderLayer(nn.Module):
702
+ class ConditionalDetrEncoderLayer(GradientCheckpointingLayer):
702
703
  def __init__(self, config: ConditionalDetrConfig):
703
704
  super().__init__()
704
- self.embed_dim = config.d_model
705
- self.self_attn = DetrAttention(
706
- embed_dim=self.embed_dim,
707
- num_heads=config.encoder_attention_heads,
705
+ self.hidden_size = config.d_model
706
+ self.self_attn = ConditionalDetrSelfAttention(
707
+ config=config,
708
+ hidden_size=self.hidden_size,
709
+ num_attention_heads=config.encoder_attention_heads,
708
710
  dropout=config.attention_dropout,
709
711
  )
710
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
712
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
711
713
  self.dropout = config.dropout
712
- self.activation_fn = ACT2FN[config.activation_function]
713
- self.activation_dropout = config.activation_dropout
714
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
715
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
716
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
714
+ self.mlp = ConditionalDetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
715
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
717
716
 
718
717
  def forward(
719
718
  self,
720
719
  hidden_states: torch.Tensor,
721
720
  attention_mask: torch.Tensor,
722
- object_queries: Optional[torch.Tensor] = None,
723
- output_attentions: bool = False,
724
- ):
721
+ spatial_position_embeddings: torch.Tensor | None = None,
722
+ **kwargs: Unpack[TransformersKwargs],
723
+ ) -> torch.Tensor:
725
724
  """
726
725
  Args:
727
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
726
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
728
727
  attention_mask (`torch.FloatTensor`): attention mask of size
729
728
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
730
729
  values.
731
- object_queries (`torch.FloatTensor`, *optional*):
732
- Object queries (also called content embeddings), to be added to the hidden states.
733
- output_attentions (`bool`, *optional*):
734
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
735
- returned tensors for more detail.
730
+ spatial_position_embeddings (`torch.FloatTensor`, *optional*):
731
+ Spatial position embeddings (2D positional encodings of image locations), to be added to both
732
+ the queries and keys in self-attention (but not to values).
736
733
  """
737
734
  residual = hidden_states
738
- hidden_states, attn_weights = self.self_attn(
735
+ hidden_states, _ = self.self_attn(
739
736
  hidden_states=hidden_states,
740
737
  attention_mask=attention_mask,
741
- object_queries=object_queries,
742
- output_attentions=output_attentions,
738
+ position_embeddings=spatial_position_embeddings,
739
+ **kwargs,
743
740
  )
744
741
 
745
742
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -747,12 +744,7 @@ class ConditionalDetrEncoderLayer(nn.Module):
747
744
  hidden_states = self.self_attn_layer_norm(hidden_states)
748
745
 
749
746
  residual = hidden_states
750
- hidden_states = self.activation_fn(self.fc1(hidden_states))
751
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
752
-
753
- hidden_states = self.fc2(hidden_states)
754
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
755
-
747
+ hidden_states = self.mlp(hidden_states)
756
748
  hidden_states = residual + hidden_states
757
749
  hidden_states = self.final_layer_norm(hidden_states)
758
750
 
@@ -761,80 +753,55 @@ class ConditionalDetrEncoderLayer(nn.Module):
761
753
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
762
754
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
763
755
 
764
- outputs = (hidden_states,)
765
-
766
- if output_attentions:
767
- outputs += (attn_weights,)
768
-
769
- return outputs
756
+ return hidden_states
770
757
 
771
758
 
772
759
  class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
773
760
  def __init__(self, config: ConditionalDetrConfig):
774
761
  super().__init__()
775
- self.embed_dim = config.d_model
776
-
777
- d_model = config.d_model
778
- # Decoder Self-Attention projections
779
- self.sa_qcontent_proj = nn.Linear(d_model, d_model)
780
- self.sa_qpos_proj = nn.Linear(d_model, d_model)
781
- self.sa_kcontent_proj = nn.Linear(d_model, d_model)
782
- self.sa_kpos_proj = nn.Linear(d_model, d_model)
783
- self.sa_v_proj = nn.Linear(d_model, d_model)
784
-
785
- self.self_attn = ConditionalDetrAttention(
786
- embed_dim=self.embed_dim,
787
- out_dim=self.embed_dim,
788
- num_heads=config.decoder_attention_heads,
762
+ self.hidden_size = config.d_model
763
+ self.self_attn = ConditionalDetrDecoderSelfAttention(
764
+ config=config,
765
+ hidden_size=self.hidden_size,
766
+ num_attention_heads=config.decoder_attention_heads,
789
767
  dropout=config.attention_dropout,
790
768
  )
791
769
  self.dropout = config.dropout
792
- self.activation_fn = ACT2FN[config.activation_function]
793
- self.activation_dropout = config.activation_dropout
794
-
795
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
796
-
797
- # Decoder Cross-Attention projections
798
- self.ca_qcontent_proj = nn.Linear(d_model, d_model)
799
- self.ca_qpos_proj = nn.Linear(d_model, d_model)
800
- self.ca_kcontent_proj = nn.Linear(d_model, d_model)
801
- self.ca_kpos_proj = nn.Linear(d_model, d_model)
802
- self.ca_v_proj = nn.Linear(d_model, d_model)
803
- self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
804
770
 
805
- self.encoder_attn = ConditionalDetrAttention(
806
- self.embed_dim * 2, self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout
771
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
772
+ self.encoder_attn = ConditionalDetrDecoderCrossAttention(
773
+ config=config,
774
+ hidden_size=self.hidden_size,
775
+ num_attention_heads=config.decoder_attention_heads,
776
+ dropout=config.attention_dropout,
807
777
  )
808
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
809
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
810
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
811
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
812
- self.nhead = config.decoder_attention_heads
778
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
779
+ self.mlp = ConditionalDetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
780
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
813
781
 
814
782
  def forward(
815
783
  self,
816
784
  hidden_states: torch.Tensor,
817
- attention_mask: Optional[torch.Tensor] = None,
818
- object_queries: Optional[torch.Tensor] = None,
819
- query_position_embeddings: Optional[torch.Tensor] = None,
820
- query_sine_embed: Optional[torch.Tensor] = None,
821
- encoder_hidden_states: Optional[torch.Tensor] = None,
822
- encoder_attention_mask: Optional[torch.Tensor] = None,
823
- output_attentions: Optional[bool] = False,
824
- is_first: Optional[bool] = False,
825
- ):
785
+ attention_mask: torch.Tensor | None = None,
786
+ spatial_position_embeddings: torch.Tensor | None = None,
787
+ query_position_embeddings: torch.Tensor | None = None,
788
+ query_sine_embed: torch.Tensor | None = None,
789
+ encoder_hidden_states: torch.Tensor | None = None,
790
+ encoder_attention_mask: torch.Tensor | None = None,
791
+ is_first: bool | None = False,
792
+ **kwargs: Unpack[TransformersKwargs],
793
+ ) -> torch.Tensor:
826
794
  """
827
795
  Args:
828
796
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
829
797
  attention_mask (`torch.FloatTensor`): attention mask of size
830
798
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
831
799
  values.
832
- object_queries (`torch.FloatTensor`, *optional*):
833
- object_queries that are added to the queries and keys
834
- in the cross-attention layer.
800
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
801
+ Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
835
802
  query_position_embeddings (`torch.FloatTensor`, *optional*):
836
803
  object_queries that are added to the queries and keys
837
- in the self-attention layer.
804
+ in the self-attention layer.
838
805
  encoder_hidden_states (`torch.FloatTensor`):
839
806
  cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
840
807
  encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
@@ -846,108 +813,49 @@ class ConditionalDetrDecoderLayer(GradientCheckpointingLayer):
846
813
  """
847
814
  residual = hidden_states
848
815
 
849
- # ========== Begin of Self-Attention =============
850
- # Apply projections here
851
- # shape: num_queries x batch_size x 256
852
- q_content = self.sa_qcontent_proj(
853
- hidden_states
854
- ) # target is the input of the first decoder layer. zero by default.
855
- q_pos = self.sa_qpos_proj(query_position_embeddings)
856
- k_content = self.sa_kcontent_proj(hidden_states)
857
- k_pos = self.sa_kpos_proj(query_position_embeddings)
858
- v = self.sa_v_proj(hidden_states)
859
-
860
- _, num_queries, n_model = q_content.shape
861
-
862
- q = q_content + q_pos
863
- k = k_content + k_pos
864
- hidden_states, self_attn_weights = self.self_attn(
865
- hidden_states=q,
816
+ hidden_states, _ = self.self_attn(
817
+ hidden_states=hidden_states,
818
+ query_position_embeddings=query_position_embeddings,
866
819
  attention_mask=attention_mask,
867
- key_states=k,
868
- value_states=v,
869
- output_attentions=output_attentions,
820
+ **kwargs,
870
821
  )
871
- # ============ End of Self-Attention =============
872
822
 
873
823
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
874
824
  hidden_states = residual + hidden_states
875
825
  hidden_states = self.self_attn_layer_norm(hidden_states)
876
826
 
877
- # ========== Begin of Cross-Attention =============
878
- # Apply projections here
879
- # shape: num_queries x batch_size x 256
880
- q_content = self.ca_qcontent_proj(hidden_states)
881
- k_content = self.ca_kcontent_proj(encoder_hidden_states)
882
- v = self.ca_v_proj(encoder_hidden_states)
883
-
884
- batch_size, num_queries, n_model = q_content.shape
885
- _, source_len, _ = k_content.shape
886
-
887
- k_pos = self.ca_kpos_proj(object_queries)
888
-
889
- # For the first decoder layer, we concatenate the positional embedding predicted from
890
- # the object query (the positional embedding) into the original query (key) in DETR.
891
- if is_first:
892
- q_pos = self.ca_qpos_proj(query_position_embeddings)
893
- q = q_content + q_pos
894
- k = k_content + k_pos
895
- else:
896
- q = q_content
897
- k = k_content
898
-
899
- q = q.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
900
- query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
901
- query_sine_embed = query_sine_embed.view(batch_size, num_queries, self.nhead, n_model // self.nhead)
902
- q = torch.cat([q, query_sine_embed], dim=3).view(batch_size, num_queries, n_model * 2)
903
- k = k.view(batch_size, source_len, self.nhead, n_model // self.nhead)
904
- k_pos = k_pos.view(batch_size, source_len, self.nhead, n_model // self.nhead)
905
- k = torch.cat([k, k_pos], dim=3).view(batch_size, source_len, n_model * 2)
906
-
907
- # Cross-Attention Block
908
- cross_attn_weights = None
909
827
  if encoder_hidden_states is not None:
910
828
  residual = hidden_states
911
829
 
912
- hidden_states, cross_attn_weights = self.encoder_attn(
913
- hidden_states=q,
830
+ hidden_states, _ = self.encoder_attn(
831
+ hidden_states=hidden_states,
832
+ encoder_hidden_states=encoder_hidden_states,
914
833
  attention_mask=encoder_attention_mask,
915
- key_states=k,
916
- value_states=v,
917
- output_attentions=output_attentions,
834
+ query_sine_embed=query_sine_embed,
835
+ encoder_position_embeddings=spatial_position_embeddings,
836
+ # Only pass query_position_embeddings for the first layer
837
+ query_position_embeddings=query_position_embeddings if is_first else None,
838
+ **kwargs,
918
839
  )
919
840
 
920
841
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
921
842
  hidden_states = residual + hidden_states
922
843
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
923
844
 
924
- # ============ End of Cross-Attention =============
925
-
926
845
  # Fully Connected
927
846
  residual = hidden_states
928
- hidden_states = self.activation_fn(self.fc1(hidden_states))
929
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
930
- hidden_states = self.fc2(hidden_states)
931
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
847
+ hidden_states = self.mlp(hidden_states)
932
848
  hidden_states = residual + hidden_states
933
849
  hidden_states = self.final_layer_norm(hidden_states)
934
850
 
935
- outputs = (hidden_states,)
851
+ return hidden_states
936
852
 
937
- if output_attentions:
938
- outputs += (self_attn_weights, cross_attn_weights)
939
853
 
940
- return outputs
941
-
942
-
943
- # Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with DetrMLPPredictionHead->MLP
944
- class MLP(nn.Module):
854
+ class ConditionalDetrMLPPredictionHead(nn.Module):
945
855
  """
946
856
  Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
947
857
  height and width of a bounding box w.r.t. an image.
948
858
 
949
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
950
-
951
859
  """
952
860
 
953
861
  def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
@@ -962,29 +870,202 @@ class MLP(nn.Module):
962
870
  return x
963
871
 
964
872
 
873
+ class ConditionalDetrConvBlock(nn.Module):
874
+ """Basic conv block: Conv3x3 -> GroupNorm -> Activation."""
875
+
876
+ def __init__(self, in_channels: int, out_channels: int, activation: str = "relu"):
877
+ super().__init__()
878
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
879
+ self.norm = nn.GroupNorm(min(8, out_channels), out_channels)
880
+ self.activation = ACT2FN[activation]
881
+
882
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
883
+ return self.activation(self.norm(self.conv(x)))
884
+
885
+
886
+ class ConditionalDetrFPNFusionStage(nn.Module):
887
+ """Single FPN fusion stage combining low-resolution features with high-resolution FPN features."""
888
+
889
+ def __init__(self, fpn_channels: int, current_channels: int, output_channels: int, activation: str = "relu"):
890
+ super().__init__()
891
+ self.fpn_adapter = nn.Conv2d(fpn_channels, current_channels, kernel_size=1)
892
+ self.refine = ConditionalDetrConvBlock(current_channels, output_channels, activation)
893
+
894
+ def forward(self, features: torch.Tensor, fpn_features: torch.Tensor) -> torch.Tensor:
895
+ """
896
+ Args:
897
+ features: Current features to upsample, shape (B*Q, current_channels, H_in, W_in)
898
+ fpn_features: FPN features at target resolution, shape (B*Q, fpn_channels, H_out, W_out)
899
+
900
+ Returns:
901
+ Fused and refined features, shape (B*Q, output_channels, H_out, W_out)
902
+ """
903
+ fpn_features = self.fpn_adapter(fpn_features)
904
+ features = nn.functional.interpolate(features, size=fpn_features.shape[-2:], mode="nearest")
905
+ return self.refine(fpn_features + features)
906
+
907
+
908
+ class ConditionalDetrMaskHeadSmallConv(nn.Module):
909
+ """
910
+ Segmentation mask head that generates per-query masks using FPN-based progressive upsampling.
911
+
912
+ Combines attention maps (spatial localization) with encoder features (semantics) and progressively
913
+ upsamples through multiple scales, fusing with FPN features for high-resolution detail.
914
+ """
915
+
916
+ def __init__(
917
+ self,
918
+ input_channels: int,
919
+ fpn_channels: list[int],
920
+ hidden_size: int,
921
+ activation_function: str = "relu",
922
+ ):
923
+ super().__init__()
924
+ if input_channels % 8 != 0:
925
+ raise ValueError(f"input_channels must be divisible by 8, got {input_channels}")
926
+
927
+ self.conv1 = ConditionalDetrConvBlock(input_channels, input_channels, activation_function)
928
+ self.conv2 = ConditionalDetrConvBlock(input_channels, hidden_size // 2, activation_function)
929
+
930
+ # Progressive channel reduction: /2 -> /4 -> /8 -> /16
931
+ self.fpn_stages = nn.ModuleList(
932
+ [
933
+ ConditionalDetrFPNFusionStage(
934
+ fpn_channels[0], hidden_size // 2, hidden_size // 4, activation_function
935
+ ),
936
+ ConditionalDetrFPNFusionStage(
937
+ fpn_channels[1], hidden_size // 4, hidden_size // 8, activation_function
938
+ ),
939
+ ConditionalDetrFPNFusionStage(
940
+ fpn_channels[2], hidden_size // 8, hidden_size // 16, activation_function
941
+ ),
942
+ ]
943
+ )
944
+
945
+ self.output_conv = nn.Conv2d(hidden_size // 16, 1, kernel_size=3, padding=1)
946
+
947
+ def forward(
948
+ self,
949
+ features: torch.Tensor,
950
+ attention_masks: torch.Tensor,
951
+ fpn_features: list[torch.Tensor],
952
+ ) -> torch.Tensor:
953
+ """
954
+ Args:
955
+ features: Encoder output features, shape (batch_size, hidden_size, H, W)
956
+ attention_masks: Cross-attention maps from decoder, shape (batch_size, num_queries, num_heads, H, W)
957
+ fpn_features: List of 3 FPN features from low to high resolution, each (batch_size, C, H, W)
958
+
959
+ Returns:
960
+ Predicted masks, shape (batch_size * num_queries, 1, output_H, output_W)
961
+ """
962
+ num_queries = attention_masks.shape[1]
963
+
964
+ # Expand to (batch_size * num_queries) dimension
965
+ features = features.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1)
966
+ attention_masks = attention_masks.flatten(0, 1)
967
+ fpn_features = [
968
+ fpn_feat.unsqueeze(1).expand(-1, num_queries, -1, -1, -1).flatten(0, 1) for fpn_feat in fpn_features
969
+ ]
970
+
971
+ hidden_states = torch.cat([features, attention_masks], dim=1)
972
+ hidden_states = self.conv1(hidden_states)
973
+ hidden_states = self.conv2(hidden_states)
974
+
975
+ for fpn_stage, fpn_feat in zip(self.fpn_stages, fpn_features):
976
+ hidden_states = fpn_stage(hidden_states, fpn_feat)
977
+
978
+ return self.output_conv(hidden_states)
979
+
980
+
981
+ class ConditionalDetrMHAttentionMap(nn.Module):
982
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
983
+
984
+ def __init__(
985
+ self,
986
+ hidden_size: int,
987
+ num_attention_heads: int,
988
+ dropout: float = 0.0,
989
+ bias: bool = True,
990
+ ):
991
+ super().__init__()
992
+ self.head_dim = hidden_size // num_attention_heads
993
+ self.scaling = self.head_dim**-0.5
994
+ self.attention_dropout = dropout
995
+
996
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
997
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
998
+
999
+ def forward(
1000
+ self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: torch.Tensor | None = None
1001
+ ):
1002
+ query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim)
1003
+ key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:])
1004
+
1005
+ query_states = self.q_proj(query_states).view(query_hidden_shape)
1006
+ key_states = nn.functional.conv2d(
1007
+ key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias
1008
+ ).view(key_hidden_shape)
1009
+
1010
+ batch_size, num_queries, num_heads, head_dim = query_states.shape
1011
+ _, _, _, height, width = key_states.shape
1012
+ query_shape = (batch_size * num_heads, num_queries, head_dim)
1013
+ key_shape = (batch_size * num_heads, height * width, head_dim)
1014
+ attn_weights_shape = (batch_size, num_heads, num_queries, height, width)
1015
+
1016
+ query = query_states.transpose(1, 2).contiguous().view(query_shape)
1017
+ key = key_states.permute(0, 1, 3, 4, 2).contiguous().view(key_shape)
1018
+
1019
+ attn_weights = (
1020
+ (torch.matmul(query * self.scaling, key.transpose(1, 2))).view(attn_weights_shape).transpose(1, 2)
1021
+ )
1022
+
1023
+ if attention_mask is not None:
1024
+ attn_weights = attn_weights + attention_mask
1025
+
1026
+ attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size())
1027
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
1028
+
1029
+ return attn_weights
1030
+
1031
+
965
1032
  @auto_docstring
966
- # Copied from transformers.models.detr.modeling_detr.DetrPreTrainedModel with Detr->ConditionalDetr
967
1033
  class ConditionalDetrPreTrainedModel(PreTrainedModel):
968
1034
  config: ConditionalDetrConfig
969
1035
  base_model_prefix = "model"
970
1036
  main_input_name = "pixel_values"
971
1037
  input_modalities = ("image",)
972
1038
  _no_split_modules = [r"ConditionalDetrConvEncoder", r"ConditionalDetrEncoderLayer", r"ConditionalDetrDecoderLayer"]
1039
+ supports_gradient_checkpointing = True
1040
+ _supports_sdpa = True
1041
+ _supports_flash_attn = True
1042
+ _supports_attention_backend = True
1043
+ _supports_flex_attn = True # Uses create_bidirectional_masks for attention masking
1044
+ _keys_to_ignore_on_load_unexpected = [
1045
+ r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
1046
+ ]
973
1047
 
974
1048
  @torch.no_grad()
975
1049
  def _init_weights(self, module):
976
1050
  std = self.config.init_std
977
1051
  xavier_std = self.config.init_xavier_std
978
1052
 
979
- if isinstance(module, ConditionalDetrMHAttentionMap):
980
- init.zeros_(module.k_linear.bias)
981
- init.zeros_(module.q_linear.bias)
982
- init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
983
- init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
1053
+ if isinstance(module, ConditionalDetrMaskHeadSmallConv):
1054
+ # ConditionalDetrMaskHeadSmallConv uses kaiming initialization for all its Conv2d layers
1055
+ for m in module.modules():
1056
+ if isinstance(m, nn.Conv2d):
1057
+ init.kaiming_uniform_(m.weight, a=1)
1058
+ if m.bias is not None:
1059
+ init.constant_(m.bias, 0)
1060
+ elif isinstance(module, ConditionalDetrMHAttentionMap):
1061
+ init.zeros_(module.k_proj.bias)
1062
+ init.zeros_(module.q_proj.bias)
1063
+ init.xavier_uniform_(module.k_proj.weight, gain=xavier_std)
1064
+ init.xavier_uniform_(module.q_proj.weight, gain=xavier_std)
984
1065
  elif isinstance(module, ConditionalDetrLearnedPositionEmbedding):
985
1066
  init.uniform_(module.row_embeddings.weight)
986
1067
  init.uniform_(module.column_embeddings.weight)
987
- if isinstance(module, (nn.Linear, nn.Conv2d)):
1068
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
988
1069
  init.normal_(module.weight, mean=0.0, std=std)
989
1070
  if module.bias is not None:
990
1071
  init.zeros_(module.bias)
@@ -998,50 +1079,38 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
998
1079
  init.zeros_(module.bias)
999
1080
 
1000
1081
 
1001
- # Copied from transformers.models.detr.modeling_detr.DetrEncoder with Detr->ConditionalDetr,DETR->ConditionalDETR
1002
1082
  class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
1003
1083
  """
1004
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
1005
- [`ConditionalDetrEncoderLayer`].
1006
-
1007
- The encoder updates the flattened feature map through multiple self-attention layers.
1008
-
1009
- Small tweak for ConditionalDETR:
1010
-
1011
- - object_queries are added to the forward pass.
1084
+ Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of
1085
+ [`ConditionalDetrEncoderLayer`] modules.
1012
1086
 
1013
1087
  Args:
1014
- config: ConditionalDetrConfig
1088
+ config (`ConditionalDetrConfig`): Model configuration object.
1015
1089
  """
1016
1090
 
1091
+ _can_record_outputs = {"hidden_states": ConditionalDetrEncoderLayer, "attentions": ConditionalDetrSelfAttention}
1092
+
1017
1093
  def __init__(self, config: ConditionalDetrConfig):
1018
1094
  super().__init__(config)
1019
1095
 
1020
1096
  self.dropout = config.dropout
1021
- self.layerdrop = config.encoder_layerdrop
1022
-
1023
1097
  self.layers = nn.ModuleList([ConditionalDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
1024
1098
 
1025
- # in the original ConditionalDETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default
1026
-
1027
1099
  # Initialize weights and apply final processing
1028
1100
  self.post_init()
1029
1101
 
1102
+ @check_model_inputs()
1030
1103
  def forward(
1031
1104
  self,
1032
1105
  inputs_embeds=None,
1033
1106
  attention_mask=None,
1034
- object_queries=None,
1035
- output_attentions=None,
1036
- output_hidden_states=None,
1037
- return_dict=None,
1038
- **kwargs,
1039
- ):
1107
+ spatial_position_embeddings=None,
1108
+ **kwargs: Unpack[TransformersKwargs],
1109
+ ) -> BaseModelOutput:
1040
1110
  r"""
1041
1111
  Args:
1042
1112
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1043
1113
  Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
1044
-
1045
1114
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1046
1115
  Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
1047
1116
 
@@ -1049,69 +1118,44 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
1049
1118
  - 0 for pixel features that are padding (i.e. **masked**).
1050
1119
 
1051
1120
  [What are attention masks?](../glossary#attention-mask)
1052
-
1053
- object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1054
- Object queries that are added to the queries in each self-attention layer.
1055
-
1056
- output_attentions (`bool`, *optional*):
1057
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1058
- returned tensors for more detail.
1059
- output_hidden_states (`bool`, *optional*):
1060
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1061
- for more detail.
1062
- return_dict (`bool`, *optional*):
1063
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1121
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1122
+ Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
1064
1123
  """
1065
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1066
- output_hidden_states = (
1067
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1068
- )
1069
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1070
-
1071
1124
  hidden_states = inputs_embeds
1072
1125
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1073
1126
 
1074
1127
  # expand attention_mask
1075
1128
  if attention_mask is not None:
1076
1129
  # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
1077
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1078
-
1079
- encoder_states = () if output_hidden_states else None
1080
- all_attentions = () if output_attentions else None
1081
- for i, encoder_layer in enumerate(self.layers):
1082
- if output_hidden_states:
1083
- encoder_states = encoder_states + (hidden_states,)
1084
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
1085
- to_drop = False
1086
- if self.training:
1087
- dropout_probability = torch.rand([])
1088
- if dropout_probability < self.layerdrop: # skip the layer
1089
- to_drop = True
1130
+ attention_mask = create_bidirectional_mask(
1131
+ config=self.config,
1132
+ input_embeds=inputs_embeds,
1133
+ attention_mask=attention_mask,
1134
+ )
1090
1135
 
1091
- if to_drop:
1092
- layer_outputs = (None, None)
1093
- else:
1094
- # we add object_queries as extra input to the encoder_layer
1095
- layer_outputs = encoder_layer(
1096
- hidden_states,
1097
- attention_mask,
1098
- object_queries=object_queries,
1099
- output_attentions=output_attentions,
1100
- )
1101
-
1102
- hidden_states = layer_outputs[0]
1103
-
1104
- if output_attentions:
1105
- all_attentions = all_attentions + (layer_outputs[1],)
1106
-
1107
- if output_hidden_states:
1108
- encoder_states = encoder_states + (hidden_states,)
1109
-
1110
- if not return_dict:
1111
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1112
- return BaseModelOutput(
1113
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
1114
- )
1136
+ for encoder_layer in self.layers:
1137
+ # we add spatial_position_embeddings as extra input to the encoder_layer
1138
+ hidden_states = encoder_layer(
1139
+ hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs
1140
+ )
1141
+
1142
+ return BaseModelOutput(last_hidden_state=hidden_states)
1143
+
1144
+
1145
+ # function to generate sine positional embedding for 2d coordinates
1146
+ def gen_sine_position_embeddings(pos_tensor, d_model):
1147
+ scale = 2 * math.pi
1148
+ dim = d_model // 2
1149
+ dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
1150
+ dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
1151
+ x_embed = pos_tensor[:, :, 0] * scale
1152
+ y_embed = pos_tensor[:, :, 1] * scale
1153
+ pos_x = x_embed[:, :, None] / dim_t
1154
+ pos_y = y_embed[:, :, None] / dim_t
1155
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
1156
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
1157
+ pos = torch.cat((pos_y, pos_x), dim=2)
1158
+ return pos.to(pos_tensor.dtype)
1115
1159
 
1116
1160
 
1117
1161
  class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
@@ -1129,39 +1173,44 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1129
1173
  config: ConditionalDetrConfig
1130
1174
  """
1131
1175
 
1176
+ _can_record_outputs = {
1177
+ "hidden_states": ConditionalDetrDecoderLayer,
1178
+ "attentions": OutputRecorder(ConditionalDetrDecoderSelfAttention, layer_name="self_attn", index=1),
1179
+ "cross_attentions": OutputRecorder(ConditionalDetrDecoderCrossAttention, layer_name="encoder_attn", index=1),
1180
+ }
1181
+
1132
1182
  def __init__(self, config: ConditionalDetrConfig):
1133
1183
  super().__init__(config)
1184
+ self.hidden_size = config.d_model
1185
+
1134
1186
  self.dropout = config.dropout
1135
1187
  self.layerdrop = config.decoder_layerdrop
1136
1188
 
1137
1189
  self.layers = nn.ModuleList([ConditionalDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
1138
1190
  # in Conditional DETR, the decoder uses layernorm after the last decoder layer output
1139
1191
  self.layernorm = nn.LayerNorm(config.d_model)
1140
- d_model = config.d_model
1141
- self.gradient_checkpointing = False
1142
1192
 
1143
1193
  # query_scale is the FFN applied on f to generate transformation T
1144
- self.query_scale = MLP(d_model, d_model, d_model, 2)
1145
- self.ref_point_head = MLP(d_model, d_model, 2, 2)
1194
+ self.query_scale = ConditionalDetrMLPPredictionHead(self.hidden_size, self.hidden_size, self.hidden_size, 2)
1195
+ self.ref_point_head = ConditionalDetrMLPPredictionHead(self.hidden_size, self.hidden_size, 2, 2)
1146
1196
  for layer_id in range(config.decoder_layers - 1):
1147
- self.layers[layer_id + 1].ca_qpos_proj = None
1197
+ # Set q_pos_proj to None for layers after the first (only first layer uses query position embeddings)
1198
+ self.layers[layer_id + 1].encoder_attn.q_pos_proj = None
1148
1199
 
1149
1200
  # Initialize weights and apply final processing
1150
1201
  self.post_init()
1151
1202
 
1203
+ @check_model_inputs()
1152
1204
  def forward(
1153
1205
  self,
1154
1206
  inputs_embeds=None,
1155
1207
  attention_mask=None,
1156
1208
  encoder_hidden_states=None,
1157
1209
  encoder_attention_mask=None,
1158
- object_queries=None,
1159
- query_position_embeddings=None,
1160
- output_attentions=None,
1161
- output_hidden_states=None,
1162
- return_dict=None,
1163
- **kwargs,
1164
- ):
1210
+ spatial_position_embeddings=None,
1211
+ object_queries_position_embeddings=None,
1212
+ **kwargs: Unpack[TransformersKwargs],
1213
+ ) -> ConditionalDetrDecoderOutput:
1165
1214
  r"""
1166
1215
  Args:
1167
1216
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -1184,46 +1233,28 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1184
1233
  - 1 for pixels that are real (i.e. **not masked**),
1185
1234
  - 0 for pixels that are padding (i.e. **masked**).
1186
1235
 
1187
- object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1188
- Position embeddings that are added to the queries and keys in each cross-attention layer.
1189
- query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
1236
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1237
+ Spatial position embeddings that are added to the queries and keys in each cross-attention layer.
1238
+ object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
1190
1239
  , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
1191
- output_attentions (`bool`, *optional*):
1192
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1193
- returned tensors for more detail.
1194
- output_hidden_states (`bool`, *optional*):
1195
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1196
- for more detail.
1197
- return_dict (`bool`, *optional*):
1198
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1199
1240
  """
1200
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1201
- output_hidden_states = (
1202
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1203
- )
1204
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1205
-
1206
1241
  if inputs_embeds is not None:
1207
1242
  hidden_states = inputs_embeds
1208
- input_shape = inputs_embeds.size()[:-1]
1209
1243
 
1210
1244
  # expand encoder attention mask
1211
1245
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1212
1246
  # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
1213
- encoder_attention_mask = _prepare_4d_attention_mask(
1214
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1247
+ encoder_attention_mask = create_bidirectional_mask(
1248
+ self.config,
1249
+ inputs_embeds,
1250
+ encoder_attention_mask,
1215
1251
  )
1216
1252
 
1217
1253
  # optional intermediate hidden states
1218
1254
  intermediate = () if self.config.auxiliary_loss else None
1219
1255
 
1220
- # decoder layers
1221
- all_hidden_states = () if output_hidden_states else None
1222
- all_self_attns = () if output_attentions else None
1223
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1224
-
1225
1256
  reference_points_before_sigmoid = self.ref_point_head(
1226
- query_position_embeddings
1257
+ object_queries_position_embeddings
1227
1258
  ) # [num_queries, batch_size, 2]
1228
1259
  reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1)
1229
1260
  obj_center = reference_points[..., :2].transpose(0, 1)
@@ -1231,9 +1262,6 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1231
1262
  query_sine_embed_before_transformation = gen_sine_position_embeddings(obj_center, self.config.d_model)
1232
1263
 
1233
1264
  for idx, decoder_layer in enumerate(self.layers):
1234
- # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
1235
- if output_hidden_states:
1236
- all_hidden_states += (hidden_states,)
1237
1265
  if self.training:
1238
1266
  dropout_probability = torch.rand([])
1239
1267
  if dropout_probability < self.layerdrop:
@@ -1245,59 +1273,31 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1245
1273
  # apply transformation
1246
1274
  query_sine_embed = query_sine_embed_before_transformation * pos_transformation
1247
1275
 
1248
- layer_outputs = decoder_layer(
1276
+ hidden_states = decoder_layer(
1249
1277
  hidden_states,
1250
- None, # attention_mask
1251
- object_queries,
1252
- query_position_embeddings,
1278
+ None,
1279
+ spatial_position_embeddings,
1280
+ object_queries_position_embeddings,
1253
1281
  query_sine_embed,
1254
1282
  encoder_hidden_states, # as a positional argument for gradient checkpointing
1255
1283
  encoder_attention_mask=encoder_attention_mask,
1256
- output_attentions=output_attentions,
1257
1284
  is_first=(idx == 0),
1285
+ **kwargs,
1258
1286
  )
1259
1287
 
1260
- hidden_states = layer_outputs[0]
1261
-
1262
1288
  if self.config.auxiliary_loss:
1263
1289
  hidden_states = self.layernorm(hidden_states)
1264
1290
  intermediate += (hidden_states,)
1265
1291
 
1266
- if output_attentions:
1267
- all_self_attns += (layer_outputs[1],)
1268
-
1269
- if encoder_hidden_states is not None:
1270
- all_cross_attentions += (layer_outputs[2],)
1271
-
1272
1292
  # finally, apply layernorm
1273
1293
  hidden_states = self.layernorm(hidden_states)
1274
1294
 
1275
- # add hidden states from the last decoder layer
1276
- if output_hidden_states:
1277
- all_hidden_states += (hidden_states,)
1278
-
1279
1295
  # stack intermediate decoder activations
1280
1296
  if self.config.auxiliary_loss:
1281
1297
  intermediate = torch.stack(intermediate)
1282
1298
 
1283
- if not return_dict:
1284
- return tuple(
1285
- v
1286
- for v in [
1287
- hidden_states,
1288
- all_hidden_states,
1289
- all_self_attns,
1290
- all_cross_attentions,
1291
- intermediate,
1292
- reference_points,
1293
- ]
1294
- if v is not None
1295
- )
1296
1299
  return ConditionalDetrDecoderOutput(
1297
1300
  last_hidden_state=hidden_states,
1298
- hidden_states=all_hidden_states,
1299
- attentions=all_self_attns,
1300
- cross_attentions=all_cross_attentions,
1301
1301
  intermediate_hidden_states=intermediate,
1302
1302
  reference_points=reference_points,
1303
1303
  )
@@ -1305,23 +1305,24 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1305
1305
 
1306
1306
  @auto_docstring(
1307
1307
  custom_intro="""
1308
- The bare Conditional DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw
1309
- hidden-states without any specific head on top.
1308
+ The bare CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without
1309
+ any specific head on top.
1310
1310
  """
1311
1311
  )
1312
1312
  class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1313
1313
  def __init__(self, config: ConditionalDetrConfig):
1314
1314
  super().__init__(config)
1315
1315
 
1316
- # Create backbone + positional encoding
1317
- backbone = ConditionalDetrConvEncoder(config)
1318
- object_queries = build_position_encoding(config)
1319
- self.backbone = ConditionalDetrConvModel(backbone, object_queries)
1320
-
1321
- # Create projection layer
1322
- self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
1316
+ self.backbone = ConditionalDetrConvEncoder(config)
1323
1317
 
1318
+ if config.position_embedding_type == "sine":
1319
+ self.position_embedding = ConditionalDetrSinePositionEmbedding(config.d_model // 2, normalize=True)
1320
+ elif config.position_embedding_type == "learned":
1321
+ self.position_embedding = ConditionalDetrLearnedPositionEmbedding(config.d_model // 2)
1322
+ else:
1323
+ raise ValueError(f"Not supported {config.position_embedding_type}")
1324
1324
  self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
1325
+ self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
1325
1326
 
1326
1327
  self.encoder = ConditionalDetrEncoder(config)
1327
1328
  self.decoder = ConditionalDetrDecoder(config)
@@ -1330,27 +1331,25 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1330
1331
  self.post_init()
1331
1332
 
1332
1333
  def freeze_backbone(self):
1333
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1334
+ for _, param in self.backbone.model.named_parameters():
1334
1335
  param.requires_grad_(False)
1335
1336
 
1336
1337
  def unfreeze_backbone(self):
1337
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1338
+ for _, param in self.backbone.model.named_parameters():
1338
1339
  param.requires_grad_(True)
1339
1340
 
1340
1341
  @auto_docstring
1342
+ @can_return_tuple
1341
1343
  def forward(
1342
1344
  self,
1343
1345
  pixel_values: torch.FloatTensor,
1344
- pixel_mask: Optional[torch.LongTensor] = None,
1345
- decoder_attention_mask: Optional[torch.LongTensor] = None,
1346
- encoder_outputs: Optional[torch.FloatTensor] = None,
1347
- inputs_embeds: Optional[torch.FloatTensor] = None,
1348
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1349
- output_attentions: Optional[bool] = None,
1350
- output_hidden_states: Optional[bool] = None,
1351
- return_dict: Optional[bool] = None,
1352
- **kwargs,
1353
- ) -> Union[tuple[torch.FloatTensor], ConditionalDetrModelOutput]:
1346
+ pixel_mask: torch.LongTensor | None = None,
1347
+ decoder_attention_mask: torch.LongTensor | None = None,
1348
+ encoder_outputs: torch.FloatTensor | None = None,
1349
+ inputs_embeds: torch.FloatTensor | None = None,
1350
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
1351
+ **kwargs: Unpack[TransformersKwargs],
1352
+ ) -> ConditionalDetrModelOutput:
1354
1353
  r"""
1355
1354
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1356
1355
  Not used by default. Can be used to mask object queries.
@@ -1386,12 +1385,6 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1386
1385
  >>> list(last_hidden_states.shape)
1387
1386
  [1, 300, 256]
1388
1387
  ```"""
1389
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1390
- output_hidden_states = (
1391
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1392
- )
1393
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1394
-
1395
1388
  batch_size, num_channels, height, width = pixel_values.shape
1396
1389
  device = pixel_values.device
1397
1390
 
@@ -1401,7 +1394,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1401
1394
  # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1402
1395
  # pixel_values should be of shape (batch_size, num_channels, height, width)
1403
1396
  # pixel_mask should be of shape (batch_size, height, width)
1404
- features, object_queries_list = self.backbone(pixel_values, pixel_mask)
1397
+ features = self.backbone(pixel_values, pixel_mask)
1405
1398
 
1406
1399
  # get final feature map and downsampled mask
1407
1400
  feature_map, mask = features[-1]
@@ -1412,53 +1405,52 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1412
1405
  # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1413
1406
  projected_feature_map = self.input_projection(feature_map)
1414
1407
 
1415
- # Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1408
+ # Generate position embeddings
1409
+ spatial_position_embeddings = self.position_embedding(
1410
+ shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
1411
+ )
1412
+
1413
+ # Third, flatten the feature map of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1416
1414
  # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1417
1415
  flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1418
- object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1419
1416
 
1420
1417
  flattened_mask = mask.flatten(1)
1421
1418
 
1422
- # Fourth, sent flattened_features + flattened_mask + object_queries through encoder
1419
+ # Fourth, sent flattened_features + flattened_mask + spatial_position_embeddings through encoder
1423
1420
  # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
1424
1421
  # flattened_mask is a Tensor of shape (batch_size, height*width)
1425
1422
  if encoder_outputs is None:
1426
1423
  encoder_outputs = self.encoder(
1427
1424
  inputs_embeds=flattened_features,
1428
1425
  attention_mask=flattened_mask,
1429
- object_queries=object_queries,
1430
- output_attentions=output_attentions,
1431
- output_hidden_states=output_hidden_states,
1432
- return_dict=return_dict,
1426
+ spatial_position_embeddings=spatial_position_embeddings,
1427
+ **kwargs,
1433
1428
  )
1434
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1435
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1429
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
1430
+ elif not isinstance(encoder_outputs, BaseModelOutput):
1436
1431
  encoder_outputs = BaseModelOutput(
1437
1432
  last_hidden_state=encoder_outputs[0],
1438
1433
  hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1439
1434
  attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1440
1435
  )
1441
1436
 
1442
- # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
1443
- query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
1444
- queries = torch.zeros_like(query_position_embeddings)
1437
+ # Fifth, sent query embeddings through the decoder (which is conditioned on the encoder output)
1438
+ object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(
1439
+ batch_size, 1, 1
1440
+ )
1441
+ queries = torch.zeros_like(object_queries_position_embeddings)
1445
1442
 
1446
1443
  # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1447
1444
  decoder_outputs = self.decoder(
1448
1445
  inputs_embeds=queries,
1449
1446
  attention_mask=None,
1450
- object_queries=object_queries,
1451
- query_position_embeddings=query_position_embeddings,
1452
- encoder_hidden_states=encoder_outputs[0],
1447
+ spatial_position_embeddings=spatial_position_embeddings,
1448
+ object_queries_position_embeddings=object_queries_position_embeddings,
1449
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
1453
1450
  encoder_attention_mask=flattened_mask,
1454
- output_attentions=output_attentions,
1455
- output_hidden_states=output_hidden_states,
1456
- return_dict=return_dict,
1451
+ **kwargs,
1457
1452
  )
1458
1453
 
1459
- if not return_dict:
1460
- return decoder_outputs + encoder_outputs
1461
-
1462
1454
  return ConditionalDetrModelOutput(
1463
1455
  last_hidden_state=decoder_outputs.last_hidden_state,
1464
1456
  decoder_hidden_states=decoder_outputs.hidden_states,
@@ -1472,45 +1464,26 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1472
1464
  )
1473
1465
 
1474
1466
 
1475
- # Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->ConditionalDetr
1476
- class ConditionalDetrMLPPredictionHead(nn.Module):
1477
- """
1478
- Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1479
- height and width of a bounding box w.r.t. an image.
1480
-
1481
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1482
-
1483
- """
1484
-
1485
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
1486
- super().__init__()
1487
- self.num_layers = num_layers
1488
- h = [hidden_dim] * (num_layers - 1)
1489
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1490
-
1491
- def forward(self, x):
1492
- for i, layer in enumerate(self.layers):
1493
- x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1494
- return x
1467
+ def inverse_sigmoid(x, eps=1e-5):
1468
+ x = x.clamp(min=0, max=1)
1469
+ x1 = x.clamp(min=eps)
1470
+ x2 = (1 - x).clamp(min=eps)
1471
+ return torch.log(x1 / x2)
1495
1472
 
1496
1473
 
1497
1474
  @auto_docstring(
1498
1475
  custom_intro="""
1499
- Conditional DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on
1500
- top, for tasks such as COCO detection.
1476
+ CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) with object detection heads on top, for tasks
1477
+ such as COCO detection.
1501
1478
  """
1502
1479
  )
1503
1480
  class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1504
1481
  def __init__(self, config: ConditionalDetrConfig):
1505
1482
  super().__init__(config)
1506
1483
 
1507
- # CONDITIONAL DETR encoder-decoder model
1484
+ # CONDITIONAL_DETR encoder-decoder model
1508
1485
  self.model = ConditionalDetrModel(config)
1509
-
1510
- # Object detection heads
1511
- self.class_labels_classifier = nn.Linear(
1512
- config.d_model, config.num_labels
1513
- ) # We add one for the "no object" class
1486
+ self.class_labels_classifier = nn.Linear(config.d_model, config.num_labels)
1514
1487
  self.bbox_predictor = ConditionalDetrMLPPredictionHead(
1515
1488
  input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
1516
1489
  )
@@ -1518,25 +1491,19 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1518
1491
  # Initialize weights and apply final processing
1519
1492
  self.post_init()
1520
1493
 
1521
- # taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
1522
- def _set_aux_loss(self, outputs_class, outputs_coord):
1523
- return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
1524
-
1525
1494
  @auto_docstring
1495
+ @can_return_tuple
1526
1496
  def forward(
1527
1497
  self,
1528
1498
  pixel_values: torch.FloatTensor,
1529
- pixel_mask: Optional[torch.LongTensor] = None,
1530
- decoder_attention_mask: Optional[torch.LongTensor] = None,
1531
- encoder_outputs: Optional[torch.FloatTensor] = None,
1532
- inputs_embeds: Optional[torch.FloatTensor] = None,
1533
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1534
- labels: Optional[list[dict]] = None,
1535
- output_attentions: Optional[bool] = None,
1536
- output_hidden_states: Optional[bool] = None,
1537
- return_dict: Optional[bool] = None,
1538
- **kwargs,
1539
- ) -> Union[tuple[torch.FloatTensor], ConditionalDetrObjectDetectionOutput]:
1499
+ pixel_mask: torch.LongTensor | None = None,
1500
+ decoder_attention_mask: torch.LongTensor | None = None,
1501
+ encoder_outputs: torch.FloatTensor | None = None,
1502
+ inputs_embeds: torch.FloatTensor | None = None,
1503
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
1504
+ labels: list[dict] | None = None,
1505
+ **kwargs: Unpack[TransformersKwargs],
1506
+ ) -> ConditionalDetrObjectDetectionOutput:
1540
1507
  r"""
1541
1508
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1542
1509
  Not used by default. Can be used to mask object queries.
@@ -1586,8 +1553,6 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1586
1553
  Detected remote with confidence 0.683 at location [334.48, 73.49, 366.37, 190.01]
1587
1554
  Detected couch with confidence 0.535 at location [0.52, 1.19, 640.35, 475.1]
1588
1555
  ```"""
1589
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1590
-
1591
1556
  # First, sent images through CONDITIONAL_DETR base model to obtain encoder + decoder outputs
1592
1557
  outputs = self.model(
1593
1558
  pixel_values,
@@ -1596,9 +1561,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1596
1561
  encoder_outputs=encoder_outputs,
1597
1562
  inputs_embeds=inputs_embeds,
1598
1563
  decoder_inputs_embeds=decoder_inputs_embeds,
1599
- output_attentions=output_attentions,
1600
- output_hidden_states=output_hidden_states,
1601
- return_dict=return_dict,
1564
+ **kwargs,
1602
1565
  )
1603
1566
 
1604
1567
  sequence_output = outputs[0]
@@ -1606,11 +1569,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1606
1569
  # class logits + predicted bounding boxes
1607
1570
  logits = self.class_labels_classifier(sequence_output)
1608
1571
 
1609
- # Index [-2] is valid only if `output_attentions` and `output_hidden_states`
1610
- # are not specified, otherwise it will be another index which is hard to determine.
1611
- # Leave it as is, because it's not a common case to use
1612
- # return_dict=False + output_attentions=True / output_hidden_states=True
1613
- reference = outputs.reference_points if return_dict else outputs[-2]
1572
+ reference = outputs.reference_points
1614
1573
  reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
1615
1574
 
1616
1575
  hs = sequence_output
@@ -1624,7 +1583,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1624
1583
  outputs_class, outputs_coord = None, None
1625
1584
  if self.config.auxiliary_loss:
1626
1585
  outputs_coords = []
1627
- intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
1586
+ intermediate = outputs.intermediate_hidden_states
1628
1587
  outputs_class = self.class_labels_classifier(intermediate)
1629
1588
  for lvl in range(intermediate.shape[0]):
1630
1589
  tmp = self.bbox_predictor(intermediate[lvl])
@@ -1636,13 +1595,6 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1636
1595
  logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
1637
1596
  )
1638
1597
 
1639
- if not return_dict:
1640
- if auxiliary_outputs is not None:
1641
- output = (logits, pred_boxes) + auxiliary_outputs + outputs
1642
- else:
1643
- output = (logits, pred_boxes) + outputs
1644
- return ((loss, loss_dict) + output) if loss is not None else output
1645
-
1646
1598
  return ConditionalDetrObjectDetectionOutput(
1647
1599
  loss=loss,
1648
1600
  loss_dict=loss_dict,
@@ -1658,14 +1610,38 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1658
1610
  encoder_attentions=outputs.encoder_attentions,
1659
1611
  )
1660
1612
 
1613
+ # taken from https://github.com/Atten4Vis/conditionalDETR/blob/master/models/conditional_detr.py
1614
+ def _set_aux_loss(self, outputs_class, outputs_coord):
1615
+ return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
1616
+
1661
1617
 
1662
1618
  @auto_docstring(
1663
1619
  custom_intro="""
1664
- Conditional DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top,
1665
- for tasks such as COCO panoptic.
1620
+ CONDITIONAL_DETR Model (consisting of a backbone and encoder-decoder Transformer) with a segmentation head on top, for tasks
1621
+ such as COCO panoptic.
1666
1622
  """
1667
1623
  )
1668
1624
  class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1625
+ _checkpoint_conversion_mapping = {
1626
+ "bbox_attention.q_linear": "bbox_attention.q_proj",
1627
+ "bbox_attention.k_linear": "bbox_attention.k_proj",
1628
+ # Mask head refactor
1629
+ "mask_head.lay1": "mask_head.conv1.conv",
1630
+ "mask_head.gn1": "mask_head.conv1.norm",
1631
+ "mask_head.lay2": "mask_head.conv2.conv",
1632
+ "mask_head.gn2": "mask_head.conv2.norm",
1633
+ "mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter",
1634
+ "mask_head.lay3": "mask_head.fpn_stages.0.refine.conv",
1635
+ "mask_head.gn3": "mask_head.fpn_stages.0.refine.norm",
1636
+ "mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter",
1637
+ "mask_head.lay4": "mask_head.fpn_stages.1.refine.conv",
1638
+ "mask_head.gn4": "mask_head.fpn_stages.1.refine.norm",
1639
+ "mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter",
1640
+ "mask_head.lay5": "mask_head.fpn_stages.2.refine.conv",
1641
+ "mask_head.gn5": "mask_head.fpn_stages.2.refine.norm",
1642
+ "mask_head.out_lay": "mask_head.output_conv",
1643
+ }
1644
+
1669
1645
  def __init__(self, config: ConditionalDetrConfig):
1670
1646
  super().__init__(config)
1671
1647
 
@@ -1674,43 +1650,44 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1674
1650
 
1675
1651
  # segmentation head
1676
1652
  hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads
1677
- intermediate_channel_sizes = self.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes
1653
+ intermediate_channel_sizes = self.conditional_detr.model.backbone.intermediate_channel_sizes
1678
1654
 
1679
1655
  self.mask_head = ConditionalDetrMaskHeadSmallConv(
1680
- hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size
1681
- )
1682
-
1683
- self.bbox_attention = ConditionalDetrMHAttentionMap(
1684
- hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std
1656
+ input_channels=hidden_size + number_of_heads,
1657
+ fpn_channels=intermediate_channel_sizes[::-1][-3:],
1658
+ hidden_size=hidden_size,
1659
+ activation_function=config.activation_function,
1685
1660
  )
1686
1661
 
1662
+ self.bbox_attention = ConditionalDetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0)
1687
1663
  # Initialize weights and apply final processing
1688
1664
  self.post_init()
1689
1665
 
1690
1666
  @auto_docstring
1667
+ @can_return_tuple
1691
1668
  def forward(
1692
1669
  self,
1693
1670
  pixel_values: torch.FloatTensor,
1694
- pixel_mask: Optional[torch.LongTensor] = None,
1695
- decoder_attention_mask: Optional[torch.FloatTensor] = None,
1696
- encoder_outputs: Optional[torch.FloatTensor] = None,
1697
- inputs_embeds: Optional[torch.FloatTensor] = None,
1698
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1699
- labels: Optional[list[dict]] = None,
1700
- output_attentions: Optional[bool] = None,
1701
- output_hidden_states: Optional[bool] = None,
1702
- return_dict: Optional[bool] = None,
1703
- **kwargs,
1704
- ) -> Union[tuple[torch.FloatTensor], ConditionalDetrSegmentationOutput]:
1671
+ pixel_mask: torch.LongTensor | None = None,
1672
+ decoder_attention_mask: torch.FloatTensor | None = None,
1673
+ encoder_outputs: torch.FloatTensor | None = None,
1674
+ inputs_embeds: torch.FloatTensor | None = None,
1675
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
1676
+ labels: list[dict] | None = None,
1677
+ **kwargs: Unpack[TransformersKwargs],
1678
+ ) -> tuple[torch.FloatTensor] | ConditionalDetrSegmentationOutput:
1705
1679
  r"""
1706
1680
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1707
- Not used by default. Can be used to mask object queries.
1681
+ Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`:
1682
+
1683
+ - 1 for queries that are **not masked**,
1684
+ - 0 for queries that are **masked**.
1708
1685
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1709
- Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1710
- can choose to directly pass a flattened representation of an image.
1686
+ Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires
1687
+ multi-scale features from the backbone that are not available when bypassing it with inputs_embeds.
1711
1688
  decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1712
1689
  Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1713
- embedded representation.
1690
+ embedded representation. Useful for tasks that require custom query initialization.
1714
1691
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1715
1692
  Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each
1716
1693
  dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels,
@@ -1723,26 +1700,21 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1723
1700
 
1724
1701
  ```python
1725
1702
  >>> import io
1726
- >>> import requests
1703
+ >>> import httpx
1704
+ >>> from io import BytesIO
1727
1705
  >>> from PIL import Image
1728
1706
  >>> import torch
1729
1707
  >>> import numpy
1730
1708
 
1731
- >>> from transformers import (
1732
- ... AutoImageProcessor,
1733
- ... ConditionalDetrConfig,
1734
- ... ConditionalDetrForSegmentation,
1735
- ... )
1709
+ >>> from transformers import AutoImageProcessor, ConditionalDetrForSegmentation
1736
1710
  >>> from transformers.image_transforms import rgb_to_id
1737
1711
 
1738
1712
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1739
- >>> image = Image.open(requests.get(url, stream=True).raw)
1740
-
1741
- >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50")
1713
+ >>> with httpx.stream("GET", url) as response:
1714
+ ... image = Image.open(BytesIO(response.read()))
1742
1715
 
1743
- >>> # randomly initialize all weights of the model
1744
- >>> config = ConditionalDetrConfig()
1745
- >>> model = ConditionalDetrForSegmentation(config)
1716
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/conditional_detr-resnet-50-panoptic")
1717
+ >>> model = ConditionalDetrForSegmentation.from_pretrained("facebook/conditional_detr-resnet-50-panoptic")
1746
1718
 
1747
1719
  >>> # prepare image for the model
1748
1720
  >>> inputs = image_processor(images=image, return_tensors="pt")
@@ -1753,89 +1725,88 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1753
1725
  >>> # Use the `post_process_panoptic_segmentation` method of the `image_processor` to retrieve post-processed panoptic segmentation maps
1754
1726
  >>> # Segmentation results are returned as a list of dictionaries
1755
1727
  >>> result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[(300, 500)])
1728
+
1756
1729
  >>> # A tensor of shape (height, width) where each value denotes a segment id, filled with -1 if no segment is found
1757
1730
  >>> panoptic_seg = result[0]["segmentation"]
1731
+ >>> panoptic_seg.shape
1732
+ torch.Size([300, 500])
1758
1733
  >>> # Get prediction score and segment_id to class_id mapping of each segment
1759
1734
  >>> panoptic_segments_info = result[0]["segments_info"]
1735
+ >>> len(panoptic_segments_info)
1736
+ 5
1760
1737
  ```"""
1761
1738
 
1762
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1763
-
1764
1739
  batch_size, num_channels, height, width = pixel_values.shape
1765
1740
  device = pixel_values.device
1766
1741
 
1767
1742
  if pixel_mask is None:
1768
1743
  pixel_mask = torch.ones((batch_size, height, width), device=device)
1769
1744
 
1770
- # First, get list of feature maps and object_queries
1771
- features, object_queries_list = self.conditional_detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
1745
+ vision_features = self.conditional_detr.model.backbone(pixel_values, pixel_mask)
1746
+ feature_map, mask = vision_features[-1]
1772
1747
 
1773
- # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1774
- feature_map, mask = features[-1]
1775
- batch_size, num_channels, height, width = feature_map.shape
1748
+ # Apply 1x1 conv to map (batch_size, C, H, W) -> (batch_size, hidden_size, H, W), then flatten to (batch_size, HW, hidden_size)
1776
1749
  projected_feature_map = self.conditional_detr.model.input_projection(feature_map)
1777
-
1778
- # Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
1779
- # In other words, turn their shape into (batch_size, sequence_length, hidden_size)
1780
1750
  flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
1781
- object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
1782
-
1751
+ spatial_position_embeddings = self.conditional_detr.model.position_embedding(
1752
+ shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask
1753
+ )
1783
1754
  flattened_mask = mask.flatten(1)
1784
1755
 
1785
- # Fourth, sent flattened_features + flattened_mask + object_queries through encoder
1786
- # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size)
1787
- # flattened_mask is a Tensor of shape (batch_size, height*width)
1788
1756
  if encoder_outputs is None:
1789
1757
  encoder_outputs = self.conditional_detr.model.encoder(
1790
1758
  inputs_embeds=flattened_features,
1791
1759
  attention_mask=flattened_mask,
1792
- object_queries=object_queries,
1793
- output_attentions=output_attentions,
1794
- output_hidden_states=output_hidden_states,
1795
- return_dict=return_dict,
1796
- )
1797
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1798
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1799
- encoder_outputs = BaseModelOutput(
1800
- last_hidden_state=encoder_outputs[0],
1801
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1802
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1760
+ spatial_position_embeddings=spatial_position_embeddings,
1761
+ **kwargs,
1803
1762
  )
1804
1763
 
1805
- # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
1806
- query_position_embeddings = self.conditional_detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
1807
- batch_size, 1, 1
1808
- )
1809
- queries = torch.zeros_like(query_position_embeddings)
1764
+ object_queries_position_embeddings = self.conditional_detr.model.query_position_embeddings.weight.unsqueeze(
1765
+ 0
1766
+ ).repeat(batch_size, 1, 1)
1767
+
1768
+ # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros
1769
+ if decoder_inputs_embeds is not None:
1770
+ queries = decoder_inputs_embeds
1771
+ else:
1772
+ queries = torch.zeros_like(object_queries_position_embeddings)
1810
1773
 
1811
- # decoder outputs consists of (dec_features, dec_hidden, dec_attn)
1812
1774
  decoder_outputs = self.conditional_detr.model.decoder(
1813
1775
  inputs_embeds=queries,
1814
- attention_mask=None,
1815
- object_queries=object_queries,
1816
- query_position_embeddings=query_position_embeddings,
1817
- encoder_hidden_states=encoder_outputs[0],
1776
+ attention_mask=decoder_attention_mask,
1777
+ spatial_position_embeddings=spatial_position_embeddings,
1778
+ object_queries_position_embeddings=object_queries_position_embeddings,
1779
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
1818
1780
  encoder_attention_mask=flattened_mask,
1819
- output_attentions=output_attentions,
1820
- output_hidden_states=output_hidden_states,
1821
- return_dict=return_dict,
1781
+ **kwargs,
1822
1782
  )
1823
1783
 
1824
1784
  sequence_output = decoder_outputs[0]
1825
1785
 
1826
- # Sixth, compute logits, pred_boxes and pred_masks
1827
1786
  logits = self.conditional_detr.class_labels_classifier(sequence_output)
1828
1787
  pred_boxes = self.conditional_detr.bbox_predictor(sequence_output).sigmoid()
1829
1788
 
1830
- memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width)
1831
- mask = flattened_mask.view(batch_size, height, width)
1789
+ height, width = feature_map.shape[-2:]
1790
+ memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view(
1791
+ batch_size, self.config.d_model, height, width
1792
+ )
1793
+ attention_mask = flattened_mask.view(batch_size, height, width)
1832
1794
 
1833
- # FIXME h_boxes takes the last one computed, keep this in mind
1834
- # important: we need to reverse the mask, since in the original implementation the mask works reversed
1835
- # bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32)
1836
- bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask)
1795
+ if attention_mask is not None:
1796
+ min_dtype = torch.finfo(memory.dtype).min
1797
+ attention_mask = torch.where(
1798
+ attention_mask.unsqueeze(1).unsqueeze(1),
1799
+ torch.tensor(0.0, device=memory.device, dtype=memory.dtype),
1800
+ min_dtype,
1801
+ )
1837
1802
 
1838
- seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]])
1803
+ bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=attention_mask)
1804
+
1805
+ seg_masks = self.mask_head(
1806
+ features=projected_feature_map,
1807
+ attention_masks=bbox_mask,
1808
+ fpn_features=[vision_features[2][0], vision_features[1][0], vision_features[0][0]],
1809
+ )
1839
1810
 
1840
1811
  pred_masks = seg_masks.view(
1841
1812
  batch_size, self.conditional_detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]
@@ -1845,20 +1816,13 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1845
1816
  if labels is not None:
1846
1817
  outputs_class, outputs_coord = None, None
1847
1818
  if self.config.auxiliary_loss:
1848
- intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1]
1819
+ intermediate = decoder_outputs.intermediate_hidden_states
1849
1820
  outputs_class = self.conditional_detr.class_labels_classifier(intermediate)
1850
1821
  outputs_coord = self.conditional_detr.bbox_predictor(intermediate).sigmoid()
1851
1822
  loss, loss_dict, auxiliary_outputs = self.loss_function(
1852
- logits, labels, self.device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
1823
+ logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord
1853
1824
  )
1854
1825
 
1855
- if not return_dict:
1856
- if auxiliary_outputs is not None:
1857
- output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs
1858
- else:
1859
- output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs
1860
- return ((loss, loss_dict) + output) if loss is not None else output
1861
-
1862
1826
  return ConditionalDetrSegmentationOutput(
1863
1827
  loss=loss,
1864
1828
  loss_dict=loss_dict,
@@ -1876,120 +1840,6 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1876
1840
  )
1877
1841
 
1878
1842
 
1879
- def _expand(tensor, length: int):
1880
- return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
1881
-
1882
-
1883
- # Copied from transformers.models.detr.modeling_detr.DetrMaskHeadSmallConv with Detr->ConditionalDetr
1884
- class ConditionalDetrMaskHeadSmallConv(nn.Module):
1885
- """
1886
- Simple convolutional head, using group norm. Upsampling is done using a FPN approach
1887
- """
1888
-
1889
- def __init__(self, dim, fpn_dims, context_dim):
1890
- super().__init__()
1891
-
1892
- if dim % 8 != 0:
1893
- raise ValueError(
1894
- "The hidden_size + number of attention heads must be divisible by 8 as the number of groups in"
1895
- " GroupNorm is set to 8"
1896
- )
1897
-
1898
- inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
1899
-
1900
- self.lay1 = nn.Conv2d(dim, dim, 3, padding=1)
1901
- self.gn1 = nn.GroupNorm(8, dim)
1902
- self.lay2 = nn.Conv2d(dim, inter_dims[1], 3, padding=1)
1903
- self.gn2 = nn.GroupNorm(min(8, inter_dims[1]), inter_dims[1])
1904
- self.lay3 = nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
1905
- self.gn3 = nn.GroupNorm(min(8, inter_dims[2]), inter_dims[2])
1906
- self.lay4 = nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
1907
- self.gn4 = nn.GroupNorm(min(8, inter_dims[3]), inter_dims[3])
1908
- self.lay5 = nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
1909
- self.gn5 = nn.GroupNorm(min(8, inter_dims[4]), inter_dims[4])
1910
- self.out_lay = nn.Conv2d(inter_dims[4], 1, 3, padding=1)
1911
-
1912
- self.dim = dim
1913
-
1914
- self.adapter1 = nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
1915
- self.adapter2 = nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
1916
- self.adapter3 = nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
1917
-
1918
- for m in self.modules():
1919
- if isinstance(m, nn.Conv2d):
1920
- init.kaiming_uniform_(m.weight, a=1)
1921
- init.constant_(m.bias, 0)
1922
-
1923
- def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]):
1924
- # here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with
1925
- # the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32).
1926
- # We expand the projected feature map to match the number of heads.
1927
- x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
1928
-
1929
- x = self.lay1(x)
1930
- x = self.gn1(x)
1931
- x = nn.functional.relu(x)
1932
- x = self.lay2(x)
1933
- x = self.gn2(x)
1934
- x = nn.functional.relu(x)
1935
-
1936
- cur_fpn = self.adapter1(fpns[0])
1937
- if cur_fpn.size(0) != x.size(0):
1938
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1939
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1940
- x = self.lay3(x)
1941
- x = self.gn3(x)
1942
- x = nn.functional.relu(x)
1943
-
1944
- cur_fpn = self.adapter2(fpns[1])
1945
- if cur_fpn.size(0) != x.size(0):
1946
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1947
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1948
- x = self.lay4(x)
1949
- x = self.gn4(x)
1950
- x = nn.functional.relu(x)
1951
-
1952
- cur_fpn = self.adapter3(fpns[2])
1953
- if cur_fpn.size(0) != x.size(0):
1954
- cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
1955
- x = cur_fpn + nn.functional.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
1956
- x = self.lay5(x)
1957
- x = self.gn5(x)
1958
- x = nn.functional.relu(x)
1959
-
1960
- x = self.out_lay(x)
1961
- return x
1962
-
1963
-
1964
- # Copied from transformers.models.detr.modeling_detr.DetrMHAttentionMap with Detr->ConditionalDetr
1965
- class ConditionalDetrMHAttentionMap(nn.Module):
1966
- """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
1967
-
1968
- def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None):
1969
- super().__init__()
1970
- self.num_heads = num_heads
1971
- self.hidden_dim = hidden_dim
1972
- self.dropout = nn.Dropout(dropout)
1973
-
1974
- self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1975
- self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
1976
-
1977
- self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
1978
-
1979
- def forward(self, q, k, mask: Optional[Tensor] = None):
1980
- q = self.q_linear(q)
1981
- k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
1982
- queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
1983
- keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
1984
- weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
1985
-
1986
- if mask is not None:
1987
- weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
1988
- weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
1989
- weights = self.dropout(weights)
1990
- return weights
1991
-
1992
-
1993
1843
  __all__ = [
1994
1844
  "ConditionalDetrForObjectDetection",
1995
1845
  "ConditionalDetrForSegmentation",