transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,9 @@
1
- # coding=utf-8
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/deformable_detr/modular_deformable_detr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_deformable_detr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
7
  # Copyright 2022 SenseTime and The HuggingFace Inc. team. All rights reserved.
3
8
  #
4
9
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,128 +17,54 @@
12
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
18
  # See the License for the specific language governing permissions and
14
19
  # limitations under the License.
15
- """PyTorch Deformable DETR model."""
16
-
17
20
  import math
18
21
  import warnings
22
+ from collections.abc import Callable
19
23
  from dataclasses import dataclass
20
- from typing import Any, Optional, Union
21
24
 
22
25
  import torch
26
+ import torch.nn as nn
23
27
  import torch.nn.functional as F
24
- from torch import Tensor, nn
28
+ from torch import Tensor
25
29
 
26
30
  from ... import initialization as init
27
31
  from ...activations import ACT2FN
32
+ from ...backbone_utils import load_backbone
28
33
  from ...integrations import use_kernel_forward_from_hub
29
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
30
34
  from ...modeling_layers import GradientCheckpointingLayer
31
- from ...modeling_outputs import BaseModelOutput
32
- from ...modeling_utils import PreTrainedModel
33
- from ...pytorch_utils import meshgrid
34
- from ...utils import (
35
- ModelOutput,
36
- auto_docstring,
37
- is_timm_available,
38
- logging,
39
- requires_backends,
40
- )
41
- from ...utils.backbone_utils import load_backbone
35
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
36
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
+ from ...processing_utils import Unpack
38
+ from ...pytorch_utils import compile_compatible_method_lru_cache, meshgrid
39
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check
40
+ from ...utils.generic import OutputRecorder, can_return_tuple, check_model_inputs
42
41
  from .configuration_deformable_detr import DeformableDetrConfig
43
42
 
44
43
 
45
- logger = logging.get_logger(__name__)
46
-
47
-
48
- if is_timm_available():
49
- from timm import create_model
50
-
51
-
52
- logger = logging.get_logger(__name__)
53
-
54
-
55
- @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
56
- class MultiScaleDeformableAttention(nn.Module):
57
- def forward(
58
- self,
59
- value: Tensor,
60
- value_spatial_shapes: Tensor,
61
- value_spatial_shapes_list: list[tuple],
62
- level_start_index: Tensor,
63
- sampling_locations: Tensor,
64
- attention_weights: Tensor,
65
- im2col_step: int,
66
- ):
67
- batch_size, _, num_heads, hidden_dim = value.shape
68
- _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
69
- value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
70
- sampling_grids = 2 * sampling_locations - 1
71
- sampling_value_list = []
72
- for level_id, (height, width) in enumerate(value_spatial_shapes_list):
73
- # batch_size, height*width, num_heads, hidden_dim
74
- # -> batch_size, height*width, num_heads*hidden_dim
75
- # -> batch_size, num_heads*hidden_dim, height*width
76
- # -> batch_size*num_heads, hidden_dim, height, width
77
- value_l_ = (
78
- value_list[level_id]
79
- .flatten(2)
80
- .transpose(1, 2)
81
- .reshape(batch_size * num_heads, hidden_dim, height, width)
82
- )
83
- # batch_size, num_queries, num_heads, num_points, 2
84
- # -> batch_size, num_heads, num_queries, num_points, 2
85
- # -> batch_size*num_heads, num_queries, num_points, 2
86
- sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
87
- # batch_size*num_heads, hidden_dim, num_queries, num_points
88
- sampling_value_l_ = nn.functional.grid_sample(
89
- value_l_,
90
- sampling_grid_l_,
91
- mode="bilinear",
92
- padding_mode="zeros",
93
- align_corners=False,
94
- )
95
- sampling_value_list.append(sampling_value_l_)
96
- # (batch_size, num_queries, num_heads, num_levels, num_points)
97
- # -> (batch_size, num_heads, num_queries, num_levels, num_points)
98
- # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
99
- attention_weights = attention_weights.transpose(1, 2).reshape(
100
- batch_size * num_heads, 1, num_queries, num_levels * num_points
101
- )
102
- output = (
103
- (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
104
- .sum(-1)
105
- .view(batch_size, num_heads * hidden_dim, num_queries)
106
- )
107
- return output.transpose(1, 2).contiguous()
108
-
109
-
110
44
  @dataclass
111
45
  @auto_docstring(
112
46
  custom_intro="""
113
- Base class for outputs of the DeformableDetrDecoder. This class adds two attributes to
114
- BaseModelOutputWithCrossAttentions, namely:
115
- - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
116
- - a stacked tensor of intermediate reference points.
47
+ Base class for outputs of the DEFORMABLE_DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
48
+ namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
49
+ gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
117
50
  """
118
51
  )
119
- class DeformableDetrDecoderOutput(ModelOutput):
52
+ class DeformableDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
120
53
  r"""
121
- intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
122
- Stacked intermediate hidden states (output of each layer of the decoder).
123
- intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
124
- Stacked intermediate reference points (reference points of each layer of the decoder).
125
54
  cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
126
55
  Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
127
56
  sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
128
57
  used to compute the weighted average in the cross-attention heads.
58
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
59
+ Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
60
+ layernorm.
61
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
62
+ Stacked intermediate reference points (reference points of each layer of the decoder).
129
63
  """
130
64
 
131
- last_hidden_state: Optional[torch.FloatTensor] = None
132
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
133
- intermediate_reference_points: Optional[torch.FloatTensor] = None
134
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
135
- attentions: Optional[tuple[torch.FloatTensor]] = None
136
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
65
+ intermediate_hidden_states: torch.FloatTensor | None = None
66
+
67
+ intermediate_reference_points: torch.FloatTensor | None = None
137
68
 
138
69
 
139
70
  @dataclass
@@ -160,18 +91,18 @@ class DeformableDetrModelOutput(ModelOutput):
160
91
  Logits of predicted bounding boxes coordinates in the first stage.
161
92
  """
162
93
 
163
- init_reference_points: Optional[torch.FloatTensor] = None
164
- last_hidden_state: Optional[torch.FloatTensor] = None
165
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
166
- intermediate_reference_points: Optional[torch.FloatTensor] = None
167
- decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
168
- decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
169
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
170
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
171
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
172
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
173
- enc_outputs_class: Optional[torch.FloatTensor] = None
174
- enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
94
+ init_reference_points: torch.FloatTensor | None = None
95
+ last_hidden_state: torch.FloatTensor | None = None
96
+ intermediate_hidden_states: torch.FloatTensor | None = None
97
+ intermediate_reference_points: torch.FloatTensor | None = None
98
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
99
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
100
+ cross_attentions: tuple[torch.FloatTensor] | None = None
101
+ encoder_last_hidden_state: torch.FloatTensor | None = None
102
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
103
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
104
+ enc_outputs_class: torch.FloatTensor | None = None
105
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
175
106
 
176
107
 
177
108
  @dataclass
@@ -199,10 +130,10 @@ class DeformableDetrObjectDetectionOutput(ModelOutput):
199
130
  Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
200
131
  and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
201
132
  `pred_boxes`) for each decoder layer.
202
- init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
203
- Initial reference points sent through the Transformer decoder.
204
133
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
205
134
  Sequence of hidden-states at the output of the last layer of the decoder of the model.
135
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
136
+ Initial reference points sent through the Transformer decoder.
206
137
  intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
207
138
  Stacked intermediate hidden states (output of each layer of the decoder).
208
139
  intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
@@ -215,33 +146,81 @@ class DeformableDetrObjectDetectionOutput(ModelOutput):
215
146
  Logits of predicted bounding boxes coordinates in the first stage.
216
147
  """
217
148
 
218
- loss: Optional[torch.FloatTensor] = None
219
- loss_dict: Optional[dict] = None
220
- logits: Optional[torch.FloatTensor] = None
221
- pred_boxes: Optional[torch.FloatTensor] = None
222
- auxiliary_outputs: Optional[list[dict]] = None
223
- init_reference_points: Optional[torch.FloatTensor] = None
224
- last_hidden_state: Optional[torch.FloatTensor] = None
225
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
226
- intermediate_reference_points: Optional[torch.FloatTensor] = None
227
- decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
228
- decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
229
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
230
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
231
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
232
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
233
- enc_outputs_class: Any = None
234
- enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
149
+ loss: torch.FloatTensor | None = None
150
+ loss_dict: dict | None = None
151
+ logits: torch.FloatTensor | None = None
152
+ pred_boxes: torch.FloatTensor | None = None
153
+ auxiliary_outputs: list[dict] | None = None
154
+ last_hidden_state: torch.FloatTensor | None = None
155
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
156
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
157
+ cross_attentions: tuple[torch.FloatTensor] | None = None
158
+ encoder_last_hidden_state: torch.FloatTensor | None = None
159
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
160
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
161
+
162
+ init_reference_points: torch.FloatTensor | None = None
163
+ intermediate_hidden_states: torch.FloatTensor | None = None
164
+ intermediate_reference_points: torch.FloatTensor | None = None
165
+ enc_outputs_class: torch.FloatTensor | None = None
166
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
235
167
 
236
168
 
237
- def inverse_sigmoid(x, eps=1e-5):
238
- x = x.clamp(min=0, max=1)
239
- x1 = x.clamp(min=eps)
240
- x2 = (1 - x).clamp(min=eps)
241
- return torch.log(x1 / x2)
169
+ @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
170
+ class MultiScaleDeformableAttention(nn.Module):
171
+ def forward(
172
+ self,
173
+ value: Tensor,
174
+ value_spatial_shapes: Tensor,
175
+ value_spatial_shapes_list: list[tuple],
176
+ level_start_index: Tensor,
177
+ sampling_locations: Tensor,
178
+ attention_weights: Tensor,
179
+ im2col_step: int,
180
+ ):
181
+ batch_size, _, num_heads, hidden_dim = value.shape
182
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
183
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
184
+ sampling_grids = 2 * sampling_locations - 1
185
+ sampling_value_list = []
186
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
187
+ # batch_size, height*width, num_heads, hidden_dim
188
+ # -> batch_size, height*width, num_heads*hidden_dim
189
+ # -> batch_size, num_heads*hidden_dim, height*width
190
+ # -> batch_size*num_heads, hidden_dim, height, width
191
+ value_l_ = (
192
+ value_list[level_id]
193
+ .flatten(2)
194
+ .transpose(1, 2)
195
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
196
+ )
197
+ # batch_size, num_queries, num_heads, num_points, 2
198
+ # -> batch_size, num_heads, num_queries, num_points, 2
199
+ # -> batch_size*num_heads, num_queries, num_points, 2
200
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
201
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
202
+ sampling_value_l_ = nn.functional.grid_sample(
203
+ value_l_,
204
+ sampling_grid_l_,
205
+ mode="bilinear",
206
+ padding_mode="zeros",
207
+ align_corners=False,
208
+ )
209
+ sampling_value_list.append(sampling_value_l_)
210
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
211
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
212
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
213
+ attention_weights = attention_weights.transpose(1, 2).reshape(
214
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
215
+ )
216
+ output = (
217
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
218
+ .sum(-1)
219
+ .view(batch_size, num_heads * hidden_dim, num_queries)
220
+ )
221
+ return output.transpose(1, 2).contiguous()
242
222
 
243
223
 
244
- # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->DeformableDetr
245
224
  class DeformableDetrFrozenBatchNorm2d(nn.Module):
246
225
  """
247
226
  BatchNorm2d where the batch statistics and the affine parameters are fixed.
@@ -281,7 +260,6 @@ class DeformableDetrFrozenBatchNorm2d(nn.Module):
281
260
  return x * scale + bias
282
261
 
283
262
 
284
- # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DeformableDetr
285
263
  def replace_batch_norm(model):
286
264
  r"""
287
265
  Recursively replace all `torch.nn.BatchNorm2d` with `DeformableDetrFrozenBatchNorm2d`.
@@ -319,57 +297,36 @@ class DeformableDetrConvEncoder(nn.Module):
319
297
 
320
298
  self.config = config
321
299
 
322
- # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
323
- if config.use_timm_backbone:
324
- # We default to values which were previously hard-coded. This enables configurability from the config
325
- # using backbone arguments, while keeping the default behavior the same.
326
- requires_backends(self, ["timm"])
327
- kwargs = getattr(config, "backbone_kwargs", {})
328
- kwargs = {} if kwargs is None else kwargs.copy()
329
- out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,))
330
- num_channels = kwargs.pop("in_chans", config.num_channels)
331
- if config.dilation:
332
- kwargs["output_stride"] = kwargs.get("output_stride", 16)
333
- backbone = create_model(
334
- config.backbone,
335
- pretrained=config.use_pretrained_backbone,
336
- features_only=True,
337
- out_indices=out_indices,
338
- in_chans=num_channels,
339
- **kwargs,
340
- )
341
- else:
342
- backbone = load_backbone(config)
300
+ backbone = load_backbone(config)
301
+ self.intermediate_channel_sizes = backbone.channels
343
302
 
