transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,9 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/rt_detr/modular_rt_detr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_rt_detr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
1
7
  # Copyright 2024 Baidu Inc and The HuggingFace Inc. team.
2
8
  #
3
9
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,10 +17,9 @@
11
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
18
  # See the License for the specific language governing permissions and
13
19
  # limitations under the License.
14
- """PyTorch RT-DETR model."""
15
-
16
20
  import math
17
21
  import warnings
22
+ from collections.abc import Callable
18
23
  from dataclasses import dataclass
19
24
 
20
25
  import torch
@@ -23,83 +28,18 @@ from torch import Tensor, nn
23
28
 
24
29
  from ... import initialization as init
25
30
  from ...activations import ACT2CLS, ACT2FN
31
+ from ...backbone_utils import load_backbone
26
32
  from ...image_transforms import center_to_corners_format, corners_to_center_format
27
33
  from ...integrations import use_kernel_forward_from_hub
28
34
  from ...modeling_outputs import BaseModelOutput
29
- from ...modeling_utils import PreTrainedModel
35
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from ...processing_utils import Unpack
30
37
  from ...pytorch_utils import compile_compatible_method_lru_cache
31
- from ...utils import (
32
- ModelOutput,
33
- auto_docstring,
34
- logging,
35
- torch_int,
36
- )
37
- from ...utils.backbone_utils import load_backbone
38
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int
39
+ from ...utils.generic import can_return_tuple, check_model_inputs
38
40
  from .configuration_rt_detr import RTDetrConfig
39
41
 
40
42
 