344
303
  # replace batch norm by frozen batch norm
345
304
  with torch.no_grad():
346
305
  replace_batch_norm(backbone)
347
- self.model = backbone
348
- self.intermediate_channel_sizes = (
349
- self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
350
- )
351
306
 
352
- backbone_model_type = None
353
- if config.backbone is not None:
354
- backbone_model_type = config.backbone
355
- elif config.backbone_config is not None:
356
- backbone_model_type = config.backbone_config.model_type
357
- else:
358
- raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
307
+ # We used to load with timm library directly instead of the AutoBackbone API
308
+ # so we need to unwrap the `backbone._backbone` module to load weights without mismatch
309
+ is_timm_model = False
310
+ if hasattr(backbone, "_backbone"):
311
+ backbone = backbone._backbone
312
+ is_timm_model = True
313
+ self.model = backbone
359
314
 
315
+ backbone_model_type = config.backbone_config.model_type
360
316
  if "resnet" in backbone_model_type:
361
317
  for name, parameter in self.model.named_parameters():
362
- if config.use_timm_backbone:
318
+ if is_timm_model:
363
319
  if "layer2" not in name and "layer3" not in name and "layer4" not in name:
364
320
  parameter.requires_grad_(False)
365
321
  else:
366
322
  if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
367
323
  parameter.requires_grad_(False)
368
324
 
369
- # Copied from transformers.models.detr.modeling_detr.DetrConvEncoder.forward with Detr->DeformableDetr
370
325
  def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
371
326
  # send pixel_values through the model to get list of feature maps
372
- features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
327
+ features = self.model(pixel_values)
328
+ if isinstance(features, dict):
329
+ features = features.feature_maps
373
330
 
374
331
  out = []
375
332
  for feature_map in features:
@@ -379,67 +336,58 @@ class DeformableDetrConvEncoder(nn.Module):
379
336
  return out
380
337
 
381
338
 
382
- # Copied from transformers.models.detr.modeling_detr.DetrConvModel with Detr->DeformableDetr
383
- class DeformableDetrConvModel(nn.Module):
384
- """
385
- This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
386
- """
387
-
388
- def __init__(self, conv_encoder, position_embedding):
389
- super().__init__()
390
- self.conv_encoder = conv_encoder
391
- self.position_embedding = position_embedding
392
-
393
- def forward(self, pixel_values, pixel_mask):
394
- # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
395
- out = self.conv_encoder(pixel_values, pixel_mask)
396
- pos = []
397
- for feature_map, mask in out:
398
- # position encoding
399
- pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
400
-
401
- return out, pos
402
-
403
-
404
339
  class DeformableDetrSinePositionEmbedding(nn.Module):
405
340
  """
406
341
  This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
407
342
  need paper, generalized to work on images.
408
343
  """
409
344
 
410
- def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
345
+ def __init__(
346
+ self,
347
+ num_position_features: int = 64,
348
+ temperature: int = 10000,
349
+ normalize: bool = False,
350
+ scale: float | None = None,
351
+ ):
411
352
  super().__init__()
412
- self.embedding_dim = embedding_dim
413
- self.temperature = temperature
414
- self.normalize = normalize
415
353
  if scale is not None and normalize is False:
416
354
  raise ValueError("normalize should be True if scale is passed")
417
- if scale is None:
418
- scale = 2 * math.pi
419
- self.scale = scale
355
+ self.num_position_features = num_position_features
356
+ self.temperature = temperature
357
+ self.normalize = normalize
358
+ self.scale = 2 * math.pi if scale is None else scale
420
359
 
421
- def forward(self, pixel_values, pixel_mask):
422
- if pixel_mask is None:
423
- raise ValueError("No pixel mask provided")
424
- y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype)
425
- x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype)
360
+ @compile_compatible_method_lru_cache(maxsize=1)
361
+ def forward(
362
+ self,
363
+ shape: torch.Size,
364
+ device: torch.device | str,
365
+ dtype: torch.dtype,
366
+ mask: torch.Tensor | None = None,
367
+ ) -> torch.Tensor:
368
+ if mask is None:
369
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
370
+ y_embed = mask.cumsum(1, dtype=dtype)
371
+ x_embed = mask.cumsum(2, dtype=dtype)
426
372
  if self.normalize:
427
373
  eps = 1e-6
428
374
  y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
429
375
  x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
430
376
 
431
- dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device)
432
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
377
+ dim_t = torch.arange(self.num_position_features, dtype=torch.int64, device=device).to(dtype)
378
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_position_features)
433
379
 
434
380
  pos_x = x_embed[:, :, :, None] / dim_t
435
381
  pos_y = y_embed[:, :, :, None] / dim_t
436
382
  pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
437
383
  pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
438
384
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
385
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
386
+ # expected by the encoder
387
+ pos = pos.flatten(2).permute(0, 2, 1)
439
388
  return pos
440
389
 
441
390
 
442
- # Copied from transformers.models.detr.modeling_detr.DetrLearnedPositionEmbedding
443
391
  class DeformableDetrLearnedPositionEmbedding(nn.Module):
444
392
  """
445
393
  This module learns positional embeddings up to a fixed maximum size.
@@ -450,31 +398,122 @@ class DeformableDetrLearnedPositionEmbedding(nn.Module):
450
398
  self.row_embeddings = nn.Embedding(50, embedding_dim)
451
399
  self.column_embeddings = nn.Embedding(50, embedding_dim)
452
400
 
453
- def forward(self, pixel_values, pixel_mask=None):
454
- height, width = pixel_values.shape[-2:]
455
- width_values = torch.arange(width, device=pixel_values.device)
456
- height_values = torch.arange(height, device=pixel_values.device)
401
+ @compile_compatible_method_lru_cache(maxsize=1)
402
+ def forward(
403
+ self,
404
+ shape: torch.Size,
405
+ device: torch.device | str,
406
+ dtype: torch.dtype,
407
+ mask: torch.Tensor | None = None,
408
+ ):
409
+ height, width = shape[-2:]
410
+ width_values = torch.arange(width, device=device)
411
+ height_values = torch.arange(height, device=device)
457
412
  x_emb = self.column_embeddings(width_values)
458
413
  y_emb = self.row_embeddings(height_values)
459
414
  pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
460
415
  pos = pos.permute(2, 0, 1)
461
416
  pos = pos.unsqueeze(0)
462
- pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
417
+ pos = pos.repeat(shape[0], 1, 1, 1)
418
+ # Flatten spatial dimensions and permute to (batch_size, sequence_length, hidden_size) format
419
+ # expected by the encoder
420
+ pos = pos.flatten(2).permute(0, 2, 1)
463
421
  return pos
464
422
 
465
423
 
466
- # Copied from transformers.models.detr.modeling_detr.build_position_encoding with Detr->DeformableDetr
467
- def build_position_encoding(config):
468
- n_steps = config.d_model // 2
469
- if config.position_embedding_type == "sine":
470
- # TODO find a better way of exposing other arguments
471
- position_embedding = DeformableDetrSinePositionEmbedding(n_steps, normalize=True)
472
- elif config.position_embedding_type == "learned":
473
- position_embedding = DeformableDetrLearnedPositionEmbedding(n_steps)
474
- else:
475
- raise ValueError(f"Not supported {config.position_embedding_type}")
424
+ def eager_attention_forward(
425
+ module: nn.Module,
426
+ query: torch.Tensor,
427
+ key: torch.Tensor,
428
+ value: torch.Tensor,
429
+ attention_mask: torch.Tensor | None,
430
+ scaling: float | None = None,
431
+ dropout: float = 0.0,
432
+ **kwargs: Unpack[TransformersKwargs],
433
+ ):
434
+ if scaling is None:
435
+ scaling = query.size(-1) ** -0.5
436
+
437
+ # Take the dot product between "query" and "key" to get the raw attention scores.
438
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
439
+
440
+ if attention_mask is not None:
441
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
442
+ attn_weights = attn_weights + attention_mask
443
+
444
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
445
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
446
+
447
+ attn_output = torch.matmul(attn_weights, value)
448
+ attn_output = attn_output.transpose(1, 2).contiguous()
449
+
450
+ return attn_output, attn_weights
451
+
452
+
453
+ class DeformableDetrSelfAttention(nn.Module):
454
+ """
455
+ Multi-headed self-attention from 'Attention Is All You Need' paper.
456
+
457
+ In DEFORMABLE_DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
458
+ """
459
+
460
+ def __init__(
461
+ self,
462
+ config: DeformableDetrConfig,
463
+ hidden_size: int,
464
+ num_attention_heads: int,
465
+ dropout: float = 0.0,
466
+ bias: bool = True,
467
+ ):
468
+ super().__init__()
469
+ self.config = config
470
+ self.head_dim = hidden_size // num_attention_heads
471
+ self.scaling = self.head_dim**-0.5
472
+ self.attention_dropout = dropout
473
+ self.is_causal = False
474
+
475
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
476
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
477
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
478
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
479
+
480
+ def forward(
481
+ self,
482
+ hidden_states: torch.Tensor,
483
+ attention_mask: torch.Tensor | None = None,
484
+ position_embeddings: torch.Tensor | None = None,
485
+ **kwargs: Unpack[TransformersKwargs],
486
+ ) -> tuple[torch.Tensor, torch.Tensor]:
487
+ """
488
+ Position embeddings are added to both queries and keys (but not values).
489
+ """
490
+ input_shape = hidden_states.shape[:-1]
491
+ hidden_shape = (*input_shape, -1, self.head_dim)
492
+
493
+ query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
494
+
495
+ query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
496
+ key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
497
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
498
+
499
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
500
+ self.config._attn_implementation, eager_attention_forward
501
+ )
502
+
503
+ attn_output, attn_weights = attention_interface(
504
+ self,
505
+ query_states,
506
+ key_states,
507
+ value_states,
508
+ attention_mask,
509
+ dropout=0.0 if not self.training else self.attention_dropout,
510
+ scaling=self.scaling,
511
+ **kwargs,
512
+ )
476
513
 
477
- return position_embedding
514
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
515
+ attn_output = self.o_proj(attn_output)
516
+ return attn_output, attn_weights
478
517
 
479
518
 
480
519
  class DeformableDetrMultiscaleDeformableAttention(nn.Module):
@@ -514,33 +553,30 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
514
553
 
515
554
  self.disable_custom_kernels = config.disable_custom_kernels
516
555
 
517
- def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
518
- return tensor if position_embeddings is None else tensor + position_embeddings
519
-
520
556
  def forward(
521
557
  self,
522
558
  hidden_states: torch.Tensor,
523
- attention_mask: Optional[torch.Tensor] = None,
559
+ attention_mask: torch.Tensor | None = None,
524
560
  encoder_hidden_states=None,
525
561
  encoder_attention_mask=None,
526
- position_embeddings: Optional[torch.Tensor] = None,
562
+ position_embeddings: torch.Tensor | None = None,
527
563
  reference_points=None,
528
564
  spatial_shapes=None,
529
565
  spatial_shapes_list=None,
530
566
  level_start_index=None,
531
- output_attentions: bool = False,
532
- ):
567
+ **kwargs: Unpack[TransformersKwargs],
568
+ ) -> tuple[torch.Tensor, torch.Tensor]:
533
569
  # add position embeddings to the hidden states before projecting to queries and keys
534
570
  if position_embeddings is not None:
535
- hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
571
+ hidden_states = hidden_states + position_embeddings
536
572
 
537
573
  batch_size, num_queries, _ = hidden_states.shape
538
574
  batch_size, sequence_length, _ = encoder_hidden_states.shape
539
575
  total_elements = sum(height * width for height, width in spatial_shapes_list)
540
- if total_elements != sequence_length:
541
- raise ValueError(
542
- "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
543
- )
576
+ torch_compilable_check(
577
+ total_elements == sequence_length,
578
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
579
+ )
544
580
 
545
581
  value = self.value_proj(encoder_hidden_states)
546
582
  if attention_mask is not None:
@@ -587,159 +623,48 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
587
623
  return output, attention_weights
588
624
 
589
625
 
590
- class DeformableDetrMultiheadAttention(nn.Module):
591
- """
592
- Multi-headed attention from 'Attention Is All You Need' paper.
593
-
594
- Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
595
- """
596
-
597
- def __init__(
598
- self,
599
- embed_dim: int,
600
- num_heads: int,
601
- dropout: float = 0.0,
602
- bias: bool = True,
603
- ):
626
+ class DeformableDetrMLP(nn.Module):
627
+ def __init__(self, config: DeformableDetrConfig, hidden_size: int, intermediate_size: int):
604
628
  super().__init__()
605
- self.embed_dim = embed_dim
606
- self.num_heads = num_heads
607
- self.dropout = dropout
608
- self.head_dim = embed_dim // num_heads
609
- if self.head_dim * num_heads != self.embed_dim:
610
- raise ValueError(
611
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
612
- f" {num_heads})."
613
- )
614
- self.scaling = self.head_dim**-0.5
615
-
616
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
617
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
618
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
619
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
620
-
621
- def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
622
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
623
-
624
- def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
625
- return tensor if position_embeddings is None else tensor + position_embeddings
626
-
627
- def forward(
628
- self,
629
- hidden_states: torch.Tensor,
630
- attention_mask: Optional[torch.Tensor] = None,
631
- position_embeddings: Optional[torch.Tensor] = None,
632
- output_attentions: bool = False,
633
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
634
- """Input shape: Batch x Time x Channel"""
635
-
636
- batch_size, target_len, embed_dim = hidden_states.size()
637
- # add position embeddings to the hidden states before projecting to queries and keys
638
- if position_embeddings is not None:
639
- hidden_states_original = hidden_states
640
- hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
641
-
642
- # get queries, keys and values
643
- query_states = self.q_proj(hidden_states) * self.scaling
644
- key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
645
- value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
646
-
647
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
648
- query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
649
- key_states = key_states.view(*proj_shape)
650
- value_states = value_states.view(*proj_shape)
651
-
652
- source_len = key_states.size(1)
653
-
654
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
655
-
656
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
657
- raise ValueError(
658
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
659
- f" {attn_weights.size()}"
660
- )
661
-
662
- # expand attention_mask
663
- if attention_mask is not None:
664
- # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
665
- attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
666
-
667
- if attention_mask is not None:
668
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
669
- raise ValueError(
670
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
671
- f" {attention_mask.size()}"
672
- )
673
- if attention_mask.dtype == torch.bool:
674
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
675
- attention_mask, -torch.inf
676
- )
677
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
678
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
679
-
680
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
681
-
682
- if output_attentions:
683
- # this operation is a bit awkward, but it's required to
684
- # make sure that attn_weights keeps its gradient.
685
- # In order to do so, attn_weights have to reshaped
686
- # twice and have to be reused in the following
687
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
688
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
689
- else:
690
- attn_weights_reshaped = None
691
-
692
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
693
-
694
- attn_output = torch.bmm(attn_probs, value_states)
695
-
696
- if attn_output.size() != (
697
- batch_size * self.num_heads,
698
- target_len,
699
- self.head_dim,
700
- ):
701
- raise ValueError(
702
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
703
- f" {attn_output.size()}"
704
- )
705
-
706
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
707
- attn_output = attn_output.transpose(1, 2)
708
- attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
709
-
710
- attn_output = self.out_proj(attn_output)
629
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
630
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
631
+ self.activation_fn = ACT2FN[config.activation_function]
632
+ self.activation_dropout = config.activation_dropout
633
+ self.dropout = config.dropout
711
634
 
712
- return attn_output, attn_weights_reshaped
635
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
636
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
637
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
638
+ hidden_states = self.fc2(hidden_states)
639
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
640
+ return hidden_states
713
641
 
714
642
 
715
643
  class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
716
644
  def __init__(self, config: DeformableDetrConfig):
717
645
  super().__init__()
718
- self.embed_dim = config.d_model
646
+ self.hidden_size = config.d_model
719
647
  self.self_attn = DeformableDetrMultiscaleDeformableAttention(
720
648
  config,
721
649
  num_heads=config.encoder_attention_heads,
722
650
  n_points=config.encoder_n_points,
723
651
  )
724
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
652
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
725
653
  self.dropout = config.dropout
726
- self.activation_fn = ACT2FN[config.activation_function]
727
- self.activation_dropout = config.activation_dropout
728
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
729
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
730
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
654
+ self.mlp = DeformableDetrMLP(config, self.hidden_size, config.encoder_ffn_dim)
655
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
731
656
 
732
657
  def forward(
733
658
  self,
734
659
  hidden_states: torch.Tensor,
735
660
  attention_mask: torch.Tensor,
736
- position_embeddings: Optional[torch.Tensor] = None,
661
+ spatial_position_embeddings: torch.Tensor | None = None,
737
662
  reference_points=None,
738
663
  spatial_shapes=None,
739
664
  spatial_shapes_list=None,
740
665
  level_start_index=None,
741
- output_attentions: bool = False,
742
- ):
666
+ **kwargs: Unpack[TransformersKwargs],
667
+ ) -> torch.Tensor:
743
668
  """