41
- logger = logging.get_logger(__name__)
42
-
43
-
44
- # TODO: Replace all occurrences of the checkpoint with the final one
45
-
46
-
47
- @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
48
- # Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
49
- class MultiScaleDeformableAttention(nn.Module):
50
- def forward(
51
- self,
52
- value: Tensor,
53
- value_spatial_shapes: Tensor,
54
- value_spatial_shapes_list: list[tuple],
55
- level_start_index: Tensor,
56
- sampling_locations: Tensor,
57
- attention_weights: Tensor,
58
- im2col_step: int,
59
- ):
60
- batch_size, _, num_heads, hidden_dim = value.shape
61
- _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
62
- value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
63
- sampling_grids = 2 * sampling_locations - 1
64
- sampling_value_list = []
65
- for level_id, (height, width) in enumerate(value_spatial_shapes_list):
66
- # batch_size, height*width, num_heads, hidden_dim
67
- # -> batch_size, height*width, num_heads*hidden_dim
68
- # -> batch_size, num_heads*hidden_dim, height*width
69
- # -> batch_size*num_heads, hidden_dim, height, width
70
- value_l_ = (
71
- value_list[level_id]
72
- .flatten(2)
73
- .transpose(1, 2)
74
- .reshape(batch_size * num_heads, hidden_dim, height, width)
75
- )
76
- # batch_size, num_queries, num_heads, num_points, 2
77
- # -> batch_size, num_heads, num_queries, num_points, 2
78
- # -> batch_size*num_heads, num_queries, num_points, 2
79
- sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
80
- # batch_size*num_heads, hidden_dim, num_queries, num_points
81
- sampling_value_l_ = nn.functional.grid_sample(
82
- value_l_,
83
- sampling_grid_l_,
84
- mode="bilinear",
85
- padding_mode="zeros",
86
- align_corners=False,
87
- )
88
- sampling_value_list.append(sampling_value_l_)
89
- # (batch_size, num_queries, num_heads, num_levels, num_points)
90
- # -> (batch_size, num_heads, num_queries, num_levels, num_points)
91
- # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
92
- attention_weights = attention_weights.transpose(1, 2).reshape(
93
- batch_size * num_heads, 1, num_queries, num_levels * num_points
94
- )
95
- output = (
96
- (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
97
- .sum(-1)
98
- .view(batch_size, num_heads * hidden_dim, num_queries)
99
- )
100
- return output.transpose(1, 2).contiguous()
101
-
102
-
103
43
  @dataclass
104
44
  @auto_docstring(
105
45
  custom_intro="""
@@ -274,19 +214,23 @@ class RTDetrObjectDetectionOutput(ModelOutput):
274
214
  denoising_meta_values: dict | None = None
275
215
 
276
216
 
277
- def _get_clones(partial_module, N):
278
- return nn.ModuleList([partial_module() for i in range(N)])
279
-
217
+ class RTDetrMLP(nn.Module):
218
+ def __init__(self, config: RTDetrConfig, hidden_size: int, intermediate_size: int, activation_function: str):
219
+ super().__init__()
220
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
221
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
222
+ self.activation_fn = ACT2FN[activation_function]
223
+ self.activation_dropout = config.activation_dropout
224
+ self.dropout = config.dropout
280
225
 
281
- # Copied from transformers.models.conditional_detr.modeling_conditional_detr.inverse_sigmoid
282
- def inverse_sigmoid(x, eps=1e-5):
283
- x = x.clamp(min=0, max=1)
284
- x1 = x.clamp(min=eps)
285
- x2 = (1 - x).clamp(min=eps)
286
- return torch.log(x1 / x2)
226
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
227
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
228
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
229
+ hidden_states = self.fc2(hidden_states)
230
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
231
+ return hidden_states
287
232
 
288
233
 
289
- # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->RTDetr
290
234
  class RTDetrFrozenBatchNorm2d(nn.Module):
291
235
  """
292
236
  BatchNorm2d where the batch statistics and the affine parameters are fixed.
@@ -326,152 +270,123 @@ class RTDetrFrozenBatchNorm2d(nn.Module):
326
270
  return x * scale + bias
327
271
 
328
272
 
329
- # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->RTDetr
330
- def replace_batch_norm(model):
331
- r"""
332
- Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrFrozenBatchNorm2d`.
273
+ def eager_attention_forward(
274
+ module: nn.Module,
275
+ query: torch.Tensor,
276
+ key: torch.Tensor,
277
+ value: torch.Tensor,
278
+ attention_mask: torch.Tensor | None,
279
+ scaling: float | None = None,
280
+ dropout: float = 0.0,
281
+ **kwargs: Unpack[TransformersKwargs],
282
+ ):
283
+ if scaling is None:
284
+ scaling = query.size(-1) ** -0.5
333
285
 
334
- Args:
335
- model (torch.nn.Module):
336
- input model
337
- """
338
- for name, module in model.named_children():
339
- if isinstance(module, nn.BatchNorm2d):
340
- new_module = RTDetrFrozenBatchNorm2d(module.num_features)
286
+ # Take the dot product between "query" and "key" to get the raw attention scores.
287
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
341
288
 
342
- if module.weight.device != torch.device("meta"):
343
- new_module.weight.copy_(module.weight)
344
- new_module.bias.copy_(module.bias)
345
- new_module.running_mean.copy_(module.running_mean)
346
- new_module.running_var.copy_(module.running_var)
289
+ if attention_mask is not None:
290
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
291
+ attn_weights = attn_weights + attention_mask
347
292
 
348
- model._modules[name] = new_module
293
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
294
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
349
295
 
350
- if len(list(module.children())) > 0:
351
- replace_batch_norm(module)
296
+ attn_output = torch.matmul(attn_weights, value)
297
+ attn_output = attn_output.transpose(1, 2).contiguous()
352
298
 
299
+ return attn_output, attn_weights
353
300
 
354
- def get_contrastive_denoising_training_group(
355
- targets,
356
- num_classes,
357
- num_queries,
358
- class_embed,
359
- num_denoising_queries=100,
360
- label_noise_ratio=0.5,
361
- box_noise_scale=1.0,
362
- ):
301
+
302
+ class RTDetrSelfAttention(nn.Module):
363
303
  """
364
- Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
304
+ Multi-headed self-attention from 'Attention Is All You Need' paper.
365
305
 
366
- Args:
367
- targets (`list[dict]`):
368
- The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
369
- num_classes (`int`):
370
- Total number of classes in the dataset.
371
- num_queries (`int`):
372
- Number of query slots in the transformer.
373
- class_embed (`callable`):
374
- A function or a model layer to embed class labels.
375
- num_denoising_queries (`int`, *optional*, defaults to 100):
376
- Number of denoising queries.
377
- label_noise_ratio (`float`, *optional*, defaults to 0.5):
378
- Ratio of noise applied to labels.
379
- box_noise_scale (`float`, *optional*, defaults to 1.0):
380
- Scale of noise applied to bounding boxes.
381
- Returns:
382
- `tuple` comprising various elements:
383
- - **input_query_class** (`torch.FloatTensor`) --
384
- Class queries with applied label noise.
385
- - **input_query_bbox** (`torch.FloatTensor`) --
386
- Bounding box queries with applied box noise.
387
- - **attn_mask** (`torch.FloatTensor`) --
388
- Attention mask for separating denoising and reconstruction queries.
389
- - **denoising_meta_values** (`dict`) --
390
- Metadata including denoising positive indices, number of groups, and split sizes.
306
+ In RT_DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
391
307
  """
392
308
 
393
- if num_denoising_queries <= 0:
394
- return None, None, None, None
309
+ def __init__(
310
+ self,
311
+ config: RTDetrConfig,
312
+ hidden_size: int,
313
+ num_attention_heads: int,
314
+ dropout: float = 0.0,
315
+ bias: bool = True,
316
+ ):
317
+ super().__init__()
318
+ self.config = config
319
+ self.head_dim = hidden_size // num_attention_heads
320
+ self.scaling = self.head_dim**-0.5
321
+ self.attention_dropout = dropout
322
+ self.is_causal = False
395
323
 
396
- num_ground_truths = [len(t["class_labels"]) for t in targets]
397
- device = targets[0]["class_labels"].device
324
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
325
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
326
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
327
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
398
328
 
399
- max_gt_num = max(num_ground_truths)
400
- if max_gt_num == 0:
401
- return None, None, None, None
329
+ def forward(
330
+ self,
331
+ hidden_states: torch.Tensor,
332
+ attention_mask: torch.Tensor | None = None,
333
+ position_embeddings: torch.Tensor | None = None,
334
+ **kwargs: Unpack[TransformersKwargs],
335
+ ) -> tuple[torch.Tensor, torch.Tensor]:
336
+ """
337
+ Position embeddings are added to both queries and keys (but not values).
338
+ """
339
+ input_shape = hidden_states.shape[:-1]
340
+ hidden_shape = (*input_shape, -1, self.head_dim)
402
341
 
403
- num_groups_denoising_queries = num_denoising_queries // max_gt_num
404
- num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
405
- # pad gt to max_num of a batch
406
- batch_size = len(num_ground_truths)
342
+ query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
407
343
 
408
- input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
409
- input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
410
- pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
344
+ query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
345
+ key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
346
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
411
347
 
412
- for i in range(batch_size):
413
- num_gt = num_ground_truths[i]
414
- if num_gt > 0:
415
- input_query_class[i, :num_gt] = targets[i]["class_labels"]
416
- input_query_bbox[i, :num_gt] = targets[i]["boxes"]
417
- pad_gt_mask[i, :num_gt] = 1
418
- # each group has positive and negative queries.
419
- input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
420
- input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
421
- pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
422
- # positive and negative mask
423
- negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
424
- negative_gt_mask[:, max_gt_num:] = 1
425
- negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
426
- positive_gt_mask = 1 - negative_gt_mask
427
- # contrastive denoising training positive index
428
- positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
429
- denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
430
- denoise_positive_idx = torch.split(
431
- denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
432
- )
433
- # total denoising queries
434
- num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
348
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
349
+ self.config._attn_implementation, eager_attention_forward
350
+ )
435
351
 
436
- if label_noise_ratio > 0:
437
- mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
438
- # randomly put a new one here
439
- new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
440
- input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
352
+ attn_output, attn_weights = attention_interface(
353
+ self,
354
+ query_states,
355
+ key_states,
356
+ value_states,
357
+ attention_mask,
358
+ dropout=0.0 if not self.training else self.attention_dropout,
359
+ scaling=self.scaling,
360
+ **kwargs,
361
+ )
441
362
 
442
- if box_noise_scale > 0:
443
- known_bbox = center_to_corners_format(input_query_bbox)
444
- diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
445
- rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
446
- rand_part = torch.rand_like(input_query_bbox)
447
- rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
448
- rand_part *= rand_sign
449
- known_bbox += rand_part * diff
450
- known_bbox.clip_(min=0.0, max=1.0)
451
- input_query_bbox = corners_to_center_format(known_bbox)
452
- input_query_bbox = inverse_sigmoid(input_query_bbox)
363
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
364
+ attn_output = self.o_proj(attn_output)
365
+ return attn_output, attn_weights
453
366
 
454
- input_query_class = class_embed(input_query_class)
455
367
 
456
- target_size = num_denoising_queries + num_queries
457
- attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
458
- # match query cannot see the reconstruction
459
- attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
368
+ def replace_batch_norm(model):
369
+ r"""
370
+ Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrFrozenBatchNorm2d`.
460
371
 
461
- # reconstructions cannot see each other
462
- for i in range(num_groups_denoising_queries):
463
- idx_block_start = max_gt_num * 2 * i
464
- idx_block_end = max_gt_num * 2 * (i + 1)
465
- attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
466
- attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
372
+ Args:
373
+ model (torch.nn.Module):
374
+ input model
375
+ """
376
+ for name, module in model.named_children():
377
+ if isinstance(module, nn.BatchNorm2d):
378
+ new_module = RTDetrFrozenBatchNorm2d(module.num_features)
467
379
 
468
- denoising_meta_values = {
469
- "dn_positive_idx": denoise_positive_idx,
470
- "dn_num_group": num_groups_denoising_queries,
471
- "dn_num_split": [num_denoising_queries, num_queries],
472
- }
380
+ if module.weight.device != torch.device("meta"):
381
+ new_module.weight.copy_(module.weight)
382
+ new_module.bias.copy_(module.bias)
383
+ new_module.running_mean.copy_(module.running_mean)
384
+ new_module.running_var.copy_(module.running_var)
473
385
 
474
- return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
386
+ model._modules[name] = new_module
387
+
388
+ if len(list(module.children())) > 0:
389
+ replace_batch_norm(module)
475
390
 
476
391
 
477
392
  class RTDetrConvEncoder(nn.Module):
@@ -531,50 +446,46 @@ class RTDetrEncoderLayer(nn.Module):
531
446
  def __init__(self, config: RTDetrConfig):
532
447
  super().__init__()
533
448
  self.normalize_before = config.normalize_before
449
+ self.hidden_size = config.encoder_hidden_dim
534
450
 
535
451
  # self-attention
536
- self.self_attn = RTDetrMultiheadAttention(
537
- embed_dim=config.encoder_hidden_dim,
538
- num_heads=config.num_attention_heads,
452
+ self.self_attn = RTDetrSelfAttention(
453
+ config=config,
454
+ hidden_size=self.hidden_size,
455
+ num_attention_heads=config.num_attention_heads,
539
456
  dropout=config.dropout,
540
457
  )
541
- self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
458
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
542
459
  self.dropout = config.dropout
543
- self.activation_fn = ACT2FN[config.encoder_activation_function]
544
- self.activation_dropout = config.activation_dropout
545
- self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
546
- self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
547
- self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
460
+ self.mlp = RTDetrMLP(config, self.hidden_size, config.encoder_ffn_dim, config.encoder_activation_function)
461
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
548
462
 
549
463
  def forward(
550
464
  self,
551
465
  hidden_states: torch.Tensor,
552
466
  attention_mask: torch.Tensor,
553
- position_embeddings: torch.Tensor | None = None,
554
- output_attentions: bool = False,
555
- **kwargs,
556
- ):
467
+ spatial_position_embeddings: torch.Tensor | None = None,
468
+ **kwargs: Unpack[TransformersKwargs],
469
+ ) -> torch.Tensor:
557
470
  """
558
471
  Args:
559
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
472
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
560
473
  attention_mask (`torch.FloatTensor`): attention mask of size
561
474
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
562
475
  values.
563
- position_embeddings (`torch.FloatTensor`, *optional*):
564
- Object queries (also called content embeddings), to be added to the hidden states.
565
- output_attentions (`bool`, *optional*):
566
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
567
- returned tensors for more detail.
476
+ spatial_position_embeddings (`torch.FloatTensor`, *optional*):
477
+ Spatial position embeddings (2D positional encodings of image locations), to be added to both
478
+ the queries and keys in self-attention (but not to values).
568
479
  """
569
480
  residual = hidden_states
570
481
  if self.normalize_before:
571
482
  hidden_states = self.self_attn_layer_norm(hidden_states)
572
483
 
573
- hidden_states, attn_weights = self.self_attn(
484
+ hidden_states, _ = self.self_attn(
574
485
  hidden_states=hidden_states,
575
486
  attention_mask=attention_mask,
576
- position_embeddings=position_embeddings,
577
- output_attentions=output_attentions,
487
+ position_embeddings=spatial_position_embeddings,
488
+ **kwargs,
578
489
  )
579
490
 
580
491
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -586,12 +497,7 @@ class RTDetrEncoderLayer(nn.Module):
586
497
  hidden_states = self.final_layer_norm(hidden_states)
587
498
  residual = hidden_states
588
499
 
589
- hidden_states = self.activation_fn(self.fc1(hidden_states))
590
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
591
-
592
- hidden_states = self.fc2(hidden_states)
593
-
594
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
500
+ hidden_states = self.mlp(hidden_states)
595
501
 
596
502
  hidden_states = residual + hidden_states
597
503
  if not self.normalize_before:
@@ -602,12 +508,7 @@ class RTDetrEncoderLayer(nn.Module):
602
508
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
603
509
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
604
510
 
605
- outputs = (hidden_states,)
606
-
607
- if output_attentions:
608
- outputs += (attn_weights,)
609
-
610
- return outputs
511
+ return hidden_states
611
512
 
612
513
 
613
514
  class RTDetrRepVggBlock(nn.Module):
@@ -658,7 +559,61 @@ class RTDetrCSPRepLayer(nn.Module):
658
559
  return self.conv3(hidden_state_1 + hidden_state_2)
659
560
 
660
561
 
661
- # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr
562
+ @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
563
+ class MultiScaleDeformableAttention(nn.Module):
564
+ def forward(
565
+ self,
566
+ value: Tensor,
567
+ value_spatial_shapes: Tensor,
568
+ value_spatial_shapes_list: list[tuple],
569
+ level_start_index: Tensor,
570
+ sampling_locations: Tensor,
571
+ attention_weights: Tensor,
572
+ im2col_step: int,
573
+ ):
574
+ batch_size, _, num_heads, hidden_dim = value.shape
575
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
576
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
577
+ sampling_grids = 2 * sampling_locations - 1
578
+ sampling_value_list = []
579
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
580
+ # batch_size, height*width, num_heads, hidden_dim
581
+ # -> batch_size, height*width, num_heads*hidden_dim
582
+ # -> batch_size, num_heads*hidden_dim, height*width
583
+ # -> batch_size*num_heads, hidden_dim, height, width
584
+ value_l_ = (
585
+ value_list[level_id]
586
+ .flatten(2)
587
+ .transpose(1, 2)
588
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
589
+ )
590
+ # batch_size, num_queries, num_heads, num_points, 2
591
+ # -> batch_size, num_heads, num_queries, num_points, 2
592
+ # -> batch_size*num_heads, num_queries, num_points, 2
593
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
594
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
595
+ sampling_value_l_ = nn.functional.grid_sample(
596
+ value_l_,
597
+ sampling_grid_l_,
598
+ mode="bilinear",
599
+ padding_mode="zeros",
600
+ align_corners=False,
601
+ )
602
+ sampling_value_list.append(sampling_value_l_)
603
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
604
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
605
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
606
+ attention_weights = attention_weights.transpose(1, 2).reshape(
607
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
608
+ )
609
+ output = (
610
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
611
+ .sum(-1)
612
+ .view(batch_size, num_heads * hidden_dim, num_queries)
613
+ )
614
+ return output.transpose(1, 2).contiguous()
615
+
616
+
662
617
  class RTDetrMultiscaleDeformableAttention(nn.Module):
663
618
  """
664
619
  Multiscale deformable attention as proposed in Deformable DETR.
@@ -696,9 +651,6 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
696
651
 
697
652
  self.disable_custom_kernels = config.disable_custom_kernels
698
653
 
699
- def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
700
- return tensor if position_embeddings is None else tensor + position_embeddings
701
-
702
654
  def forward(
703
655
  self,
704
656
  hidden_states: torch.Tensor,
@@ -710,19 +662,19 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
710
662
  spatial_shapes=None,
711
663
  spatial_shapes_list=None,
712
664
  level_start_index=None,
713
- output_attentions: bool = False,
714
- ):
665
+ **kwargs: Unpack[TransformersKwargs],
666
+ ) -> tuple[torch.Tensor, torch.Tensor]:
715
667
  # add position embeddings to the hidden states before projecting to queries and keys
716
668
  if position_embeddings is not None:
717
- hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
669
+ hidden_states = hidden_states + position_embeddings
718
670
 
719
671
  batch_size, num_queries, _ = hidden_states.shape
720
672
  batch_size, sequence_length, _ = encoder_hidden_states.shape
721
673
  total_elements = sum(height * width for height, width in spatial_shapes_list)
722
- if total_elements != sequence_length:
723
- raise ValueError(
724
- "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
725
- )
674
+ torch_compilable_check(
675
+ total_elements == sequence_length,
676
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
677
+ )
726
678
 
727
679
  value = self.value_proj(encoder_hidden_states)
728
680
  if attention_mask is not None:
@@ -769,235 +721,218 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
769
721
  return output, attention_weights
770
722
 
771
723
 
772
- class RTDetrMultiheadAttention(nn.Module):
773
- """
774
- Multi-headed attention from 'Attention Is All You Need' paper.
775
-
776
- Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
777
- """
778
-
779
- def __init__(
780
- self,
781
- embed_dim: int,
782
- num_heads: int,
783
- dropout: float = 0.0,
784
- bias: bool = True,
785
- ):
724
+ class RTDetrDecoderLayer(nn.Module):
725
+ def __init__(self, config: RTDetrConfig):
786
726
  super().__init__()
787
- self.embed_dim = embed_dim
788
- self.num_heads = num_heads
789
- self.dropout = dropout
790
- self.head_dim = embed_dim // num_heads
791
- if self.head_dim * num_heads != self.embed_dim:
792
- raise ValueError(
793
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
794
- f" {num_heads})."
795
- )
796
- self.scaling = self.head_dim**-0.5
797
-
798
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
799
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
800
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
801
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
727
+ self.hidden_size = config.d_model
802
728
 
803
- def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
804
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
729
+ # self-attention
730
+ self.self_attn = RTDetrSelfAttention(
731
+ config=config,
732
+ hidden_size=self.hidden_size,
733
+ num_attention_heads=config.decoder_attention_heads,
734
+ dropout=config.attention_dropout,
735
+ )
736
+ self.dropout = config.dropout
805
737
 
806
- def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
807
- return tensor if position_embeddings is None else tensor + position_embeddings
738
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
739
+ # cross-attention
740
+ self.encoder_attn = RTDetrMultiscaleDeformableAttention(
741
+ config,
742
+ num_heads=config.decoder_attention_heads,
743
+ n_points=config.decoder_n_points,
744
+ )
745
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
746
+ # feedforward neural networks
747
+ self.mlp = RTDetrMLP(config, self.hidden_size, config.decoder_ffn_dim, config.decoder_activation_function)
748
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
808
749
 
809
750
  def forward(
810
751
  self,
811
752
  hidden_states: torch.Tensor,
812
- attention_mask: torch.Tensor | None = None,
813
- position_embeddings: torch.Tensor | None = None,
814
- output_attentions: bool = False,
815
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
816
- """Input shape: Batch x Time x Channel"""
753
+ object_queries_position_embeddings: torch.Tensor | None = None,
754
+ reference_points=None,
755
+ spatial_shapes=None,
756
+ spatial_shapes_list=None,
757
+ level_start_index=None,
758
+ encoder_hidden_states: torch.Tensor | None = None,
759
+ encoder_attention_mask: torch.Tensor | None = None,
760
+ **kwargs: Unpack[TransformersKwargs],
761
+ ) -> torch.Tensor:
762
+ """
763
+ Args:
764
+ hidden_states (`torch.FloatTensor`):
765
+ Input to the layer of shape `(batch, seq_len, hidden_size)`.
766
+ object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
767
+ Position embeddings for the object query slots. These are added to both queries and keys
768
+ in the self-attention layer (not values).
769
+ reference_points (`torch.FloatTensor`, *optional*):
770
+ Reference points.
771
+ spatial_shapes (`torch.LongTensor`, *optional*):
772
+ Spatial shapes.
773
+ level_start_index (`torch.LongTensor`, *optional*):
774
+ Level start index.
775
+ encoder_hidden_states (`torch.FloatTensor`):
776
+ cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
777
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
778
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
779
+ values.
780
+ """
781
+ residual = hidden_states
817
782
 
818
- batch_size, target_len, embed_dim = hidden_states.size()
819
- # add position embeddings to the hidden states before projecting to queries and keys
820
- if position_embeddings is not None:
821
- hidden_states_original = hidden_states
822
- hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
783
+ # Self Attention
784
+ hidden_states, _ = self.self_attn(
785
+ hidden_states=hidden_states,
786
+ attention_mask=encoder_attention_mask,
787
+ position_embeddings=object_queries_position_embeddings,
788
+ **kwargs,
789
+ )
823
790
 
824
- # get queries, keys and values
825
- query_states = self.q_proj(hidden_states) * self.scaling
826
- key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size)
827
- value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
791
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
792
+ hidden_states = residual + hidden_states
793
+ hidden_states = self.self_attn_layer_norm(hidden_states)
828
794
 
829
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
830
- query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
831
- key_states = key_states.view(*proj_shape)
832
- value_states = value_states.view(*proj_shape)
795
+ residual = hidden_states
833
796
 
834
- source_len = key_states.size(1)
797
+ # Cross-Attention
798
+ hidden_states, _ = self.encoder_attn(
799
+ hidden_states=hidden_states,
800
+ encoder_hidden_states=encoder_hidden_states,
801
+ position_embeddings=object_queries_position_embeddings,
802
+ reference_points=reference_points,
803
+ spatial_shapes=spatial_shapes,
804
+ spatial_shapes_list=spatial_shapes_list,
805
+ level_start_index=level_start_index,
806
+ **kwargs,
807
+ )
835
808
 
836
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
809
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
810
+ hidden_states = residual + hidden_states
837
811
 
838
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
839
- raise ValueError(
840
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
841
- f" {attn_weights.size()}"
842
- )
812
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
843
813
 
844
- # expand attention_mask
845
- if attention_mask is not None:
846
- # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
847
- attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
814
+ # Fully Connected
815
+ residual = hidden_states
816
+ hidden_states = self.mlp(hidden_states)
817
+ hidden_states = residual + hidden_states
818
+ hidden_states = self.final_layer_norm(hidden_states)
819
+
820
+ return hidden_states
848
821
 
849
- if attention_mask is not None:
850
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
851
- raise ValueError(
852
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
853
- f" {attention_mask.size()}"
854
- )
855
- if attention_mask.dtype == torch.bool:
856
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
857
- attention_mask, -torch.inf
858
- )
859
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
860
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
861
-
862
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
863
-
864
- if output_attentions:
865
- # this operation is a bit awkward, but it's required to
866
- # make sure that attn_weights keeps its gradient.
867
- # In order to do so, attn_weights have to reshaped
868
- # twice and have to be reused in the following
869
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
870
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
871
- else:
872
- attn_weights_reshaped = None
873
822
 
874
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
823
+ class RTDetrSinePositionEmbedding(nn.Module):
824
+ """
825
+ 2D sinusoidal position embedding used in RT-DETR hybrid encoder.
826
+ """
875
827
 
876
- attn_output = torch.bmm(attn_probs, value_states)
828
+ def __init__(self, embed_dim: int = 256, temperature: int = 10000):
829
+ super().__init__()
830
+ self.embed_dim = embed_dim
831
+ self.temperature = temperature
877
832
 
878
- if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
879
- raise ValueError(
880
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
881
- f" {attn_output.size()}"
882
- )
833
+ @compile_compatible_method_lru_cache(maxsize=32)
834
+ def forward(
835
+ self,
836
+ width: int,
837
+ height: int,
838
+ device: torch.device | str,
839
+ dtype: torch.dtype,
840
+ ) -> torch.Tensor:
841
+ """
842
+ Generate 2D sinusoidal position embeddings.
883
843
 
884
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
885
- attn_output = attn_output.transpose(1, 2)
886
- attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
844
+ Returns:
845
+ Position embeddings of shape (1, height*width, embed_dim)
846
+ """
847
+ grid_w = torch.arange(torch_int(width), device=device).to(dtype)
848
+ grid_h = torch.arange(torch_int(height), device=device).to(dtype)
849
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
850
+ if self.embed_dim % 4 != 0:
851
+ raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
852
+ pos_dim = self.embed_dim // 4
853
+ omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
854
+ omega = 1.0 / (self.temperature**omega)
887
855
 
888
- attn_output = self.out_proj(attn_output)
856
+ out_w = grid_w.flatten()[..., None] @ omega[None]
857
+ out_h = grid_h.flatten()[..., None] @ omega[None]
889
858
 
890
- return attn_output, attn_weights_reshaped
859
+ return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
891
860
 
892
861
 
893
- class RTDetrDecoderLayer(nn.Module):
862
+ class RTDetrAIFILayer(nn.Module):
863
+ """
864
+ AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder.
865
+ """
866
+
894
867
  def __init__(self, config: RTDetrConfig):
895
868
  super().__init__()
896
- # self-attention
897
- self.self_attn = RTDetrMultiheadAttention(
898
- embed_dim=config.d_model,
899
- num_heads=config.decoder_attention_heads,
900
- dropout=config.attention_dropout,
901
- )
902
- self.dropout = config.dropout
903
- self.activation_fn = ACT2FN[config.decoder_activation_function]
904
- self.activation_dropout = config.activation_dropout
869
+ self.config = config
870
+ self.encoder_hidden_dim = config.encoder_hidden_dim
871
+ self.eval_size = config.eval_size
905
872
 
906
- self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
907
- # cross-attention
908
- self.encoder_attn = RTDetrMultiscaleDeformableAttention(
909
- config,
910
- num_heads=config.decoder_attention_heads,
911
- n_points=config.decoder_n_points,
873
+ self.position_embedding = RTDetrSinePositionEmbedding(
874
+ embed_dim=self.encoder_hidden_dim,
875
+ temperature=config.positional_encoding_temperature,
912
876
  )
913
- self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
914
- # feedforward neural networks
915
- self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
916
- self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
917
- self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
877
+ self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
918
878
 
919
879
  def forward(
920
- self,
921
- hidden_states: torch.Tensor,
922
- position_embeddings: torch.Tensor | None = None,
923
- reference_points=None,
924
- spatial_shapes=None,
925
- spatial_shapes_list=None,
926
- level_start_index=None,
927
- encoder_hidden_states: torch.Tensor | None = None,
928
- encoder_attention_mask: torch.Tensor | None = None,
929
- output_attentions: bool | None = False,
930
- ):
880
+ self,
881
+ hidden_states: torch.Tensor,
882
+ **kwargs: Unpack[TransformersKwargs],
883
+ ) -> torch.Tensor:
931
884
  """
932
885
  Args:
933
- hidden_states (`torch.FloatTensor`):
934
- Input to the layer of shape `(seq_len, batch, embed_dim)`.
935
- position_embeddings (`torch.FloatTensor`, *optional*):
936
- Position embeddings that are added to the queries and keys in the self-attention layer.
937
- reference_points (`torch.FloatTensor`, *optional*):
938
- Reference points.
939
- spatial_shapes (`torch.LongTensor`, *optional*):
940
- Spatial shapes.
941
- level_start_index (`torch.LongTensor`, *optional*):
942
- Level start index.
943
- encoder_hidden_states (`torch.FloatTensor`):
944
- cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
945
- encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
946
- `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
947
- values.
948
- output_attentions (`bool`, *optional*):
949
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
950
- returned tensors for more detail.
886
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
887
+ Feature map to process.
951
888
  """
952
- residual = hidden_states
889
+ batch_size = hidden_states.shape[0]
890
+ height, width = hidden_states.shape[2:]
953
891
 
954
- # Self Attention
955
- hidden_states, self_attn_weights = self.self_attn(
956
- hidden_states=hidden_states,
957
- attention_mask=encoder_attention_mask,
958
- position_embeddings=position_embeddings,
959
- output_attentions=output_attentions,
960
- )
892
+ hidden_states = hidden_states.flatten(2).permute(0, 2, 1)
961
893
 
962
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
963
- hidden_states = residual + hidden_states
964
- hidden_states = self.self_attn_layer_norm(hidden_states)
894
+ if self.training or self.eval_size is None:
895
+ pos_embed = self.position_embedding(
896
+ width=width,
897
+ height=height,
898
+ device=hidden_states.device,
899
+ dtype=hidden_states.dtype,
900
+ )
901
+ else:
902
+ pos_embed = None
965
903
 
966
- second_residual = hidden_states
904
+ for layer in self.layers:
905
+ hidden_states = layer(
906
+ hidden_states,
907
+ attention_mask=None,
908
+ spatial_position_embeddings=pos_embed,
909
+ **kwargs,
910
+ )
967
911
 
968
- # Cross-Attention
969
- cross_attn_weights = None
970
- hidden_states, cross_attn_weights = self.encoder_attn(
971
- hidden_states=hidden_states,
972
- encoder_hidden_states=encoder_hidden_states,
973
- position_embeddings=position_embeddings,
974
- reference_points=reference_points,
975
- spatial_shapes=spatial_shapes,
976
- spatial_shapes_list=spatial_shapes_list,
977
- level_start_index=level_start_index,
978
- output_attentions=output_attentions,
912
+ hidden_states = (
913
+ hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous()
979
914
  )
980
915
 
981
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
982
- hidden_states = second_residual + hidden_states
916
+ return hidden_states
983
917
 
984
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
985
918
 
986
- # Fully Connected
987
- residual = hidden_states
988
- hidden_states = self.activation_fn(self.fc1(hidden_states))
989
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
990
- hidden_states = self.fc2(hidden_states)
991
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
992
- hidden_states = residual + hidden_states
993
- hidden_states = self.final_layer_norm(hidden_states)
919
+ class RTDetrMLPPredictionHead(nn.Module):
920
+ """
921
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
922
+ height and width of a bounding box w.r.t. an image.
994
923
 
995
- outputs = (hidden_states,)
924
+ """
996
925
 
997
- if output_attentions:
998
- outputs += (self_attn_weights, cross_attn_weights)
926
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
927
+ super().__init__()
928
+ self.num_layers = num_layers
929
+ h = [hidden_dim] * (num_layers - 1)
930
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
999
931
 
1000
- return outputs
932
+ def forward(self, x):
933
+ for i, layer in enumerate(self.layers):
934
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
935
+ return x
1001
936
 
1002
937
 
1003
938
  @auto_docstring
@@ -1007,6 +942,10 @@ class RTDetrPreTrainedModel(PreTrainedModel):
1007
942
  main_input_name = "pixel_values"
1008
943
  input_modalities = ("image",)
1009
944
  _no_split_modules = [r"RTDetrHybridEncoder", r"RTDetrDecoderLayer"]
945
+ _supports_sdpa = True
946
+ _supports_flash_attn = True
947
+ _supports_attention_backend = True
948
+ _supports_flex_attn = True
1010
949
 
1011
950
  @torch.no_grad()
1012
951
  def _init_weights(self, module):
@@ -1072,35 +1011,23 @@ class RTDetrPreTrainedModel(PreTrainedModel):
1072
1011
  init.xavier_uniform_(module.denoising_class_embed.weight)
1073
1012
 
1074
1013
 
1075
- class RTDetrEncoder(nn.Module):
1076
- def __init__(self, config: RTDetrConfig):
1077
- super().__init__()
1078
-
1079
- self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
1080
-
1081
- def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
1082
- hidden_states = src
1083
- for layer in self.layers:
1084
- hidden_states = layer(
1085
- hidden_states,
1086
- attention_mask=src_mask,
1087
- position_embeddings=pos_embed,
1088
- output_attentions=output_attentions,
1089
- )
1090
- return hidden_states
1091
-
1092
-
1093
- class RTDetrHybridEncoder(nn.Module):
1014
+ class RTDetrHybridEncoder(RTDetrPreTrainedModel):
1094
1015
  """
1095
- Decoder consisting of a projection layer, a set of `RTDetrEncoder`, a top-down Feature Pyramid Network
1096
- (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069
1016
+ Hybrid encoder consisting of AIFI (Attention-based Intra-scale Feature Interaction) layers,
1017
+ a top-down Feature Pyramid Network (FPN) and a bottom-up Path Aggregation Network (PAN).
1018
+ More details on the paper: https://huggingface.co/papers/2304.08069
1097
1019
 
1098
1020
  Args:
1099
1021
  config: RTDetrConfig
1100
1022
  """
1101
1023
 
1024
+ _can_record_outputs = {
1025
+ "hidden_states": RTDetrAIFILayer,
1026
+ "attentions": RTDetrSelfAttention,
1027
+ }
1028
+
1102
1029
  def __init__(self, config: RTDetrConfig):
1103
- super().__init__()
1030
+ super().__init__(config)
1104
1031
  self.config = config
1105
1032
  self.in_channels = config.encoder_in_channels
1106
1033
  self.feat_strides = config.feat_strides
@@ -1112,10 +1039,9 @@ class RTDetrHybridEncoder(nn.Module):
1112
1039
  self.out_strides = self.feat_strides
1113
1040
  self.num_fpn_stages = len(self.in_channels) - 1
1114
1041
  self.num_pan_stages = len(self.in_channels) - 1
1115
- activation = config.activation_function
1116
1042
 
1117
- # encoder transformer
1118
- self.encoder = nn.ModuleList([RTDetrEncoder(config) for _ in range(len(self.encode_proj_layers))])
1043
+ # AIFI (Attention-based Intra-scale Feature Interaction) layers
1044
+ self.aifi = nn.ModuleList([RTDetrAIFILayer(config) for _ in range(len(self.encode_proj_layers))])
1119
1045
 
1120
1046
  # top-down FPN
1121
1047
  self.lateral_convs = nn.ModuleList()
@@ -1127,7 +1053,7 @@ class RTDetrHybridEncoder(nn.Module):
1127
1053
  out_channels=self.encoder_hidden_dim,
1128
1054
  kernel_size=1,
1129
1055
  stride=1,
1130
- activation=activation,
1056
+ activation=config.activation_function,
1131
1057
  )
1132
1058
  fpn_block = RTDetrCSPRepLayer(config)
1133
1059
  self.lateral_convs.append(lateral_conv)
@@ -1143,118 +1069,36 @@ class RTDetrHybridEncoder(nn.Module):
1143
1069
  out_channels=self.encoder_hidden_dim,
1144
1070
  kernel_size=3,
1145
1071
  stride=2,
1146
- activation=activation,
1072
+ activation=config.activation_function,
1147
1073
  )
1148
1074
  pan_block = RTDetrCSPRepLayer(config)
1149
1075
  self.downsample_convs.append(downsample_conv)
1150
1076
  self.pan_blocks.append(pan_block)
1151
1077
 
1152
- @staticmethod
1153
- def build_2d_sincos_position_embedding(
1154
- width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
1155
- ):
1156
- grid_w = torch.arange(torch_int(width), device=device).to(dtype)
1157
- grid_h = torch.arange(torch_int(height), device=device).to(dtype)
1158
- grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
1159
- if embed_dim % 4 != 0:
1160
- raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
1161
- pos_dim = embed_dim // 4
1162
- omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
1163
- omega = 1.0 / (temperature**omega)
1164
-
1165
- out_w = grid_w.flatten()[..., None] @ omega[None]
1166
- out_h = grid_h.flatten()[..., None] @ omega[None]
1167
-
1168
- return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
1078
+ self.post_init()
1169
1079
 
1080
+ @check_model_inputs(tie_last_hidden_states=False)
1170
1081
  def forward(
1171
1082
  self,
1172
1083
  inputs_embeds=None,
1173
- attention_mask=None,
1174
- position_embeddings=None,
1175
- spatial_shapes=None,
1176
- level_start_index=None,
1177
- valid_ratios=None,
1178
- output_attentions=None,
1179
- output_hidden_states=None,
1180
- return_dict=None,
1181
- ):
1084
+ **kwargs: Unpack[TransformersKwargs],
1085
+ ) -> BaseModelOutput:
1182
1086
  r"""
1183
1087
  Args:
1184
1088
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1185
1089
  Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
1186
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1187
- Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
1188
- - 1 for pixel features that are real (i.e. **not masked**),
1189
- - 0 for pixel features that are padding (i.e. **masked**).
1190
- [What are attention masks?](../glossary#attention-mask)
1191
- position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1192
- Position embeddings that are added to the queries and keys in each self-attention layer.
1193
- spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
1194
- Spatial shapes of each feature map.
1195
- level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
1196
- Starting index of each feature map.
1197
- valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
1198
- Ratio of valid area in each feature level.
1199
- output_attentions (`bool`, *optional*):
1200
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1201
- returned tensors for more detail.
1202
- output_hidden_states (`bool`, *optional*):
1203
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1204
- for more detail.
1205
- return_dict (`bool`, *optional*):
1206
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1207
1090
  """
1208
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1209
- output_hidden_states = (
1210
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1211
- )
1212
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1213
-
1214
- hidden_states = inputs_embeds
1091
+ feature_maps = inputs_embeds
1215
1092
 
1216
- encoder_states = () if output_hidden_states else None
1217
- all_attentions = () if output_attentions else None
1218
-
1219
- # encoder
1093
+ # AIFI: Apply transformer encoder to specified feature levels
1220
1094
  if self.config.encoder_layers > 0:
1221
1095
  for i, enc_ind in enumerate(self.encode_proj_layers):
1222
- if output_hidden_states:
1223
- encoder_states = encoder_states + (hidden_states[enc_ind],)
1224
- height, width = hidden_states[enc_ind].shape[2:]
1225
- # flatten [batch, channel, height, width] to [batch, height*width, channel]
1226
- src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
1227
- if self.training or self.eval_size is None:
1228
- pos_embed = self.build_2d_sincos_position_embedding(
1229
- width,
1230
- height,
1231
- self.encoder_hidden_dim,
1232
- self.positional_encoding_temperature,
1233
- device=src_flatten.device,
1234
- dtype=src_flatten.dtype,
1235
- )
1236
- else:
1237
- pos_embed = None
1238
-
1239
- layer_outputs = self.encoder[i](
1240
- src_flatten,
1241
- pos_embed=pos_embed,
1242
- output_attentions=output_attentions,
1243
- )
1244
- hidden_states[enc_ind] = (
1245
- layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
1246
- )
1247
-
1248
- if output_attentions:
1249
- all_attentions = all_attentions + (layer_outputs[1],)
1250
-
1251
- if output_hidden_states:
1252
- encoder_states = encoder_states + (hidden_states[enc_ind],)
1096
+ feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs)
1253
1097
 
1254
1098
  # top-down FPN
1255
- fpn_feature_maps = [hidden_states[-1]]
1099
+ fpn_feature_maps = [feature_maps[-1]]
1256
1100
  for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
1257
- backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
1101
+ backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1]
1258
1102
  top_fpn_feature_map = fpn_feature_maps[-1]
1259
1103
  # apply lateral block
1260
1104
  top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
@@ -1277,20 +1121,29 @@ class RTDetrHybridEncoder(nn.Module):
1277
1121
  new_pan_feature_map = pan_block(fused_feature_map)
1278
1122
  pan_feature_maps.append(new_pan_feature_map)
1279
1123
 
1280
- if not return_dict:
1281
- return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
1282
- return BaseModelOutput(
1283
- last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
1284
- )
1124
+ return BaseModelOutput(last_hidden_state=pan_feature_maps)
1125
+
1126
+
1127
+ def inverse_sigmoid(x, eps=1e-5):
1128
+ x = x.clamp(min=0, max=1)
1129
+ x1 = x.clamp(min=eps)
1130
+ x2 = (1 - x).clamp(min=eps)
1131
+ return torch.log(x1 / x2)
1285
1132
 
1286
1133
 
1287
1134
  class RTDetrDecoder(RTDetrPreTrainedModel):
1135
+ _can_record_outputs = {
1136
+ "hidden_states": RTDetrDecoderLayer,
1137
+ "attentions": RTDetrSelfAttention,
1138
+ "cross_attentions": RTDetrMultiscaleDeformableAttention,
1139
+ }
1140
+
1288
1141
  def __init__(self, config: RTDetrConfig):
1289
1142
  super().__init__(config)
1290
1143
 
1291
1144
  self.dropout = config.dropout
1292
1145
  self.layers = nn.ModuleList([RTDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
1293
- self.query_pos_head = RTDetrMLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2)
1146
+ self.query_pos_head = RTDetrMLPPredictionHead(4, 2 * config.d_model, config.d_model, num_layers=2)
1294
1147
 
1295
1148
  # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
1296
1149
  self.bbox_embed = None
@@ -1299,21 +1152,17 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
1299
1152
  # Initialize weights and apply final processing
1300
1153
  self.post_init()
1301
1154
 
1155
+ @check_model_inputs()
1302
1156
  def forward(
1303
1157
  self,
1304
1158
  inputs_embeds=None,
1305
1159
  encoder_hidden_states=None,
1306
1160
  encoder_attention_mask=None,
1307
- position_embeddings=None,
1308
1161
  reference_points=None,
1309
1162
  spatial_shapes=None,
1310
1163
  spatial_shapes_list=None,
1311
1164
  level_start_index=None,
1312
- valid_ratios=None,
1313
- output_attentions=None,
1314
- output_hidden_states=None,
1315
- return_dict=None,
1316
- **kwargs,
1165
+ **kwargs: Unpack[TransformersKwargs],
1317
1166
  ):
1318
1167
  r"""
1319
1168
  Args:
@@ -1327,39 +1176,17 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
1327
1176
  in `[0, 1]`:
1328
1177
  - 1 for pixels that are real (i.e. **not masked**),
1329
1178
  - 0 for pixels that are padding (i.e. **masked**).
1330
- position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1331
- Position embeddings that are added to the queries and keys in each self-attention layer.
1332
1179
  reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
1333
1180
  Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
1334
1181
  spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
1335
1182
  Spatial shapes of the feature maps.
1336
1183
  level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
1337
1184
  Indexes for the start of each feature level. In range `[0, sequence_length]`.
1338
- valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
1339
- Ratio of valid area in each feature level.
1340
-
1341
- output_attentions (`bool`, *optional*):
1342
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1343
- returned tensors for more detail.
1344
- output_hidden_states (`bool`, *optional*):
1345
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1346
- for more detail.
1347
- return_dict (`bool`, *optional*):
1348
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1349
1185
  """
1350
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1351
- output_hidden_states = (
1352
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1353
- )
1354
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1355
-
1356
1186
  if inputs_embeds is not None:
1357
1187
  hidden_states = inputs_embeds
1358
1188
 
1359
1189
  # decoder layers
1360
- all_hidden_states = () if output_hidden_states else None
1361
- all_self_attns = () if output_attentions else None
1362
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1363
1190
  intermediate = ()
1364
1191
  intermediate_reference_points = ()
1365
1192
  intermediate_logits = ()
@@ -1369,25 +1196,20 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
1369
1196
  # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L252
1370
1197
  for idx, decoder_layer in enumerate(self.layers):
1371
1198
  reference_points_input = reference_points.unsqueeze(2)
1372
- position_embeddings = self.query_pos_head(reference_points)
1373
-
1374
- if output_hidden_states:
1375
- all_hidden_states += (hidden_states,)
1199
+ object_queries_position_embeddings = self.query_pos_head(reference_points)
1376
1200
 
1377
- layer_outputs = decoder_layer(
1201
+ hidden_states = decoder_layer(
1378
1202
  hidden_states,
1379
- position_embeddings=position_embeddings,
1203
+ object_queries_position_embeddings=object_queries_position_embeddings,
1380
1204
  encoder_hidden_states=encoder_hidden_states,
1381
1205
  reference_points=reference_points_input,
1382
1206
  spatial_shapes=spatial_shapes,
1383
1207
  spatial_shapes_list=spatial_shapes_list,
1384
1208
  level_start_index=level_start_index,
1385
1209
  encoder_attention_mask=encoder_attention_mask,
1386
- output_attentions=output_attentions,
1210
+ **kwargs,
1387
1211
  )
1388
1212
 
1389
- hidden_states = layer_outputs[0]
1390
-
1391
1213
  # hack implementation for iterative bounding box refinement
1392
1214
  if self.bbox_embed is not None:
1393
1215
  predicted_corners = self.bbox_embed[idx](hidden_states)
@@ -1403,68 +1225,141 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
1403
1225
  logits = self.class_embed[idx](hidden_states)
1404
1226
  intermediate_logits += (logits,)
1405
1227
 
1406
- if output_attentions:
1407
- all_self_attns += (layer_outputs[1],)
1408
-
1409
- if encoder_hidden_states is not None:
1410
- all_cross_attentions += (layer_outputs[2],)
1411
-
1412
1228
  # Keep batch_size as first dimension
1413
1229
  intermediate = torch.stack(intermediate, dim=1)
1414
1230
  intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
1415
1231
  if self.class_embed is not None:
1416
1232
  intermediate_logits = torch.stack(intermediate_logits, dim=1)
1417
1233
 
1418
- # add hidden states from the last decoder layer
1419
- if output_hidden_states:
1420
- all_hidden_states += (hidden_states,)
1421
-
1422
- if not return_dict:
1423
- return tuple(
1424
- v
1425
- for v in [
1426
- hidden_states,
1427
- intermediate,
1428
- intermediate_logits,
1429
- intermediate_reference_points,
1430
- all_hidden_states,
1431
- all_self_attns,
1432
- all_cross_attentions,
1433
- ]
1434
- if v is not None
1435
- )
1436
1234
  return RTDetrDecoderOutput(
1437
1235
  last_hidden_state=hidden_states,
1438
1236
  intermediate_hidden_states=intermediate,
1439
1237
  intermediate_logits=intermediate_logits,
1440
1238
  intermediate_reference_points=intermediate_reference_points,
1441
- hidden_states=all_hidden_states,
1442
- attentions=all_self_attns,
1443
- cross_attentions=all_cross_attentions,
1444
1239
  )
1445
1240
 
1446
1241
 
1447
- # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1448
- class RTDetrMLPPredictionHead(nn.Module):
1242
+ def get_contrastive_denoising_training_group(
1243
+ targets,
1244
+ num_classes,
1245
+ num_queries,
1246
+ class_embed,
1247
+ num_denoising_queries=100,
1248
+ label_noise_ratio=0.5,
1249
+ box_noise_scale=1.0,
1250
+ ):
1449
1251
  """
1450
- Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1451
- height and width of a bounding box w.r.t. an image.
1452
-
1453
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1454
- Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_paddle/ppdet/modeling/transformers/utils.py#L453
1252
+ Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
1455
1253
 
1254
+ Args:
1255
+ targets (`list[dict]`):
1256
+ The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
1257
+ num_classes (`int`):
1258
+ Total number of classes in the dataset.
1259
+ num_queries (`int`):
1260
+ Number of query slots in the transformer.
1261
+ class_embed (`callable`):
1262
+ A function or a model layer to embed class labels.
1263
+ num_denoising_queries (`int`, *optional*, defaults to 100):
1264
+ Number of denoising queries.
1265
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
1266
+ Ratio of noise applied to labels.
1267
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
1268
+ Scale of noise applied to bounding boxes.
1269
+ Returns:
1270
+ `tuple` comprising various elements:
1271
+ - **input_query_class** (`torch.FloatTensor`) --
1272
+ Class queries with applied label noise.
1273
+ - **input_query_bbox** (`torch.FloatTensor`) --
1274
+ Bounding box queries with applied box noise.
1275
+ - **attn_mask** (`torch.FloatTensor`) --
1276
+ Attention mask for separating denoising and reconstruction queries.
1277
+ - **denoising_meta_values** (`dict`) --
1278
+ Metadata including denoising positive indices, number of groups, and split sizes.
1456
1279
  """
1457
1280
 
1458
- def __init__(self, config, input_dim, d_model, output_dim, num_layers):
1459
- super().__init__()
1460
- self.num_layers = num_layers
1461
- h = [d_model] * (num_layers - 1)
1462
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1281
+ if num_denoising_queries <= 0:
1282
+ return None, None, None, None
1463
1283
 
1464
- def forward(self, x):
1465
- for i, layer in enumerate(self.layers):
1466
- x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1467
- return x
1284
+ num_ground_truths = [len(t["class_labels"]) for t in targets]
1285
+ device = targets[0]["class_labels"].device
1286
+
1287
+ max_gt_num = max(num_ground_truths)
1288
+ if max_gt_num == 0:
1289
+ return None, None, None, None
1290
+
1291
+ num_groups_denoising_queries = num_denoising_queries // max_gt_num
1292
+ num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
1293
+ # pad gt to max_num of a batch
1294
+ batch_size = len(num_ground_truths)
1295
+
1296
+ input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
1297
+ input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
1298
+ pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
1299
+
1300
+ for i in range(batch_size):
1301
+ num_gt = num_ground_truths[i]
1302
+ if num_gt > 0:
1303
+ input_query_class[i, :num_gt] = targets[i]["class_labels"]
1304
+ input_query_bbox[i, :num_gt] = targets[i]["boxes"]
1305
+ pad_gt_mask[i, :num_gt] = 1
1306
+ # each group has positive and negative queries.
1307
+ input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
1308
+ input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
1309
+ pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
1310
+ # positive and negative mask
1311
+ negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
1312
+ negative_gt_mask[:, max_gt_num:] = 1
1313
+ negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
1314
+ positive_gt_mask = 1 - negative_gt_mask
1315
+ # contrastive denoising training positive index
1316
+ positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
1317
+ denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
1318
+ denoise_positive_idx = torch.split(
1319
+ denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
1320
+ )
1321
+ # total denoising queries
1322
+ num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
1323
+
1324
+ if label_noise_ratio > 0:
1325
+ mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
1326
+ # randomly put a new one here
1327
+ new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
1328
+ input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
1329
+
1330
+ if box_noise_scale > 0:
1331
+ known_bbox = center_to_corners_format(input_query_bbox)
1332
+ diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
1333
+ rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
1334
+ rand_part = torch.rand_like(input_query_bbox)
1335
+ rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
1336
+ rand_part *= rand_sign
1337
+ known_bbox += rand_part * diff
1338
+ known_bbox.clip_(min=0.0, max=1.0)
1339
+ input_query_bbox = corners_to_center_format(known_bbox)
1340
+ input_query_bbox = inverse_sigmoid(input_query_bbox)
1341
+
1342
+ input_query_class = class_embed(input_query_class)
1343
+
1344
+ target_size = num_denoising_queries + num_queries
1345
+ attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
1346
+ # match query cannot see the reconstruction
1347
+ attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
1348
+
1349
+ # reconstructions cannot see each other
1350
+ for i in range(num_groups_denoising_queries):
1351
+ idx_block_start = max_gt_num * 2 * i
1352
+ idx_block_end = max_gt_num * 2 * (i + 1)
1353
+ attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
1354
+ attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
1355
+
1356
+ denoising_meta_values = {
1357
+ "dn_positive_idx": denoise_positive_idx,
1358
+ "dn_num_group": num_groups_denoising_queries,
1359
+ "dn_num_split": [num_denoising_queries, num_queries],
1360
+ }
1361
+
1362
+ return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
1468
1363
 
1469
1364
 
1470
1365
  @auto_docstring(
@@ -1484,8 +1379,8 @@ class RTDetrModel(RTDetrPreTrainedModel):
1484
1379
  # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py#L212
1485
1380
  num_backbone_outs = len(intermediate_channel_sizes)
1486
1381
  encoder_input_proj_list = []
1487
- for _ in range(num_backbone_outs):
1488
- in_channels = intermediate_channel_sizes[_]
1382
+ for i in range(num_backbone_outs):
1383
+ in_channels = intermediate_channel_sizes[i]
1489
1384
  encoder_input_proj_list.append(
1490
1385
  nn.Sequential(
1491
1386
  nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
@@ -1513,7 +1408,7 @@ class RTDetrModel(RTDetrPreTrainedModel):
1513
1408
  nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
1514
1409
  )
1515
1410
  self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
1516
- self.enc_bbox_head = RTDetrMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3)
1411
+ self.enc_bbox_head = RTDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
1517
1412
 
1518
1413
  # init encoder output anchors and valid_mask
1519
1414
  if config.anchor_image_size:
@@ -1523,8 +1418,8 @@ class RTDetrModel(RTDetrPreTrainedModel):
1523
1418
  # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
1524
1419
  num_backbone_outs = len(config.decoder_in_channels)
1525
1420
  decoder_input_proj_list = []
1526
- for _ in range(num_backbone_outs):
1527
- in_channels = config.decoder_in_channels[_]
1421
+ for i in range(num_backbone_outs):
1422
+ in_channels = config.decoder_in_channels[i]
1528
1423
  decoder_input_proj_list.append(
1529
1424
  nn.Sequential(
1530
1425
  nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
@@ -1584,26 +1479,20 @@ class RTDetrModel(RTDetrPreTrainedModel):
1584
1479
  return anchors, valid_mask
1585
1480
 
1586
1481
  @auto_docstring
1482
+ @can_return_tuple
1587
1483
  def forward(
1588
1484
  self,
1589
1485
  pixel_values: torch.FloatTensor,
1590
1486
  pixel_mask: torch.LongTensor | None = None,
1591
1487
  encoder_outputs: torch.FloatTensor | None = None,
1592
1488
  inputs_embeds: torch.FloatTensor | None = None,
1593
- decoder_inputs_embeds: torch.FloatTensor | None = None,
1594
1489
  labels: list[dict] | None = None,
1595
- output_attentions: bool | None = None,
1596
- output_hidden_states: bool | None = None,
1597
- return_dict: bool | None = None,
1598
- **kwargs,
1490
+ **kwargs: Unpack[TransformersKwargs],
1599
1491
  ) -> tuple[torch.FloatTensor] | RTDetrModelOutput:
1600
1492
  r"""
1601
1493
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1602
1494
  Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1603
1495
  can choose to directly pass a flattened representation of an image.
1604
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1605
- Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1606
- embedded representation.
1607
1496
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1608
1497
  Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1609
1498
  following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
@@ -1631,53 +1520,46 @@ class RTDetrModel(RTDetrPreTrainedModel):
1631
1520
  >>> list(last_hidden_states.shape)
1632
1521
  [1, 300, 256]
1633
1522
  ```"""
1634
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1635
- output_hidden_states = (
1636
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1637
- )
1638
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1639
-
1640
- batch_size, num_channels, height, width = pixel_values.shape
1641
- device = pixel_values.device
1642
-
1643
- if pixel_mask is None:
1644
- pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1645
-
1646
- features = self.backbone(pixel_values, pixel_mask)
1647
-
1648
- proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
1523
+ if pixel_values is None and inputs_embeds is None:
1524
+ raise ValueError("You have to specify either pixel_values or inputs_embeds")
1525
+
1526
+ if inputs_embeds is None:
1527
+ batch_size, num_channels, height, width = pixel_values.shape
1528
+ device = pixel_values.device
1529
+ if pixel_mask is None:
1530
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1531
+ features = self.backbone(pixel_values, pixel_mask)
1532
+ proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
1533
+ else:
1534
+ batch_size = inputs_embeds.shape[0]
1535
+ device = inputs_embeds.device
1536
+ proj_feats = inputs_embeds
1649
1537
 
1650
1538
  if encoder_outputs is None:
1651
1539
  encoder_outputs = self.encoder(
1652
1540
  proj_feats,
1653
- output_attentions=output_attentions,
1654
- output_hidden_states=output_hidden_states,
1655
- return_dict=return_dict,
1541
+ **kwargs,
1656
1542
  )
1657
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1658
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1543
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
1544
+ elif not isinstance(encoder_outputs, BaseModelOutput):
1659
1545
  encoder_outputs = BaseModelOutput(
1660
1546
  last_hidden_state=encoder_outputs[0],
1661
- hidden_states=encoder_outputs[1] if output_hidden_states else None,
1662
- attentions=encoder_outputs[2]
1663
- if len(encoder_outputs) > 2
1664
- else encoder_outputs[1]
1665
- if output_attentions
1666
- else None,
1547
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1548
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1667
1549
  )
1668
1550
 
1669
1551
  # Equivalent to def _get_encoder_input
1670
1552
  # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
1671
1553
  sources = []
1672
- for level, source in enumerate(encoder_outputs[0]):
1554
+ for level, source in enumerate(encoder_outputs.last_hidden_state):
1673
1555
  sources.append(self.decoder_input_proj[level](source))
1674
1556
 
1675
1557
  # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
1676
1558
  if self.config.num_feature_levels > len(sources):
1677
1559
  _len_sources = len(sources)
1678
- sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1])
1560
+ sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1])
1679
1561
  for i in range(_len_sources + 1, self.config.num_feature_levels):
1680
- sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1]))
1562
+ sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1]))
1681
1563
 
1682
1564
  # Prepare encoder inputs (by flattening)
1683
1565
  source_flatten = []
@@ -1769,22 +1651,9 @@ class RTDetrModel(RTDetrPreTrainedModel):
1769
1651
  spatial_shapes=spatial_shapes,
1770
1652
  spatial_shapes_list=spatial_shapes_list,
1771
1653
  level_start_index=level_start_index,
1772
- output_attentions=output_attentions,
1773
- output_hidden_states=output_hidden_states,
1774
- return_dict=return_dict,
1654
+ **kwargs,
1775
1655
  )
1776
1656
 
1777
- if not return_dict:
1778
- enc_outputs = tuple(
1779
- value
1780
- for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
1781
- if value is not None
1782
- )
1783
- dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
1784
- tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
1785
-
1786
- return tuple_outputs
1787
-
1788
1657
  return RTDetrModelOutput(
1789
1658
  last_hidden_state=decoder_outputs.last_hidden_state,
1790
1659
  intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
@@ -1826,7 +1695,7 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1826
1695
  [torch.nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]
1827
1696
  )
1828
1697
  self.model.decoder.bbox_embed = nn.ModuleList(
1829
- [RTDetrMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3) for _ in range(num_pred)]
1698
+ [RTDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) for _ in range(num_pred)]
1830
1699
  )
1831
1700
  # if two-stage, the last class_embed and bbox_embed is for region proposal generation
1832
1701
  self.post_init()
@@ -1835,26 +1704,20 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1835
1704
  return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
1836
1705
 
1837
1706
  @auto_docstring
1707
+ @can_return_tuple
1838
1708
  def forward(
1839
1709
  self,
1840
1710
  pixel_values: torch.FloatTensor,
1841
1711
  pixel_mask: torch.LongTensor | None = None,
1842
1712
  encoder_outputs: torch.FloatTensor | None = None,
1843
1713
  inputs_embeds: torch.FloatTensor | None = None,
1844
- decoder_inputs_embeds: torch.FloatTensor | None = None,
1845
1714
  labels: list[dict] | None = None,
1846
- output_attentions: bool | None = None,
1847
- output_hidden_states: bool | None = None,
1848
- return_dict: bool | None = None,
1849
- **kwargs,
1715
+ **kwargs: Unpack[TransformersKwargs],
1850
1716
  ) -> tuple[torch.FloatTensor] | RTDetrObjectDetectionOutput:
1851
1717
  r"""
1852
1718
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1853
1719
  Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1854
1720
  can choose to directly pass a flattened representation of an image.
1855
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1856
- Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1857
- embedded representation.
1858
1721
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1859
1722
  Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1860
1723
  following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
@@ -1907,40 +1770,29 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1907
1770
  Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48]
1908
1771
  Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99]