744
669
  Args:
745
670
  hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -754,24 +679,18 @@ class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
754
679
  Spatial shapes of the backbone feature maps.
755
680
  level_start_index (`torch.LongTensor`, *optional*):
756
681
  Level start index.
757
- output_attentions (`bool`, *optional*):
758
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
759
- returned tensors for more detail.
760
682
  """
761
683
  residual = hidden_states
762
-
763
- # Apply Multi-scale Deformable Attention Module on the multi-scale feature maps.
764
- hidden_states, attn_weights = self.self_attn(
684
+ hidden_states, _ = self.self_attn(
765
685
  hidden_states=hidden_states,
766
686
  attention_mask=attention_mask,
767
687
  encoder_hidden_states=hidden_states,
768
688
  encoder_attention_mask=attention_mask,
769
- position_embeddings=position_embeddings,
689
+ position_embeddings=spatial_position_embeddings,
770
690
  reference_points=reference_points,
771
691
  spatial_shapes=spatial_shapes,
772
692
  spatial_shapes_list=spatial_shapes_list,
773
693
  level_start_index=level_start_index,
774
- output_attentions=output_attentions,
775
694
  )
776
695
 
777
696
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -779,12 +698,7 @@ class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
779
698
  hidden_states = self.self_attn_layer_norm(hidden_states)
780
699
 
781
700
  residual = hidden_states
782
- hidden_states = self.activation_fn(self.fc1(hidden_states))
783
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
784
-
785
- hidden_states = self.fc2(hidden_states)
786
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
787
-
701
+ hidden_states = self.mlp(hidden_states)
788
702
  hidden_states = residual + hidden_states
789
703
  hidden_states = self.final_layer_norm(hidden_states)
790
704
 
@@ -793,54 +707,44 @@ class DeformableDetrEncoderLayer(GradientCheckpointingLayer):
793
707
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
794
708
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
795
709
 
796
- outputs = (hidden_states,)
797
-
798
- if output_attentions:
799
- outputs += (attn_weights,)
800
-
801
- return outputs
710
+ return hidden_states
802
711
 
803
712
 
804
713
  class DeformableDetrDecoderLayer(GradientCheckpointingLayer):
805
714
  def __init__(self, config: DeformableDetrConfig):
806
715
  super().__init__()
807
- self.embed_dim = config.d_model
716
+ self.hidden_size = config.d_model
808
717
 
809
- # self-attention
810
- self.self_attn = DeformableDetrMultiheadAttention(
811
- embed_dim=self.embed_dim,
812
- num_heads=config.decoder_attention_heads,
718
+ self.self_attn = DeformableDetrSelfAttention(
719
+ config=config,
720
+ hidden_size=self.hidden_size,
721
+ num_attention_heads=config.decoder_attention_heads,
813
722
  dropout=config.attention_dropout,
814
723
  )
815
724
  self.dropout = config.dropout
816
- self.activation_fn = ACT2FN[config.activation_function]
817
- self.activation_dropout = config.activation_dropout
818
725
 
819
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
820
- # cross-attention
726
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size)
821
727
  self.encoder_attn = DeformableDetrMultiscaleDeformableAttention(
822
728
  config,
823
729
  num_heads=config.decoder_attention_heads,
824
730
  n_points=config.decoder_n_points,
825
731
  )
826
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
827
- # feedforward neural networks
828
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
829
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
830
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
732
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size)
733
+ self.mlp = DeformableDetrMLP(config, self.hidden_size, config.decoder_ffn_dim)
734
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
831
735
 
832
736
  def forward(
833
737
  self,
834
738
  hidden_states: torch.Tensor,
835
- position_embeddings: Optional[torch.Tensor] = None,
739
+ object_queries_position_embeddings: torch.Tensor | None = None,
836
740
  reference_points=None,
837
741
  spatial_shapes=None,
838
742
  spatial_shapes_list=None,
839
743
  level_start_index=None,
840
- encoder_hidden_states: Optional[torch.Tensor] = None,
841
- encoder_attention_mask: Optional[torch.Tensor] = None,
842
- output_attentions: Optional[bool] = False,
843
- ):
744
+ encoder_hidden_states: torch.Tensor | None = None,
745
+ encoder_attention_mask: torch.Tensor | None = None,
746
+ **kwargs: Unpack[TransformersKwargs],
747
+ ) -> torch.Tensor:
844
748
  """
845
749
  Args:
846
750
  hidden_states (`torch.FloatTensor`):
@@ -858,60 +762,47 @@ class DeformableDetrDecoderLayer(GradientCheckpointingLayer):
858
762
  encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
859
763
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
860
764
  values.
861
- output_attentions (`bool`, *optional*):
862
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
863
- returned tensors for more detail.
864
765
  """
865
766
  residual = hidden_states
866
767
 
867
768
  # Self Attention
868
- hidden_states, self_attn_weights = self.self_attn(
769
+ hidden_states, _ = self.self_attn(
869
770
  hidden_states=hidden_states,
870
- position_embeddings=position_embeddings,
871
- output_attentions=output_attentions,
771
+ position_embeddings=object_queries_position_embeddings,
772
+ **kwargs,
872
773
  )
873
774
 
874
775
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
875
776
  hidden_states = residual + hidden_states
876
777
  hidden_states = self.self_attn_layer_norm(hidden_states)
877
778
 
878
- second_residual = hidden_states
779
+ residual = hidden_states
879
780
 
880
781
  # Cross-Attention
881
- cross_attn_weights = None
882
- hidden_states, cross_attn_weights = self.encoder_attn(
782
+ hidden_states, _ = self.encoder_attn(
883
783
  hidden_states=hidden_states,
884
784
  attention_mask=encoder_attention_mask,
885
785
  encoder_hidden_states=encoder_hidden_states,
886
786
  encoder_attention_mask=encoder_attention_mask,
887
- position_embeddings=position_embeddings,
787
+ position_embeddings=object_queries_position_embeddings,
888
788
  reference_points=reference_points,
889
789
  spatial_shapes=spatial_shapes,
890
790
  spatial_shapes_list=spatial_shapes_list,
891
791
  level_start_index=level_start_index,
892
- output_attentions=output_attentions,
893
792
  )
894
793
 
895
794
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
896
- hidden_states = second_residual + hidden_states
795
+ hidden_states = residual + hidden_states
897
796
 
898
797
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
899
798
 
900
799
  # Fully Connected
901
800
  residual = hidden_states
902
- hidden_states = self.activation_fn(self.fc1(hidden_states))
903
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
904
- hidden_states = self.fc2(hidden_states)
905
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
801
+ hidden_states = self.mlp(hidden_states)
906
802
  hidden_states = residual + hidden_states
907
803
  hidden_states = self.final_layer_norm(hidden_states)
908
804
 
909
- outputs = (hidden_states,)
910
-
911
- if output_attentions:
912
- outputs += (self_attn_weights, cross_attn_weights)
913
-
914
- return outputs
805
+ return hidden_states
915
806
 
916
807
 
917
808
  @auto_docstring
@@ -926,6 +817,13 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
926
817
  r"DeformableDetrEncoderLayer",
927
818
  r"DeformableDetrDecoderLayer",
928
819
  ]
820
+ _supports_sdpa = True
821
+ _supports_flash_attn = True
822
+ _supports_attention_backend = True
823
+ _supports_flex_attn = True
824
+ _keys_to_ignore_on_load_unexpected = [
825
+ r"detr\.model\.backbone\.model\.layer\d+\.0\.downsample\.1\.num_batches_tracked"
826
+ ]
929
827
 
930
828
  @torch.no_grad()
931
829
  def _init_weights(self, module):
@@ -983,9 +881,13 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
983
881
  config: DeformableDetrConfig
984
882
  """
985
883
 
884
+ _can_record_outputs = {
885
+ "hidden_states": DeformableDetrEncoderLayer,
886
+ "attentions": OutputRecorder(DeformableDetrMultiscaleDeformableAttention, layer_name="self_attn", index=1),
887
+ }
888
+
986
889
  def __init__(self, config: DeformableDetrConfig):
987
890
  super().__init__(config)
988
- self.gradient_checkpointing = False
989
891
 
990
892
  self.dropout = config.dropout
991
893
  self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
@@ -993,51 +895,18 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
993
895
  # Initialize weights and apply final processing
994
896
  self.post_init()
995
897
 
996
- @staticmethod
997
- def get_reference_points(spatial_shapes, valid_ratios, device):
998
- """
999
- Get reference points for each feature map. Used in decoder.
1000
-
1001
- Args:
1002
- spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
1003
- Spatial shapes of each feature map.
1004
- valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
1005
- Valid ratios of each feature map.
1006
- device (`torch.device`):
1007
- Device on which to create the tensors.
1008
- Returns:
1009
- `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
1010
- """
1011
- reference_points_list = []
1012
- for level, (height, width) in enumerate(spatial_shapes):
1013
- ref_y, ref_x = meshgrid(
1014
- torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
1015
- torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
1016
- indexing="ij",
1017
- )
1018
- # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
1019
- ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
1020
- ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
1021
- ref = torch.stack((ref_x, ref_y), -1)
1022
- reference_points_list.append(ref)
1023
- reference_points = torch.cat(reference_points_list, 1)
1024
- reference_points = reference_points[:, :, None] * valid_ratios[:, None]
1025
- return reference_points
1026
-
898
+ @check_model_inputs()
1027
899
  def forward(
1028
900
  self,
1029
901
  inputs_embeds=None,
1030
902
  attention_mask=None,
1031
- position_embeddings=None,
903
+ spatial_position_embeddings=None,
1032
904
  spatial_shapes=None,
1033
905
  spatial_shapes_list=None,
1034
906
  level_start_index=None,
1035
907
  valid_ratios=None,
1036
- output_attentions=None,
1037
- output_hidden_states=None,
1038
- return_dict=None,
1039
- **kwargs,
1040
- ):
908
+ **kwargs: Unpack[TransformersKwargs],
909
+ ) -> BaseModelOutput:
1041
910
  r"""
1042
911
  Args:
1043
912
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -1047,66 +916,72 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
1047
916
  - 1 for pixel features that are real (i.e. **not masked**),
1048
917
  - 0 for pixel features that are padding (i.e. **masked**).
1049
918
  [What are attention masks?](../glossary#attention-mask)
1050
- position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1051
- Position embeddings that are added to the queries and keys in each self-attention layer.
919
+ spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
920
+ Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer.
1052
921
  spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
1053
922
  Spatial shapes of each feature map.
1054
923
  level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
1055
924
  Starting index of each feature map.
1056
925
  valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
1057
926
  Ratio of valid area in each feature level.
1058
- output_attentions (`bool`, *optional*):
1059
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1060
- returned tensors for more detail.
1061
- output_hidden_states (`bool`, *optional*):
1062
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1063
- for more detail.
1064
- return_dict (`bool`, *optional*):
1065
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1066
927
  """