1909
1772
  ```"""
1910
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1911
- output_hidden_states = (
1912
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1913
- )
1914
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1915
-
1916
1773
  outputs = self.model(
1917
1774
  pixel_values,
1918
1775
  pixel_mask=pixel_mask,
1919
1776
  encoder_outputs=encoder_outputs,
1920
1777
  inputs_embeds=inputs_embeds,
1921
- decoder_inputs_embeds=decoder_inputs_embeds,
1922
1778
  labels=labels,
1923
- output_attentions=output_attentions,
1924
- output_hidden_states=output_hidden_states,
1925
- return_dict=return_dict,
1779
+ **kwargs,
1926
1780
  )
1927
1781
 
1928
- denoising_meta_values = (
1929
- outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
1930
- )
1782
+ denoising_meta_values = outputs.denoising_meta_values if self.training else None
1931
1783
 
1932
- outputs_class = outputs.intermediate_logits if return_dict else outputs[2]
1933
- outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3]
1934
- predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4]
1935
- initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5]
1784
+ outputs_class = outputs.intermediate_logits
1785
+ outputs_coord = outputs.intermediate_reference_points
1786
+ predicted_corners = outputs.intermediate_predicted_corners
1787
+ initial_reference_points = outputs.initial_reference_points
1936
1788
 
1937
1789
  logits = outputs_class[:, -1]
1938
1790
  pred_boxes = outputs_coord[:, -1]
1939
1791
 
1940
1792
  loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
1941
1793
  if labels is not None:
1942
- enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
1943
- enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
1794
+ enc_topk_logits = outputs.enc_topk_logits
1795
+ enc_topk_bboxes = outputs.enc_topk_bboxes
1944
1796
  loss, loss_dict, auxiliary_outputs = self.loss_function(
1945
1797
  logits,
1946
1798
  labels,
@@ -1957,13 +1809,6 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1957
1809
  **kwargs,
1958
1810
  )
1959
1811
 
1960
- if not return_dict:
1961
- if auxiliary_outputs is not None:
1962
- output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
1963
- else:
1964
- output = (logits, pred_boxes) + outputs
1965
- return ((loss, loss_dict) + output) if loss is not None else output
1966
-
1967
1812
  return RTDetrObjectDetectionOutput(
1968
1813
  loss=loss,
1969
1814
  loss_dict=loss_dict,
@@ -1991,8 +1836,4 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1991
1836
  )
1992
1837
 
1993
1838
 
1994
- __all__ = [
1995
- "RTDetrForObjectDetection",
1996
- "RTDetrModel",
1997
- "RTDetrPreTrainedModel",
1998
- ]
1839
+ __all__ = ["RTDetrForObjectDetection", "RTDetrModel", "RTDetrPreTrainedModel"]