1067
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1068
- output_hidden_states = (
1069
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1070
- )
1071
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1072
-
1073
928
  hidden_states = inputs_embeds
1074
929
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1075
930
 
1076
931
  spatial_shapes_tuple = tuple(spatial_shapes_list)
1077
932
  reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device)
1078
933
 
1079
- encoder_states = () if output_hidden_states else None
1080
- all_attentions = () if output_attentions else None
1081
- for i, encoder_layer in enumerate(self.layers):
1082
- if output_hidden_states:
1083
- encoder_states = encoder_states + (hidden_states,)
1084
- layer_outputs = encoder_layer(
934
+ for encoder_layer in self.layers:
935
+ hidden_states = encoder_layer(
1085
936
  hidden_states,
1086
937
  attention_mask,
1087
- position_embeddings=position_embeddings,
938
+ spatial_position_embeddings=spatial_position_embeddings,
1088
939
  reference_points=reference_points,
1089
940
  spatial_shapes=spatial_shapes,
1090
941
  spatial_shapes_list=spatial_shapes_list,
1091
942
  level_start_index=level_start_index,
1092
- output_attentions=output_attentions,
943
+ **kwargs,
1093
944
  )
1094
945
 
1095
- hidden_states = layer_outputs[0]
946
+ return BaseModelOutput(last_hidden_state=hidden_states)
1096
947
 
1097
- if output_attentions:
1098
- all_attentions = all_attentions + (layer_outputs[1],)
948
+ @staticmethod
949
+ def get_reference_points(spatial_shapes_list, valid_ratios, device):
950
+ """
951
+ Get reference points for each feature map. Used in decoder.
952
+
953
+ Args:
954
+ spatial_shapes_list (`list[tuple[int, int]]`):
955
+ Spatial shapes of each feature map.
956
+ valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
957
+ Valid ratios of each feature map.
958
+ device (`torch.device`):
959
+ Device on which to create the tensors.
960
+ Returns:
961
+ `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
962
+ """
963
+ reference_points_list = []
964
+ for level, (height, width) in enumerate(spatial_shapes_list):
965
+ ref_y, ref_x = meshgrid(
966
+ torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
967
+ torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
968
+ indexing="ij",
969
+ )
970
+ # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
971
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
972
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
973
+ ref = torch.stack((ref_x, ref_y), -1)
974
+ reference_points_list.append(ref)
975
+ reference_points = torch.cat(reference_points_list, 1)
976
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
977
+ return reference_points
1099
978
 
1100
- if output_hidden_states:
1101
- encoder_states = encoder_states + (hidden_states,)
1102
979
 
1103
- if not return_dict:
1104
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1105
- return BaseModelOutput(
1106
- last_hidden_state=hidden_states,
1107
- hidden_states=encoder_states,
1108
- attentions=all_attentions,
1109
- )
980
+ def inverse_sigmoid(x, eps=1e-5):
981
+ x = x.clamp(min=0, max=1)
982
+ x1 = x.clamp(min=eps)
983
+ x2 = (1 - x).clamp(min=eps)
984
+ return torch.log(x1 / x2)
1110
985
 
1111
986
 
1112
987
  class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
@@ -1124,12 +999,19 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1124
999
  config: DeformableDetrConfig
1125
1000
  """
1126
1001
 
1002
+ _can_record_outputs = {
1003
+ "hidden_states": DeformableDetrDecoderLayer,
1004
+ "attentions": OutputRecorder(DeformableDetrSelfAttention, layer_name="self_attn", index=1),
1005
+ "cross_attentions": OutputRecorder(
1006
+ DeformableDetrMultiscaleDeformableAttention, layer_name="encoder_attn", index=1
1007
+ ),
1008
+ }
1009
+
1127
1010
  def __init__(self, config: DeformableDetrConfig):
1128
1011
  super().__init__(config)
1129
1012
 
1130
1013
  self.dropout = config.dropout
1131
1014
  self.layers = nn.ModuleList([DeformableDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
1132
- self.gradient_checkpointing = False
1133
1015
 
1134
1016
  # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
1135
1017
  self.bbox_embed = None
@@ -1138,21 +1020,19 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1138
1020
  # Initialize weights and apply final processing
1139
1021
  self.post_init()
1140
1022
 
1023
+ @check_model_inputs()
1141
1024
  def forward(
1142
1025
  self,
1143
1026
  inputs_embeds=None,
1144
1027
  encoder_hidden_states=None,
1145
1028
  encoder_attention_mask=None,
1146
- position_embeddings=None,
1029
+ object_queries_position_embeddings=None,
1147
1030
  reference_points=None,
1148
1031
  spatial_shapes=None,
1149
1032
  spatial_shapes_list=None,
1150
1033
  level_start_index=None,
1151
1034
  valid_ratios=None,
1152
- output_attentions=None,
1153
- output_hidden_states=None,
1154
- return_dict=None,
1155
- **kwargs,
1035
+ **kwargs: Unpack[TransformersKwargs],
1156
1036
  ):
1157
1037
  r"""
1158
1038
  Args:
@@ -1166,8 +1046,8 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1166
1046
  in `[0, 1]`:
1167
1047
  - 1 for pixels that are real (i.e. **not masked**),
1168
1048
  - 0 for pixels that are padding (i.e. **masked**).
1169
- position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1170
- Position embeddings that are added to the queries and keys in each self-attention layer.
1049
+ object_queries_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1050
+ Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer.
1171
1051
  reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
1172
1052
  Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
1173
1053
  spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
@@ -1177,28 +1057,11 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1177
1057
  valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
1178
1058
  Ratio of valid area in each feature level.
1179
1059
 
1180
- output_attentions (`bool`, *optional*):
1181
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1182
- returned tensors for more detail.
1183
- output_hidden_states (`bool`, *optional*):
1184
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1185
- for more detail.
1186
- return_dict (`bool`, *optional*):
1187
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1188
1060
  """
1189
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1190
- output_hidden_states = (
1191
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1192
- )
1193
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1194
-
1195
1061
  if inputs_embeds is not None:
1196
1062
  hidden_states = inputs_embeds
1197
1063
 
1198
1064
  # decoder layers
1199
- all_hidden_states = () if output_hidden_states else None
1200
- all_self_attns = () if output_attentions else None
1201
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1202
1065
  intermediate = ()
1203
1066
  intermediate_reference_points = ()
1204
1067
 
@@ -1213,23 +1076,18 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1213
1076
  else:
1214
1077
  raise ValueError("Reference points' last dimension must be of size 2")
1215
1078
 
1216
- if output_hidden_states:
1217
- all_hidden_states += (hidden_states,)
1218
-
1219
- layer_outputs = decoder_layer(
1079
+ hidden_states = decoder_layer(
1220
1080
  hidden_states,
1221
- position_embeddings,
1081
+ object_queries_position_embeddings,
1222
1082
  reference_points_input,
1223
1083
  spatial_shapes,
1224
1084
  spatial_shapes_list,
1225
1085
  level_start_index,
1226
1086
  encoder_hidden_states, # as a positional argument for gradient checkpointing
1227
1087
  encoder_attention_mask,
1228
- output_attentions,
1088
+ **kwargs,
1229
1089
  )
1230
1090
 
1231
- hidden_states = layer_outputs[0]
1232
-
1233
1091
  # hack implementation for iterative bounding box refinement
1234
1092
  if self.bbox_embed is not None:
1235
1093
  tmp = self.bbox_embed[idx](hidden_states)
@@ -1250,40 +1108,14 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1250
1108
  intermediate += (hidden_states,)
1251
1109
  intermediate_reference_points += (reference_points,)
1252
1110
 
1253
- if output_attentions:
1254
- all_self_attns += (layer_outputs[1],)
1255
-
1256
- if encoder_hidden_states is not None:
1257
- all_cross_attentions += (layer_outputs[2],)
1258
-
1259
1111
  # Keep batch_size as first dimension
1260
1112
  intermediate = torch.stack(intermediate, dim=1)
1261
1113
  intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
1262
1114
 
1263
- # add hidden states from the last decoder layer
1264
- if output_hidden_states:
1265
- all_hidden_states += (hidden_states,)
1266
-
1267
- if not return_dict:
1268
- return tuple(
1269
- v
1270
- for v in [
1271
- hidden_states,
1272
- intermediate,
1273
- intermediate_reference_points,
1274
- all_hidden_states,
1275
- all_self_attns,
1276
- all_cross_attentions,
1277
- ]
1278
- if v is not None
1279
- )
1280
1115
  return DeformableDetrDecoderOutput(
1281
1116
  last_hidden_state=hidden_states,
1282
1117
  intermediate_hidden_states=intermediate,
1283
1118
  intermediate_reference_points=intermediate_reference_points,
1284
- hidden_states=all_hidden_states,
1285
- attentions=all_self_attns,
1286
- cross_attentions=all_cross_attentions,
1287
1119
  )
1288
1120
 
1289
1121
 
@@ -1297,17 +1129,23 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1297
1129
  def __init__(self, config: DeformableDetrConfig):
1298
1130
  super().__init__(config)
1299
1131
 
1300
- # Create backbone + positional encoding
1301
- backbone = DeformableDetrConvEncoder(config)
1302
- position_embeddings = build_position_encoding(config)
1303
- self.backbone = DeformableDetrConvModel(backbone, position_embeddings)
1132
+ # Create backbone
1133
+ self.backbone = DeformableDetrConvEncoder(config)
1134
+
1135
+ # Create positional encoding
1136
+ if config.position_embedding_type == "sine":
1137
+ self.position_embedding = DeformableDetrSinePositionEmbedding(config.d_model // 2, normalize=True)
1138
+ elif config.position_embedding_type == "learned":
1139
+ self.position_embedding = DeformableDetrLearnedPositionEmbedding(config.d_model // 2)
1140
+ else:
1141
+ raise ValueError(f"Not supported {config.position_embedding_type}")
1304
1142
 
1305
1143
  # Create input projection layers
1306
1144
  if config.num_feature_levels > 1:
1307
- num_backbone_outs = len(backbone.intermediate_channel_sizes)
1145
+ num_backbone_outs = len(self.backbone.intermediate_channel_sizes)
1308
1146
  input_proj_list = []
1309
1147
  for _ in range(num_backbone_outs):
1310
- in_channels = backbone.intermediate_channel_sizes[_]
1148
+ in_channels = self.backbone.intermediate_channel_sizes[_]
1311
1149
  input_proj_list.append(
1312
1150
  nn.Sequential(
1313
1151
  nn.Conv2d(in_channels, config.d_model, kernel_size=1),
@@ -1334,7 +1172,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1334
1172
  [
1335
1173
  nn.Sequential(
1336
1174
  nn.Conv2d(
1337
- backbone.intermediate_channel_sizes[-1],
1175
+ self.backbone.intermediate_channel_sizes[-1],
1338
1176
  config.d_model,
1339
1177
  kernel_size=1,
1340
1178
  ),
@@ -1362,11 +1200,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1362
1200
  self.post_init()
1363
1201
 
1364
1202
  def freeze_backbone(self):
1365
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1203
+ for name, param in self.backbone.model.named_parameters():
1366
1204
  param.requires_grad_(False)
1367
1205
 
1368
1206
  def unfreeze_backbone(self):
1369
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1207
+ for name, param in self.backbone.model.named_parameters():
1370
1208
  param.requires_grad_(True)
1371
1209
 
1372
1210
  def get_valid_ratio(self, mask, dtype=torch.float32):
@@ -1387,15 +1225,18 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1387
1225
  temperature = 10000
1388
1226
  scale = 2 * math.pi
1389
1227
 
1390
- dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
1228
+ # Compute position embeddings in float32 to avoid overflow with large temperature values in fp16
1229
+ proposals_dtype = proposals.dtype
1230
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
1391
1231
  dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
1392
1232
  # batch_size, num_queries, 4
1393
- proposals = proposals.sigmoid() * scale
1233
+ proposals = proposals.sigmoid().to(torch.float32) * scale
1394
1234
  # batch_size, num_queries, 4, 128
1395
1235
  pos = proposals[:, :, :, None] / dim_t
1396
1236
  # batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512
1397
1237
  pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
1398
- return pos
1238
+ # Convert back to target dtype after all computations are done
1239
+ return pos.to(proposals_dtype)
1399
1240
 
1400
1241
  def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
1401
1242
  """Generate the encoder output proposals from encoded enc_output.
@@ -1459,19 +1300,17 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1459
1300
  return object_query, output_proposals
1460
1301
 
1461
1302
  @auto_docstring
1303
+ @can_return_tuple
1462
1304
  def forward(
1463
1305
  self,
1464
1306
  pixel_values: torch.FloatTensor,
1465
- pixel_mask: Optional[torch.LongTensor] = None,
1466
- decoder_attention_mask: Optional[torch.FloatTensor] = None,
1467
- encoder_outputs: Optional[torch.FloatTensor] = None,
1468
- inputs_embeds: Optional[torch.FloatTensor] = None,
1469
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1470
- output_attentions: Optional[bool] = None,
1471
- output_hidden_states: Optional[bool] = None,
1472
- return_dict: Optional[bool] = None,
1473
- **kwargs,
1474
- ) -> Union[tuple[torch.FloatTensor], DeformableDetrModelOutput]:
1307
+ pixel_mask: torch.LongTensor | None = None,
1308
+ decoder_attention_mask: torch.FloatTensor | None = None,
1309
+ encoder_outputs: torch.FloatTensor | None = None,
1310
+ inputs_embeds: torch.FloatTensor | None = None,
1311
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
1312
+ **kwargs: Unpack[TransformersKwargs],
1313
+ ) -> tuple[torch.FloatTensor] | DeformableDetrModelOutput:
1475
1314
  r"""
1476
1315
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1477
1316
  Not used by default. Can be used to mask object queries.
@@ -1503,12 +1342,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1503
1342
  >>> list(last_hidden_states.shape)
1504
1343
  [1, 300, 256]
1505
1344
  ```"""
1506
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1507
- output_hidden_states = (
1508
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1509
- )
1510
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1511
-
1512
1345
  batch_size, num_channels, height, width = pixel_values.shape
1513
1346
  device = pixel_values.device
1514
1347
 
@@ -1518,16 +1351,22 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1518
1351
  # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
1519
1352
  # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1520
1353
  # which is a list of tuples
1521
- features, position_embeddings_list = self.backbone(pixel_values, pixel_mask)
1354
+ features = self.backbone(pixel_values, pixel_mask)
1522
1355
 
1523
1356
  # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1524
1357
  sources = []
1525
1358
  masks = []
1359
+ position_embeddings_list = []
1526
1360
  for level, (source, mask) in enumerate(features):
1527
1361
  sources.append(self.input_proj[level](source))
1528
1362
  masks.append(mask)
1529
1363
  if mask is None:
1530
1364
  raise ValueError("No attention mask was provided")
1365
+ # Generate position embeddings for this feature level
1366
+ pos = self.position_embedding(shape=source.shape, device=device, dtype=pixel_values.dtype, mask=mask).to(
1367
+ source.dtype
1368
+ )
1369
+ position_embeddings_list.append(pos)
1531
1370
 
1532
1371
  # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
1533
1372
  if self.config.num_feature_levels > len(sources):
@@ -1540,7 +1379,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1540
1379
  mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to(
1541
1380
  torch.bool
1542
1381
  )[0]
1543
- pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
1382
+ pos_l = self.position_embedding(
1383
+ shape=source.shape, device=device, dtype=pixel_values.dtype, mask=mask
1384
+ ).to(source.dtype)
1544
1385
  sources.append(source)
1545
1386
  masks.append(mask)
1546
1387
  position_embeddings_list.append(pos_l)
@@ -1561,7 +1402,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1561
1402
  spatial_shapes_list.append(spatial_shape)
1562
1403
  source = source.flatten(2).transpose(1, 2)
1563
1404
  mask = mask.flatten(1)
1564
- pos_embed = pos_embed.flatten(2).transpose(1, 2)
1565
1405
  lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1)
1566
1406
  lvl_pos_embed_flatten.append(lvl_pos_embed)
1567
1407
  source_flatten.append(source)
@@ -1579,21 +1419,12 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1579
1419
  encoder_outputs = self.encoder(
1580
1420
  inputs_embeds=source_flatten,
1581
1421
  attention_mask=mask_flatten,
1582
- position_embeddings=lvl_pos_embed_flatten,
1422
+ spatial_position_embeddings=lvl_pos_embed_flatten,
1583
1423
  spatial_shapes=spatial_shapes,
1584
1424
  spatial_shapes_list=spatial_shapes_list,
1585
1425
  level_start_index=level_start_index,
1586
1426
  valid_ratios=valid_ratios,
1587
- output_attentions=output_attentions,
1588
- output_hidden_states=output_hidden_states,
1589
- return_dict=return_dict,
1590
- )
1591
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1592
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1593
- encoder_outputs = BaseModelOutput(
1594
- last_hidden_state=encoder_outputs[0],
1595
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1596
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1427
+ **kwargs,
1597
1428
  )
1598
1429
 
1599
1430
  # Fifth, prepare decoder inputs
@@ -1636,7 +1467,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1636
1467
 
1637
1468
  decoder_outputs = self.decoder(
1638
1469
  inputs_embeds=target,
1639
- position_embeddings=query_embed,
1470
+ object_queries_position_embeddings=query_embed,
1640
1471
  encoder_hidden_states=encoder_outputs[0],
1641
1472
  encoder_attention_mask=mask_flatten,
1642
1473
  reference_points=reference_points,
@@ -1644,17 +1475,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1644
1475
  spatial_shapes_list=spatial_shapes_list,
1645
1476
  level_start_index=level_start_index,
1646
1477
  valid_ratios=valid_ratios,
1647
- output_attentions=output_attentions,
1648
- output_hidden_states=output_hidden_states,
1649
- return_dict=return_dict,
1478
+ **kwargs,
1650
1479
  )
1651
1480
 
1652
- if not return_dict:
1653
- enc_outputs = tuple(value for value in [enc_outputs_class, enc_outputs_coord_logits] if value is not None)
1654
- tuple_outputs = (init_reference_points,) + decoder_outputs + encoder_outputs + enc_outputs
1655
-
1656
- return tuple_outputs
1657
-
1658
1481
  return DeformableDetrModelOutput(
1659
1482
  init_reference_points=init_reference_points,
1660
1483
  last_hidden_state=decoder_outputs.last_hidden_state,
@@ -1671,14 +1494,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1671
1494
  )
1672
1495
 
1673
1496
 
1674
- # Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead
1675
1497
  class DeformableDetrMLPPredictionHead(nn.Module):
1676
1498
  """
1677
1499
  Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1678
1500
  height and width of a bounding box w.r.t. an image.
1679
1501
 
1680
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1681
-
1682
1502
  """
1683
1503
 
1684
1504
  def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
@@ -1727,29 +1547,29 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1727
1547
  for _ in range(num_pred)
1728
1548
  ]
1729
1549
  )
1550
+ # Convert to instance attribute before modifying
1551
+ self._tied_weights_keys = self._tied_weights_keys.copy()
1730
1552
  if config.with_box_refine:
1731
1553
  self.model.decoder.bbox_embed = self.bbox_embed
1732
- self._tied_weights_keys["model.decoder.bbox_embed"] = "bbox_embed"
1554
+ self._tied_weights_keys["bbox_embed"] = "model.decoder.bbox_embed"
1733
1555
  if config.two_stage:
1734
1556
  self.model.decoder.class_embed = self.class_embed
1735
- self._tied_weights_keys["model.decoder.class_embed"] = "class_embed"
1557
+ self._tied_weights_keys["class_embed"] = "model.decoder.class_embed"
1736
1558
  self.post_init()
1737
1559
 
1738
1560
  @auto_docstring
1561
+ @can_return_tuple
1739
1562
  def forward(
1740
1563
  self,
1741
1564
  pixel_values: torch.FloatTensor,
1742
- pixel_mask: Optional[torch.LongTensor] = None,
1743
- decoder_attention_mask: Optional[torch.FloatTensor] = None,
1744
- encoder_outputs: Optional[torch.FloatTensor] = None,
1745
- inputs_embeds: Optional[torch.FloatTensor] = None,
1746
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1747
- labels: Optional[list[dict]] = None,
1748
- output_attentions: Optional[bool] = None,
1749
- output_hidden_states: Optional[bool] = None,
1750
- return_dict: Optional[bool] = None,
1751
- **kwargs,
1752
- ) -> Union[tuple[torch.FloatTensor], DeformableDetrObjectDetectionOutput]:
1565
+ pixel_mask: torch.LongTensor | None = None,
1566
+ decoder_attention_mask: torch.FloatTensor | None = None,
1567
+ encoder_outputs: torch.FloatTensor | None = None,
1568
+ inputs_embeds: torch.FloatTensor | None = None,
1569
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
1570
+ labels: list[dict] | None = None,
1571
+ **kwargs: Unpack[TransformersKwargs],
1572
+ ) -> tuple[torch.FloatTensor] | DeformableDetrObjectDetectionOutput:
1753
1573
  r"""
1754
1574
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1755
1575
  Not used by default. Can be used to mask object queries.
@@ -1796,8 +1616,6 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1796
1616
  Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25]
1797
1617
  Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25]
1798
1618
  ```"""
1799
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1800
-
1801
1619
  # First, sent images through DETR base model to obtain encoder + decoder outputs
1802
1620
  outputs = self.model(
1803
1621
  pixel_values,
@@ -1806,14 +1624,12 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1806
1624
  encoder_outputs=encoder_outputs,
1807
1625
  inputs_embeds=inputs_embeds,
1808
1626
  decoder_inputs_embeds=decoder_inputs_embeds,
1809
- output_attentions=output_attentions,
1810
- output_hidden_states=output_hidden_states,
1811
- return_dict=return_dict,
1627
+ **kwargs,
1812
1628
  )
1813
1629
 
1814
- hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2]
1815
- init_reference = outputs.init_reference_points if return_dict else outputs[0]
1816
- inter_references = outputs.intermediate_reference_points if return_dict else outputs[3]
1630
+ hidden_states = outputs.intermediate_hidden_states
1631
+ init_reference = outputs.init_reference_points
1632
+ inter_references = outputs.intermediate_reference_points
1817
1633
 
1818
1634
  # class logits + predicted bounding boxes
1819
1635
  outputs_classes = []
@@ -1854,16 +1670,8 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1854
1670
  outputs_class,
1855
1671
  outputs_coord,
1856
1672
  )
1857
- if not return_dict:
1858
- if auxiliary_outputs is not None:
1859
- output = (logits, pred_boxes) + auxiliary_outputs + outputs
1860
- else:
1861
- output = (logits, pred_boxes) + outputs
1862
- tuple_outputs = ((loss, loss_dict) + output) if loss is not None else output
1863
1673
 
1864
- return tuple_outputs
1865
-
1866
- dict_outputs = DeformableDetrObjectDetectionOutput(
1674
+ return DeformableDetrObjectDetectionOutput(
1867
1675
  loss=loss,
1868
1676
  loss_dict=loss_dict,
1869
1677
  logits=logits,
@@ -1883,11 +1691,5 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1883
1691
  enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
1884
1692
  )
1885
1693
 
1886
- return dict_outputs
1887
-
1888
1694
 
1889
- __all__ = [
1890
- "DeformableDetrForObjectDetection",
1891
- "DeformableDetrModel",
1892
- "DeformableDetrPreTrainedModel",
1893
- ]
1695
+ __all__ = ["DeformableDetrForObjectDetection", "DeformableDetrModel", "DeformableDetrPreTrainedModel"]