transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,9 @@
1
- # coding=utf-8
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/rt_detr/modular_rt_detr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_rt_detr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
7
  # Copyright 2024 Baidu Inc and The HuggingFace Inc. team.
3
8
  #
4
9
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,12 +17,10 @@
12
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
18
  # See the License for the specific language governing permissions and
14
19
  # limitations under the License.
15
- """PyTorch RT-DETR model."""
16
-
17
20
  import math
18
21
  import warnings
22
+ from collections.abc import Callable
19
23
  from dataclasses import dataclass
20
- from typing import Optional, Union
21
24
 
22
25
  import torch
23
26
  import torch.nn.functional as F
@@ -25,83 +28,18 @@ from torch import Tensor, nn
25
28
 
26
29
  from ... import initialization as init
27
30
  from ...activations import ACT2CLS, ACT2FN
31
+ from ...backbone_utils import load_backbone
28
32
  from ...image_transforms import center_to_corners_format, corners_to_center_format
29
33
  from ...integrations import use_kernel_forward_from_hub
30
34
  from ...modeling_outputs import BaseModelOutput
31
- from ...modeling_utils import PreTrainedModel
35
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from ...processing_utils import Unpack
32
37
  from ...pytorch_utils import compile_compatible_method_lru_cache
33
- from ...utils import (
34
- ModelOutput,
35
- auto_docstring,
36
- logging,
37
- torch_int,
38
- )
39
- from ...utils.backbone_utils import load_backbone
38
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int
39
+ from ...utils.generic import can_return_tuple, check_model_inputs
40
40
  from .configuration_rt_detr import RTDetrConfig
41
41
 
42
42
 
43
- logger = logging.get_logger(__name__)
44
-
45
-
46
- # TODO: Replace all occurrences of the checkpoint with the final one
47
-
48
-
49
- @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
50
- # Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
51
- class MultiScaleDeformableAttention(nn.Module):
52
- def forward(
53
- self,
54
- value: Tensor,
55
- value_spatial_shapes: Tensor,
56
- value_spatial_shapes_list: list[tuple],
57
- level_start_index: Tensor,
58
- sampling_locations: Tensor,
59
- attention_weights: Tensor,
60
- im2col_step: int,
61
- ):
62
- batch_size, _, num_heads, hidden_dim = value.shape
63
- _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
64
- value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
65
- sampling_grids = 2 * sampling_locations - 1
66
- sampling_value_list = []
67
- for level_id, (height, width) in enumerate(value_spatial_shapes_list):
68
- # batch_size, height*width, num_heads, hidden_dim
69
- # -> batch_size, height*width, num_heads*hidden_dim
70
- # -> batch_size, num_heads*hidden_dim, height*width
71
- # -> batch_size*num_heads, hidden_dim, height, width
72
- value_l_ = (
73
- value_list[level_id]
74
- .flatten(2)
75
- .transpose(1, 2)
76
- .reshape(batch_size * num_heads, hidden_dim, height, width)
77
- )
78
- # batch_size, num_queries, num_heads, num_points, 2
79
- # -> batch_size, num_heads, num_queries, num_points, 2
80
- # -> batch_size*num_heads, num_queries, num_points, 2
81
- sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
82
- # batch_size*num_heads, hidden_dim, num_queries, num_points
83
- sampling_value_l_ = nn.functional.grid_sample(
84
- value_l_,
85
- sampling_grid_l_,
86
- mode="bilinear",
87
- padding_mode="zeros",
88
- align_corners=False,
89
- )
90
- sampling_value_list.append(sampling_value_l_)
91
- # (batch_size, num_queries, num_heads, num_levels, num_points)
92
- # -> (batch_size, num_heads, num_queries, num_levels, num_points)
93
- # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
94
- attention_weights = attention_weights.transpose(1, 2).reshape(
95
- batch_size * num_heads, 1, num_queries, num_levels * num_points
96
- )
97
- output = (
98
- (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
99
- .sum(-1)
100
- .view(batch_size, num_heads * hidden_dim, num_queries)
101
- )
102
- return output.transpose(1, 2).contiguous()
103
-
104
-
105
43
  @dataclass
106
44
  @auto_docstring(
107
45
  custom_intro="""
@@ -129,15 +67,15 @@ class RTDetrDecoderOutput(ModelOutput):
129
67
  used to compute the weighted average in the cross-attention heads.
130
68
  """
131
69
 
132
- last_hidden_state: Optional[torch.FloatTensor] = None
133
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
134
- intermediate_logits: Optional[torch.FloatTensor] = None
135
- intermediate_reference_points: Optional[torch.FloatTensor] = None
136
- intermediate_predicted_corners: Optional[torch.FloatTensor] = None
137
- initial_reference_points: Optional[torch.FloatTensor] = None
138
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
139
- attentions: Optional[tuple[torch.FloatTensor]] = None
140
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
70
+ last_hidden_state: torch.FloatTensor | None = None
71
+ intermediate_hidden_states: torch.FloatTensor | None = None
72
+ intermediate_logits: torch.FloatTensor | None = None
73
+ intermediate_reference_points: torch.FloatTensor | None = None
74
+ intermediate_predicted_corners: torch.FloatTensor | None = None
75
+ initial_reference_points: torch.FloatTensor | None = None
76
+ hidden_states: tuple[torch.FloatTensor] | None = None
77
+ attentions: tuple[torch.FloatTensor] | None = None
78
+ cross_attentions: tuple[torch.FloatTensor] | None = None
141
79
 
142
80
 
143
81
  @dataclass
@@ -178,24 +116,24 @@ class RTDetrModelOutput(ModelOutput):
178
116
  Extra dictionary for the denoising related values.
179
117
  """
180
118
 
181
- last_hidden_state: Optional[torch.FloatTensor] = None
182
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
183
- intermediate_logits: Optional[torch.FloatTensor] = None
184
- intermediate_reference_points: Optional[torch.FloatTensor] = None
185
- intermediate_predicted_corners: Optional[torch.FloatTensor] = None
186
- initial_reference_points: Optional[torch.FloatTensor] = None
187
- decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
188
- decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
189
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
190
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
191
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
192
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
193
- init_reference_points: Optional[torch.FloatTensor] = None
194
- enc_topk_logits: Optional[torch.FloatTensor] = None
195
- enc_topk_bboxes: Optional[torch.FloatTensor] = None
196
- enc_outputs_class: Optional[torch.FloatTensor] = None
197
- enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
198
- denoising_meta_values: Optional[dict] = None
119
+ last_hidden_state: torch.FloatTensor | None = None
120
+ intermediate_hidden_states: torch.FloatTensor | None = None
121
+ intermediate_logits: torch.FloatTensor | None = None
122
+ intermediate_reference_points: torch.FloatTensor | None = None
123
+ intermediate_predicted_corners: torch.FloatTensor | None = None
124
+ initial_reference_points: torch.FloatTensor | None = None
125
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
126
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
127
+ cross_attentions: tuple[torch.FloatTensor] | None = None
128
+ encoder_last_hidden_state: torch.FloatTensor | None = None
129
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
130
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
131
+ init_reference_points: torch.FloatTensor | None = None
132
+ enc_topk_logits: torch.FloatTensor | None = None
133
+ enc_topk_bboxes: torch.FloatTensor | None = None
134
+ enc_outputs_class: torch.FloatTensor | None = None
135
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
136
+ denoising_meta_values: dict | None = None
199
137
 
200
138
 
201
139
  @dataclass
@@ -251,44 +189,48 @@ class RTDetrObjectDetectionOutput(ModelOutput):
251
189
  Extra dictionary for the denoising related values
252
190
  """
253
191
 
254
- loss: Optional[torch.FloatTensor] = None
255
- loss_dict: Optional[dict] = None
256
- logits: Optional[torch.FloatTensor] = None
257
- pred_boxes: Optional[torch.FloatTensor] = None
258
- auxiliary_outputs: Optional[list[dict]] = None
259
- last_hidden_state: Optional[torch.FloatTensor] = None
260
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
261
- intermediate_logits: Optional[torch.FloatTensor] = None
262
- intermediate_reference_points: Optional[torch.FloatTensor] = None
263
- intermediate_predicted_corners: Optional[torch.FloatTensor] = None
264
- initial_reference_points: Optional[torch.FloatTensor] = None
265
- decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
266
- decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
267
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
268
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
269
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
270
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
271
- init_reference_points: Optional[tuple[torch.FloatTensor]] = None
272
- enc_topk_logits: Optional[torch.FloatTensor] = None
273
- enc_topk_bboxes: Optional[torch.FloatTensor] = None
274
- enc_outputs_class: Optional[torch.FloatTensor] = None
275
- enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
276
- denoising_meta_values: Optional[dict] = None
277
-
278
-
279
- def _get_clones(partial_module, N):
280
- return nn.ModuleList([partial_module() for i in range(N)])
281
-
282
-
283
- # Copied from transformers.models.conditional_detr.modeling_conditional_detr.inverse_sigmoid
284
- def inverse_sigmoid(x, eps=1e-5):
285
- x = x.clamp(min=0, max=1)
286
- x1 = x.clamp(min=eps)
287
- x2 = (1 - x).clamp(min=eps)
288
- return torch.log(x1 / x2)
192
+ loss: torch.FloatTensor | None = None
193
+ loss_dict: dict | None = None
194
+ logits: torch.FloatTensor | None = None
195
+ pred_boxes: torch.FloatTensor | None = None
196
+ auxiliary_outputs: list[dict] | None = None
197
+ last_hidden_state: torch.FloatTensor | None = None
198
+ intermediate_hidden_states: torch.FloatTensor | None = None
199
+ intermediate_logits: torch.FloatTensor | None = None
200
+ intermediate_reference_points: torch.FloatTensor | None = None
201
+ intermediate_predicted_corners: torch.FloatTensor | None = None
202
+ initial_reference_points: torch.FloatTensor | None = None
203
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
204
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
205
+ cross_attentions: tuple[torch.FloatTensor] | None = None
206
+ encoder_last_hidden_state: torch.FloatTensor | None = None
207
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
208
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
209
+ init_reference_points: tuple[torch.FloatTensor] | None = None
210
+ enc_topk_logits: torch.FloatTensor | None = None
211
+ enc_topk_bboxes: torch.FloatTensor | None = None
212
+ enc_outputs_class: torch.FloatTensor | None = None
213
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
214
+ denoising_meta_values: dict | None = None
215
+
216
+
217
+ class RTDetrMLP(nn.Module):
218
+ def __init__(self, config: RTDetrConfig, hidden_size: int, intermediate_size: int, activation_function: str):
219
+ super().__init__()
220
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
221
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
222
+ self.activation_fn = ACT2FN[activation_function]
223
+ self.activation_dropout = config.activation_dropout
224
+ self.dropout = config.dropout
225
+
226
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
227
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
228
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
229
+ hidden_states = self.fc2(hidden_states)
230
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
231
+ return hidden_states
289
232
 
290
233
 
291
- # Copied from transformers.models.detr.modeling_detr.DetrFrozenBatchNorm2d with Detr->RTDetr
292
234
  class RTDetrFrozenBatchNorm2d(nn.Module):
293
235
  """
294
236
  BatchNorm2d where the batch statistics and the affine parameters are fixed.
@@ -328,152 +270,123 @@ class RTDetrFrozenBatchNorm2d(nn.Module):
328
270
  return x * scale + bias
329
271
 
330
272
 
331
- # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->RTDetr
332
- def replace_batch_norm(model):
333
- r"""
334
- Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrFrozenBatchNorm2d`.
273
+ def eager_attention_forward(
274
+ module: nn.Module,
275
+ query: torch.Tensor,
276
+ key: torch.Tensor,
277
+ value: torch.Tensor,
278
+ attention_mask: torch.Tensor | None,
279
+ scaling: float | None = None,
280
+ dropout: float = 0.0,
281
+ **kwargs: Unpack[TransformersKwargs],
282
+ ):
283
+ if scaling is None:
284
+ scaling = query.size(-1) ** -0.5
335
285
 
336
- Args:
337
- model (torch.nn.Module):
338
- input model
339
- """
340
- for name, module in model.named_children():
341
- if isinstance(module, nn.BatchNorm2d):
342
- new_module = RTDetrFrozenBatchNorm2d(module.num_features)
286
+ # Take the dot product between "query" and "key" to get the raw attention scores.
287
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
343
288
 
344
- if module.weight.device != torch.device("meta"):
345
- new_module.weight.copy_(module.weight)
346
- new_module.bias.copy_(module.bias)
347
- new_module.running_mean.copy_(module.running_mean)
348
- new_module.running_var.copy_(module.running_var)
289
+ if attention_mask is not None:
290
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
291
+ attn_weights = attn_weights + attention_mask
349
292
 
350
- model._modules[name] = new_module
293
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
294
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
351
295
 
352
- if len(list(module.children())) > 0:
353
- replace_batch_norm(module)
296
+ attn_output = torch.matmul(attn_weights, value)
297
+ attn_output = attn_output.transpose(1, 2).contiguous()
354
298
 
299
+ return attn_output, attn_weights
355
300
 
356
- def get_contrastive_denoising_training_group(
357
- targets,
358
- num_classes,
359
- num_queries,
360
- class_embed,
361
- num_denoising_queries=100,
362
- label_noise_ratio=0.5,
363
- box_noise_scale=1.0,
364
- ):
301
+
302
+ class RTDetrSelfAttention(nn.Module):
365
303
  """
366
- Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
304
+ Multi-headed self-attention from 'Attention Is All You Need' paper.
367
305
 
368
- Args:
369
- targets (`list[dict]`):
370
- The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
371
- num_classes (`int`):
372
- Total number of classes in the dataset.
373
- num_queries (`int`):
374
- Number of query slots in the transformer.
375
- class_embed (`callable`):
376
- A function or a model layer to embed class labels.
377
- num_denoising_queries (`int`, *optional*, defaults to 100):
378
- Number of denoising queries.
379
- label_noise_ratio (`float`, *optional*, defaults to 0.5):
380
- Ratio of noise applied to labels.
381
- box_noise_scale (`float`, *optional*, defaults to 1.0):
382
- Scale of noise applied to bounding boxes.
383
- Returns:
384
- `tuple` comprising various elements:
385
- - **input_query_class** (`torch.FloatTensor`) --
386
- Class queries with applied label noise.
387
- - **input_query_bbox** (`torch.FloatTensor`) --
388
- Bounding box queries with applied box noise.
389
- - **attn_mask** (`torch.FloatTensor`) --
390
- Attention mask for separating denoising and reconstruction queries.
391
- - **denoising_meta_values** (`dict`) --
392
- Metadata including denoising positive indices, number of groups, and split sizes.
306
+ In RT_DETR, position embeddings are added to both queries and keys (but not values) in self-attention.
393
307
  """
394
308
 
395
- if num_denoising_queries <= 0:
396
- return None, None, None, None
309
+ def __init__(
310
+ self,
311
+ config: RTDetrConfig,
312
+ hidden_size: int,
313
+ num_attention_heads: int,
314
+ dropout: float = 0.0,
315
+ bias: bool = True,
316
+ ):
317
+ super().__init__()
318
+ self.config = config
319
+ self.head_dim = hidden_size // num_attention_heads
320
+ self.scaling = self.head_dim**-0.5
321
+ self.attention_dropout = dropout
322
+ self.is_causal = False
397
323
 
398
- num_ground_truths = [len(t["class_labels"]) for t in targets]
399
- device = targets[0]["class_labels"].device
324
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
325
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
326
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
327
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
400
328
 
401
- max_gt_num = max(num_ground_truths)
402
- if max_gt_num == 0:
403
- return None, None, None, None
329
+ def forward(
330
+ self,
331
+ hidden_states: torch.Tensor,
332
+ attention_mask: torch.Tensor | None = None,
333
+ position_embeddings: torch.Tensor | None = None,
334
+ **kwargs: Unpack[TransformersKwargs],
335
+ ) -> tuple[torch.Tensor, torch.Tensor]:
336
+ """
337
+ Position embeddings are added to both queries and keys (but not values).
338
+ """
339
+ input_shape = hidden_states.shape[:-1]
340
+ hidden_shape = (*input_shape, -1, self.head_dim)
404
341
 
405
- num_groups_denoising_queries = num_denoising_queries // max_gt_num
406
- num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
407
- # pad gt to max_num of a batch
408
- batch_size = len(num_ground_truths)
342
+ query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
409
343
 
410
- input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
411
- input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
412
- pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
344
+ query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
345
+ key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
346
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
413
347
 
414
- for i in range(batch_size):
415
- num_gt = num_ground_truths[i]
416
- if num_gt > 0:
417
- input_query_class[i, :num_gt] = targets[i]["class_labels"]
418
- input_query_bbox[i, :num_gt] = targets[i]["boxes"]
419
- pad_gt_mask[i, :num_gt] = 1
420
- # each group has positive and negative queries.
421
- input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
422
- input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
423
- pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
424
- # positive and negative mask
425
- negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
426
- negative_gt_mask[:, max_gt_num:] = 1
427
- negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
428
- positive_gt_mask = 1 - negative_gt_mask
429
- # contrastive denoising training positive index
430
- positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
431
- denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
432
- denoise_positive_idx = torch.split(
433
- denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
434
- )
435
- # total denoising queries
436
- num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
348
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
349
+ self.config._attn_implementation, eager_attention_forward
350
+ )
437
351
 
438
- if label_noise_ratio > 0:
439
- mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
440
- # randomly put a new one here
441
- new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
442
- input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
352
+ attn_output, attn_weights = attention_interface(
353
+ self,
354
+ query_states,
355
+ key_states,
356
+ value_states,
357
+ attention_mask,
358
+ dropout=0.0 if not self.training else self.attention_dropout,
359
+ scaling=self.scaling,
360
+ **kwargs,
361
+ )
443
362
 
444
- if box_noise_scale > 0:
445
- known_bbox = center_to_corners_format(input_query_bbox)
446
- diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
447
- rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
448
- rand_part = torch.rand_like(input_query_bbox)
449
- rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
450
- rand_part *= rand_sign
451
- known_bbox += rand_part * diff
452
- known_bbox.clip_(min=0.0, max=1.0)
453
- input_query_bbox = corners_to_center_format(known_bbox)
454
- input_query_bbox = inverse_sigmoid(input_query_bbox)
363
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
364
+ attn_output = self.o_proj(attn_output)
365
+ return attn_output, attn_weights
455
366
 
456
- input_query_class = class_embed(input_query_class)
457
367
 
458
- target_size = num_denoising_queries + num_queries
459
- attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
460
- # match query cannot see the reconstruction
461
- attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
368
+ def replace_batch_norm(model):
369
+ r"""
370
+ Recursively replace all `torch.nn.BatchNorm2d` with `RTDetrFrozenBatchNorm2d`.
462
371
 
463
- # reconstructions cannot see each other
464
- for i in range(num_groups_denoising_queries):
465
- idx_block_start = max_gt_num * 2 * i
466
- idx_block_end = max_gt_num * 2 * (i + 1)
467
- attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
468
- attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
372
+ Args:
373
+ model (torch.nn.Module):
374
+ input model
375
+ """
376
+ for name, module in model.named_children():
377
+ if isinstance(module, nn.BatchNorm2d):
378
+ new_module = RTDetrFrozenBatchNorm2d(module.num_features)
469
379
 
470
- denoising_meta_values = {
471
- "dn_positive_idx": denoise_positive_idx,
472
- "dn_num_group": num_groups_denoising_queries,
473
- "dn_num_split": [num_denoising_queries, num_queries],
474
- }
380
+ if module.weight.device != torch.device("meta"):
381
+ new_module.weight.copy_(module.weight)
382
+ new_module.bias.copy_(module.bias)
383
+ new_module.running_mean.copy_(module.running_mean)
384
+ new_module.running_var.copy_(module.running_var)
475
385
 
476
- return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
386
+ model._modules[name] = new_module
387
+
388
+ if len(list(module.children())) > 0:
389
+ replace_batch_norm(module)
477
390
 
478
391
 
479
392
  class RTDetrConvEncoder(nn.Module):
@@ -533,50 +446,46 @@ class RTDetrEncoderLayer(nn.Module):
533
446
  def __init__(self, config: RTDetrConfig):
534
447
  super().__init__()
535
448
  self.normalize_before = config.normalize_before
449
+ self.hidden_size = config.encoder_hidden_dim
536
450
 
537
451
  # self-attention
538
- self.self_attn = RTDetrMultiheadAttention(
539
- embed_dim=config.encoder_hidden_dim,
540
- num_heads=config.num_attention_heads,
452
+ self.self_attn = RTDetrSelfAttention(
453
+ config=config,
454
+ hidden_size=self.hidden_size,
455
+ num_attention_heads=config.num_attention_heads,
541
456
  dropout=config.dropout,
542
457
  )
543
- self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
458
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
544
459
  self.dropout = config.dropout
545
- self.activation_fn = ACT2FN[config.encoder_activation_function]
546
- self.activation_dropout = config.activation_dropout
547
- self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
548
- self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
549
- self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
460
+ self.mlp = RTDetrMLP(config, self.hidden_size, config.encoder_ffn_dim, config.encoder_activation_function)
461
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
550
462
 
551
463
  def forward(
552
464
  self,
553
465
  hidden_states: torch.Tensor,
554
466
  attention_mask: torch.Tensor,
555
- position_embeddings: Optional[torch.Tensor] = None,
556
- output_attentions: bool = False,
557
- **kwargs,
558
- ):
467
+ spatial_position_embeddings: torch.Tensor | None = None,
468
+ **kwargs: Unpack[TransformersKwargs],
469
+ ) -> torch.Tensor:
559
470
  """
560
471
  Args:
561
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
472
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
562
473
  attention_mask (`torch.FloatTensor`): attention mask of size
563
474
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
564
475
  values.
565
- position_embeddings (`torch.FloatTensor`, *optional*):
566
- Object queries (also called content embeddings), to be added to the hidden states.
567
- output_attentions (`bool`, *optional*):
568
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
569
- returned tensors for more detail.
476
+ spatial_position_embeddings (`torch.FloatTensor`, *optional*):
477
+ Spatial position embeddings (2D positional encodings of image locations), to be added to both
478
+ the queries and keys in self-attention (but not to values).
570
479
  """
571
480
  residual = hidden_states
572
481
  if self.normalize_before:
573
482
  hidden_states = self.self_attn_layer_norm(hidden_states)
574
483
 
575
- hidden_states, attn_weights = self.self_attn(
484
+ hidden_states, _ = self.self_attn(
576
485
  hidden_states=hidden_states,
577
486
  attention_mask=attention_mask,
578
- position_embeddings=position_embeddings,
579
- output_attentions=output_attentions,
487
+ position_embeddings=spatial_position_embeddings,
488
+ **kwargs,
580
489
  )
581
490
 
582
491
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -588,12 +497,7 @@ class RTDetrEncoderLayer(nn.Module):
588
497
  hidden_states = self.final_layer_norm(hidden_states)
589
498
  residual = hidden_states
590
499
 
591
- hidden_states = self.activation_fn(self.fc1(hidden_states))
592
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
593
-
594
- hidden_states = self.fc2(hidden_states)
595
-
596
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
500
+ hidden_states = self.mlp(hidden_states)
597
501
 
598
502
  hidden_states = residual + hidden_states
599
503
  if not self.normalize_before:
@@ -604,12 +508,7 @@ class RTDetrEncoderLayer(nn.Module):
604
508
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
605
509
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
606
510
 
607
- outputs = (hidden_states,)
608
-
609
- if output_attentions:
610
- outputs += (attn_weights,)
611
-
612
- return outputs
511
+ return hidden_states
613
512
 
614
513
 
615
514
  class RTDetrRepVggBlock(nn.Module):
@@ -660,7 +559,61 @@ class RTDetrCSPRepLayer(nn.Module):
660
559
  return self.conv3(hidden_state_1 + hidden_state_2)
661
560
 
662
561
 
663
- # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->RTDetr
562
+ @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
563
+ class MultiScaleDeformableAttention(nn.Module):
564
+ def forward(
565
+ self,
566
+ value: Tensor,
567
+ value_spatial_shapes: Tensor,
568
+ value_spatial_shapes_list: list[tuple],
569
+ level_start_index: Tensor,
570
+ sampling_locations: Tensor,
571
+ attention_weights: Tensor,
572
+ im2col_step: int,
573
+ ):
574
+ batch_size, _, num_heads, hidden_dim = value.shape
575
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
576
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
577
+ sampling_grids = 2 * sampling_locations - 1
578
+ sampling_value_list = []
579
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
580
+ # batch_size, height*width, num_heads, hidden_dim
581
+ # -> batch_size, height*width, num_heads*hidden_dim
582
+ # -> batch_size, num_heads*hidden_dim, height*width
583
+ # -> batch_size*num_heads, hidden_dim, height, width
584
+ value_l_ = (
585
+ value_list[level_id]
586
+ .flatten(2)
587
+ .transpose(1, 2)
588
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
589
+ )
590
+ # batch_size, num_queries, num_heads, num_points, 2
591
+ # -> batch_size, num_heads, num_queries, num_points, 2
592
+ # -> batch_size*num_heads, num_queries, num_points, 2
593
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
594
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
595
+ sampling_value_l_ = nn.functional.grid_sample(
596
+ value_l_,
597
+ sampling_grid_l_,
598
+ mode="bilinear",
599
+ padding_mode="zeros",
600
+ align_corners=False,
601
+ )
602
+ sampling_value_list.append(sampling_value_l_)
603
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
604
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
605
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
606
+ attention_weights = attention_weights.transpose(1, 2).reshape(
607
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
608
+ )
609
+ output = (
610
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
611
+ .sum(-1)
612
+ .view(batch_size, num_heads * hidden_dim, num_queries)
613
+ )
614
+ return output.transpose(1, 2).contiguous()
615
+
616
+
664
617
  class RTDetrMultiscaleDeformableAttention(nn.Module):
665
618
  """
666
619
  Multiscale deformable attention as proposed in Deformable DETR.
@@ -698,33 +651,30 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
698
651
 
699
652
  self.disable_custom_kernels = config.disable_custom_kernels
700
653
 
701
- def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
702
- return tensor if position_embeddings is None else tensor + position_embeddings
703
-
704
654
  def forward(
705
655
  self,
706
656
  hidden_states: torch.Tensor,
707
- attention_mask: Optional[torch.Tensor] = None,
657
+ attention_mask: torch.Tensor | None = None,
708
658
  encoder_hidden_states=None,
709
659
  encoder_attention_mask=None,
710
- position_embeddings: Optional[torch.Tensor] = None,
660
+ position_embeddings: torch.Tensor | None = None,
711
661
  reference_points=None,
712
662
  spatial_shapes=None,
713
663
  spatial_shapes_list=None,
714
664
  level_start_index=None,
715
- output_attentions: bool = False,
716
- ):
665
+ **kwargs: Unpack[TransformersKwargs],
666
+ ) -> tuple[torch.Tensor, torch.Tensor]:
717
667
  # add position embeddings to the hidden states before projecting to queries and keys
718
668
  if position_embeddings is not None:
719
- hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
669
+ hidden_states = hidden_states + position_embeddings
720
670
 
721
671
  batch_size, num_queries, _ = hidden_states.shape
722
672
  batch_size, sequence_length, _ = encoder_hidden_states.shape
723
673
  total_elements = sum(height * width for height, width in spatial_shapes_list)
724
- if total_elements != sequence_length:
725
- raise ValueError(
726
- "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
727
- )
674
+ torch_compilable_check(
675
+ total_elements == sequence_length,
676
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
677
+ )
728
678
 
729
679
  value = self.value_proj(encoder_hidden_states)
730
680
  if attention_mask is not None:
@@ -771,235 +721,218 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
771
721
  return output, attention_weights
772
722
 
773
723
 
774
- class RTDetrMultiheadAttention(nn.Module):
775
- """
776
- Multi-headed attention from 'Attention Is All You Need' paper.
777
-
778
- Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
779
- """
780
-
781
- def __init__(
782
- self,
783
- embed_dim: int,
784
- num_heads: int,
785
- dropout: float = 0.0,
786
- bias: bool = True,
787
- ):
724
+ class RTDetrDecoderLayer(nn.Module):
725
+ def __init__(self, config: RTDetrConfig):
788
726
  super().__init__()
789
- self.embed_dim = embed_dim
790
- self.num_heads = num_heads
791
- self.dropout = dropout
792
- self.head_dim = embed_dim // num_heads
793
- if self.head_dim * num_heads != self.embed_dim:
794
- raise ValueError(
795
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
796
- f" {num_heads})."
797
- )
798
- self.scaling = self.head_dim**-0.5
799
-
800
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
801
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
802
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
803
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
727
+ self.hidden_size = config.d_model
804
728
 
805
- def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
806
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
729
+ # self-attention
730
+ self.self_attn = RTDetrSelfAttention(
731
+ config=config,
732
+ hidden_size=self.hidden_size,
733
+ num_attention_heads=config.decoder_attention_heads,
734
+ dropout=config.attention_dropout,
735
+ )
736
+ self.dropout = config.dropout
807
737
 
808
- def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
809
- return tensor if position_embeddings is None else tensor + position_embeddings
738
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
739
+ # cross-attention
740
+ self.encoder_attn = RTDetrMultiscaleDeformableAttention(
741
+ config,
742
+ num_heads=config.decoder_attention_heads,
743
+ n_points=config.decoder_n_points,
744
+ )
745
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
746
+ # feedforward neural networks
747
+ self.mlp = RTDetrMLP(config, self.hidden_size, config.decoder_ffn_dim, config.decoder_activation_function)
748
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
810
749
 
811
750
  def forward(
812
751
  self,
813
752
  hidden_states: torch.Tensor,
814
- attention_mask: Optional[torch.Tensor] = None,
815
- position_embeddings: Optional[torch.Tensor] = None,
816
- output_attentions: bool = False,
817
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
818
- """Input shape: Batch x Time x Channel"""
753
+ object_queries_position_embeddings: torch.Tensor | None = None,
754
+ reference_points=None,
755
+ spatial_shapes=None,
756
+ spatial_shapes_list=None,
757
+ level_start_index=None,
758
+ encoder_hidden_states: torch.Tensor | None = None,
759
+ encoder_attention_mask: torch.Tensor | None = None,
760
+ **kwargs: Unpack[TransformersKwargs],
761
+ ) -> torch.Tensor:
762
+ """
763
+ Args:
764
+ hidden_states (`torch.FloatTensor`):
765
+ Input to the layer of shape `(batch, seq_len, hidden_size)`.
766
+ object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
767
+ Position embeddings for the object query slots. These are added to both queries and keys
768
+ in the self-attention layer (not values).
769
+ reference_points (`torch.FloatTensor`, *optional*):
770
+ Reference points.
771
+ spatial_shapes (`torch.LongTensor`, *optional*):
772
+ Spatial shapes.
773
+ level_start_index (`torch.LongTensor`, *optional*):
774
+ Level start index.
775
+ encoder_hidden_states (`torch.FloatTensor`):
776
+ cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
777
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
778
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
779
+ values.
780
+ """
781
+ residual = hidden_states
819
782
 
820
- batch_size, target_len, embed_dim = hidden_states.size()
821
- # add position embeddings to the hidden states before projecting to queries and keys
822
- if position_embeddings is not None:
823
- hidden_states_original = hidden_states
824
- hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
783
+ # Self Attention
784
+ hidden_states, _ = self.self_attn(
785
+ hidden_states=hidden_states,
786
+ attention_mask=encoder_attention_mask,
787
+ position_embeddings=object_queries_position_embeddings,
788
+ **kwargs,
789
+ )
825
790
 
826
- # get queries, keys and values
827
- query_states = self.q_proj(hidden_states) * self.scaling
828
- key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size)
829
- value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
791
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
792
+ hidden_states = residual + hidden_states
793
+ hidden_states = self.self_attn_layer_norm(hidden_states)
830
794
 
831
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
832
- query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
833
- key_states = key_states.view(*proj_shape)
834
- value_states = value_states.view(*proj_shape)
795
+ residual = hidden_states
835
796
 
836
- source_len = key_states.size(1)
797
+ # Cross-Attention
798
+ hidden_states, _ = self.encoder_attn(
799
+ hidden_states=hidden_states,
800
+ encoder_hidden_states=encoder_hidden_states,
801
+ position_embeddings=object_queries_position_embeddings,
802
+ reference_points=reference_points,
803
+ spatial_shapes=spatial_shapes,
804
+ spatial_shapes_list=spatial_shapes_list,
805
+ level_start_index=level_start_index,
806
+ **kwargs,
807
+ )
837
808
 
838
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
809
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
810
+ hidden_states = residual + hidden_states
839
811
 
840
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
841
- raise ValueError(
842
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
843
- f" {attn_weights.size()}"
844
- )
812
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
845
813
 
846
- # expand attention_mask
847
- if attention_mask is not None:
848
- # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
849
- attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
814
+ # Fully Connected
815
+ residual = hidden_states
816
+ hidden_states = self.mlp(hidden_states)
817
+ hidden_states = residual + hidden_states
818
+ hidden_states = self.final_layer_norm(hidden_states)
850
819
 
851
- if attention_mask is not None:
852
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
853
- raise ValueError(
854
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
855
- f" {attention_mask.size()}"
856
- )
857
- if attention_mask.dtype == torch.bool:
858
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
859
- attention_mask, -torch.inf
860
- )
861
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
862
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
863
-
864
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
865
-
866
- if output_attentions:
867
- # this operation is a bit awkward, but it's required to
868
- # make sure that attn_weights keeps its gradient.
869
- # In order to do so, attn_weights have to reshaped
870
- # twice and have to be reused in the following
871
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
872
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
873
- else:
874
- attn_weights_reshaped = None
820
+ return hidden_states
875
821
 
876
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
877
822
 
878
- attn_output = torch.bmm(attn_probs, value_states)
823
+ class RTDetrSinePositionEmbedding(nn.Module):
824
+ """
825
+ 2D sinusoidal position embedding used in RT-DETR hybrid encoder.
826
+ """
879
827
 
880
- if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
881
- raise ValueError(
882
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
883
- f" {attn_output.size()}"
884
- )
828
+ def __init__(self, embed_dim: int = 256, temperature: int = 10000):
829
+ super().__init__()
830
+ self.embed_dim = embed_dim
831
+ self.temperature = temperature
885
832
 
886
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
887
- attn_output = attn_output.transpose(1, 2)
888
- attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
833
+ @compile_compatible_method_lru_cache(maxsize=32)
834
+ def forward(
835
+ self,
836
+ width: int,
837
+ height: int,
838
+ device: torch.device | str,
839
+ dtype: torch.dtype,
840
+ ) -> torch.Tensor:
841
+ """
842
+ Generate 2D sinusoidal position embeddings.
889
843
 
890
- attn_output = self.out_proj(attn_output)
844
+ Returns:
845
+ Position embeddings of shape (1, height*width, embed_dim)
846
+ """
847
+ grid_w = torch.arange(torch_int(width), device=device).to(dtype)
848
+ grid_h = torch.arange(torch_int(height), device=device).to(dtype)
849
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
850
+ if self.embed_dim % 4 != 0:
851
+ raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
852
+ pos_dim = self.embed_dim // 4
853
+ omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
854
+ omega = 1.0 / (self.temperature**omega)
855
+
856
+ out_w = grid_w.flatten()[..., None] @ omega[None]
857
+ out_h = grid_h.flatten()[..., None] @ omega[None]
891
858
 
892
- return attn_output, attn_weights_reshaped
859
+ return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
893
860
 
894
861
 
895
- class RTDetrDecoderLayer(nn.Module):
862
+ class RTDetrAIFILayer(nn.Module):
863
+ """
864
+ AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder.
865
+ """
866
+
896
867
  def __init__(self, config: RTDetrConfig):
897
868
  super().__init__()
898
- # self-attention
899
- self.self_attn = RTDetrMultiheadAttention(
900
- embed_dim=config.d_model,
901
- num_heads=config.decoder_attention_heads,
902
- dropout=config.attention_dropout,
903
- )
904
- self.dropout = config.dropout
905
- self.activation_fn = ACT2FN[config.decoder_activation_function]
906
- self.activation_dropout = config.activation_dropout
869
+ self.config = config
870
+ self.encoder_hidden_dim = config.encoder_hidden_dim
871
+ self.eval_size = config.eval_size
907
872
 
908
- self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
909
- # cross-attention
910
- self.encoder_attn = RTDetrMultiscaleDeformableAttention(
911
- config,
912
- num_heads=config.decoder_attention_heads,
913
- n_points=config.decoder_n_points,
873
+ self.position_embedding = RTDetrSinePositionEmbedding(
874
+ embed_dim=self.encoder_hidden_dim,
875
+ temperature=config.positional_encoding_temperature,
914
876
  )
915
- self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
916
- # feedforward neural networks
917
- self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
918
- self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
919
- self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
877
+ self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
920
878
 
921
879
  def forward(
922
- self,
923
- hidden_states: torch.Tensor,
924
- position_embeddings: Optional[torch.Tensor] = None,
925
- reference_points=None,
926
- spatial_shapes=None,
927
- spatial_shapes_list=None,
928
- level_start_index=None,
929
- encoder_hidden_states: Optional[torch.Tensor] = None,
930
- encoder_attention_mask: Optional[torch.Tensor] = None,
931
- output_attentions: Optional[bool] = False,
932
- ):
880
+ self,
881
+ hidden_states: torch.Tensor,
882
+ **kwargs: Unpack[TransformersKwargs],
883
+ ) -> torch.Tensor:
933
884
  """
934
885
  Args:
935
- hidden_states (`torch.FloatTensor`):
936
- Input to the layer of shape `(seq_len, batch, embed_dim)`.
937
- position_embeddings (`torch.FloatTensor`, *optional*):
938
- Position embeddings that are added to the queries and keys in the self-attention layer.
939
- reference_points (`torch.FloatTensor`, *optional*):
940
- Reference points.
941
- spatial_shapes (`torch.LongTensor`, *optional*):
942
- Spatial shapes.
943
- level_start_index (`torch.LongTensor`, *optional*):
944
- Level start index.
945
- encoder_hidden_states (`torch.FloatTensor`):
946
- cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
947
- encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
948
- `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
949
- values.
950
- output_attentions (`bool`, *optional*):
951
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
952
- returned tensors for more detail.
886
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
887
+ Feature map to process.
953
888
  """
954
- residual = hidden_states
889
+ batch_size = hidden_states.shape[0]
890
+ height, width = hidden_states.shape[2:]
955
891
 
956
- # Self Attention
957
- hidden_states, self_attn_weights = self.self_attn(
958
- hidden_states=hidden_states,
959
- attention_mask=encoder_attention_mask,
960
- position_embeddings=position_embeddings,
961
- output_attentions=output_attentions,
962
- )
892
+ hidden_states = hidden_states.flatten(2).permute(0, 2, 1)
963
893
 
964
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
965
- hidden_states = residual + hidden_states
966
- hidden_states = self.self_attn_layer_norm(hidden_states)
894
+ if self.training or self.eval_size is None:
895
+ pos_embed = self.position_embedding(
896
+ width=width,
897
+ height=height,
898
+ device=hidden_states.device,
899
+ dtype=hidden_states.dtype,
900
+ )
901
+ else:
902
+ pos_embed = None
967
903
 
968
- second_residual = hidden_states
904
+ for layer in self.layers:
905
+ hidden_states = layer(
906
+ hidden_states,
907
+ attention_mask=None,
908
+ spatial_position_embeddings=pos_embed,
909
+ **kwargs,
910
+ )
969
911
 
970
- # Cross-Attention
971
- cross_attn_weights = None
972
- hidden_states, cross_attn_weights = self.encoder_attn(
973
- hidden_states=hidden_states,
974
- encoder_hidden_states=encoder_hidden_states,
975
- position_embeddings=position_embeddings,
976
- reference_points=reference_points,
977
- spatial_shapes=spatial_shapes,
978
- spatial_shapes_list=spatial_shapes_list,
979
- level_start_index=level_start_index,
980
- output_attentions=output_attentions,
912
+ hidden_states = (
913
+ hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous()
981
914
  )
982
915
 
983
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
984
- hidden_states = second_residual + hidden_states
916
+ return hidden_states
985
917
 
986
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
987
918
 
988
- # Fully Connected
989
- residual = hidden_states
990
- hidden_states = self.activation_fn(self.fc1(hidden_states))
991
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
992
- hidden_states = self.fc2(hidden_states)
993
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
994
- hidden_states = residual + hidden_states
995
- hidden_states = self.final_layer_norm(hidden_states)
919
+ class RTDetrMLPPredictionHead(nn.Module):
920
+ """
921
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
922
+ height and width of a bounding box w.r.t. an image.
996
923
 
997
- outputs = (hidden_states,)
924
+ """
998
925
 
999
- if output_attentions:
1000
- outputs += (self_attn_weights, cross_attn_weights)
926
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
927
+ super().__init__()
928
+ self.num_layers = num_layers
929
+ h = [hidden_dim] * (num_layers - 1)
930
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1001
931
 
1002
- return outputs
932
+ def forward(self, x):
933
+ for i, layer in enumerate(self.layers):
934
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
935
+ return x
1003
936
 
1004
937
 
1005
938
  @auto_docstring
@@ -1009,6 +942,10 @@ class RTDetrPreTrainedModel(PreTrainedModel):
1009
942
  main_input_name = "pixel_values"
1010
943
  input_modalities = ("image",)
1011
944
  _no_split_modules = [r"RTDetrHybridEncoder", r"RTDetrDecoderLayer"]
945
+ _supports_sdpa = True
946
+ _supports_flash_attn = True
947
+ _supports_attention_backend = True
948
+ _supports_flex_attn = True
1012
949
 
1013
950
  @torch.no_grad()
1014
951
  def _init_weights(self, module):
@@ -1074,35 +1011,23 @@ class RTDetrPreTrainedModel(PreTrainedModel):
1074
1011
  init.xavier_uniform_(module.denoising_class_embed.weight)
1075
1012
 
1076
1013
 
1077
- class RTDetrEncoder(nn.Module):
1078
- def __init__(self, config: RTDetrConfig):
1079
- super().__init__()
1080
-
1081
- self.layers = nn.ModuleList([RTDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
1082
-
1083
- def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
1084
- hidden_states = src
1085
- for layer in self.layers:
1086
- hidden_states = layer(
1087
- hidden_states,
1088
- attention_mask=src_mask,
1089
- position_embeddings=pos_embed,
1090
- output_attentions=output_attentions,
1091
- )
1092
- return hidden_states
1093
-
1094
-
1095
- class RTDetrHybridEncoder(nn.Module):
1014
+ class RTDetrHybridEncoder(RTDetrPreTrainedModel):
1096
1015
  """
1097
- Decoder consisting of a projection layer, a set of `RTDetrEncoder`, a top-down Feature Pyramid Network
1098
- (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069
1016
+ Hybrid encoder consisting of AIFI (Attention-based Intra-scale Feature Interaction) layers,
1017
+ a top-down Feature Pyramid Network (FPN) and a bottom-up Path Aggregation Network (PAN).
1018
+ More details on the paper: https://huggingface.co/papers/2304.08069
1099
1019
 
1100
1020
  Args:
1101
1021
  config: RTDetrConfig
1102
1022
  """
1103
1023
 
1024
+ _can_record_outputs = {
1025
+ "hidden_states": RTDetrAIFILayer,
1026
+ "attentions": RTDetrSelfAttention,
1027
+ }
1028
+
1104
1029
  def __init__(self, config: RTDetrConfig):
1105
- super().__init__()
1030
+ super().__init__(config)
1106
1031
  self.config = config
1107
1032
  self.in_channels = config.encoder_in_channels
1108
1033
  self.feat_strides = config.feat_strides
@@ -1114,10 +1039,9 @@ class RTDetrHybridEncoder(nn.Module):
1114
1039
  self.out_strides = self.feat_strides
1115
1040
  self.num_fpn_stages = len(self.in_channels) - 1
1116
1041
  self.num_pan_stages = len(self.in_channels) - 1
1117
- activation = config.activation_function
1118
1042
 
1119
- # encoder transformer
1120
- self.encoder = nn.ModuleList([RTDetrEncoder(config) for _ in range(len(self.encode_proj_layers))])
1043
+ # AIFI (Attention-based Intra-scale Feature Interaction) layers
1044
+ self.aifi = nn.ModuleList([RTDetrAIFILayer(config) for _ in range(len(self.encode_proj_layers))])
1121
1045
 
1122
1046
  # top-down FPN
1123
1047
  self.lateral_convs = nn.ModuleList()
@@ -1129,7 +1053,7 @@ class RTDetrHybridEncoder(nn.Module):
1129
1053
  out_channels=self.encoder_hidden_dim,
1130
1054
  kernel_size=1,
1131
1055
  stride=1,
1132
- activation=activation,
1056
+ activation=config.activation_function,
1133
1057
  )
1134
1058
  fpn_block = RTDetrCSPRepLayer(config)
1135
1059
  self.lateral_convs.append(lateral_conv)
@@ -1145,118 +1069,36 @@ class RTDetrHybridEncoder(nn.Module):
1145
1069
  out_channels=self.encoder_hidden_dim,
1146
1070
  kernel_size=3,
1147
1071
  stride=2,
1148
- activation=activation,
1072
+ activation=config.activation_function,
1149
1073
  )
1150
1074
  pan_block = RTDetrCSPRepLayer(config)
1151
1075
  self.downsample_convs.append(downsample_conv)
1152
1076
  self.pan_blocks.append(pan_block)
1153
1077
 
1154
- @staticmethod
1155
- def build_2d_sincos_position_embedding(
1156
- width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
1157
- ):
1158
- grid_w = torch.arange(torch_int(width), device=device).to(dtype)
1159
- grid_h = torch.arange(torch_int(height), device=device).to(dtype)
1160
- grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
1161
- if embed_dim % 4 != 0:
1162
- raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
1163
- pos_dim = embed_dim // 4
1164
- omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
1165
- omega = 1.0 / (temperature**omega)
1166
-
1167
- out_w = grid_w.flatten()[..., None] @ omega[None]
1168
- out_h = grid_h.flatten()[..., None] @ omega[None]
1169
-
1170
- return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
1078
+ self.post_init()
1171
1079
 
1080
+ @check_model_inputs(tie_last_hidden_states=False)
1172
1081
  def forward(
1173
1082
  self,
1174
1083
  inputs_embeds=None,
1175
- attention_mask=None,
1176
- position_embeddings=None,
1177
- spatial_shapes=None,
1178
- level_start_index=None,
1179
- valid_ratios=None,
1180
- output_attentions=None,
1181
- output_hidden_states=None,
1182
- return_dict=None,
1183
- ):
1084
+ **kwargs: Unpack[TransformersKwargs],
1085
+ ) -> BaseModelOutput:
1184
1086
  r"""
1185
1087
  Args:
1186
1088
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1187
1089
  Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
1188
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1189
- Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
1190
- - 1 for pixel features that are real (i.e. **not masked**),
1191
- - 0 for pixel features that are padding (i.e. **masked**).
1192
- [What are attention masks?](../glossary#attention-mask)
1193
- position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1194
- Position embeddings that are added to the queries and keys in each self-attention layer.
1195
- spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
1196
- Spatial shapes of each feature map.
1197
- level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
1198
- Starting index of each feature map.
1199
- valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
1200
- Ratio of valid area in each feature level.
1201
- output_attentions (`bool`, *optional*):
1202
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1203
- returned tensors for more detail.
1204
- output_hidden_states (`bool`, *optional*):
1205
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1206
- for more detail.
1207
- return_dict (`bool`, *optional*):
1208
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1209
1090
  """
1210
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1211
- output_hidden_states = (
1212
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1213
- )
1214
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1215
-
1216
- hidden_states = inputs_embeds
1091
+ feature_maps = inputs_embeds
1217
1092
 
1218
- encoder_states = () if output_hidden_states else None
1219
- all_attentions = () if output_attentions else None
1220
-
1221
- # encoder
1093
+ # AIFI: Apply transformer encoder to specified feature levels
1222
1094
  if self.config.encoder_layers > 0:
1223
1095
  for i, enc_ind in enumerate(self.encode_proj_layers):
1224
- if output_hidden_states:
1225
- encoder_states = encoder_states + (hidden_states[enc_ind],)
1226
- height, width = hidden_states[enc_ind].shape[2:]
1227
- # flatten [batch, channel, height, width] to [batch, height*width, channel]
1228
- src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
1229
- if self.training or self.eval_size is None:
1230
- pos_embed = self.build_2d_sincos_position_embedding(
1231
- width,
1232
- height,
1233
- self.encoder_hidden_dim,
1234
- self.positional_encoding_temperature,
1235
- device=src_flatten.device,
1236
- dtype=src_flatten.dtype,
1237
- )
1238
- else:
1239
- pos_embed = None
1240
-
1241
- layer_outputs = self.encoder[i](
1242
- src_flatten,
1243
- pos_embed=pos_embed,
1244
- output_attentions=output_attentions,
1245
- )
1246
- hidden_states[enc_ind] = (
1247
- layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
1248
- )
1249
-
1250
- if output_attentions:
1251
- all_attentions = all_attentions + (layer_outputs[1],)
1252
-
1253
- if output_hidden_states:
1254
- encoder_states = encoder_states + (hidden_states[enc_ind],)
1096
+ feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs)
1255
1097
 
1256
1098
  # top-down FPN
1257
- fpn_feature_maps = [hidden_states[-1]]
1099
+ fpn_feature_maps = [feature_maps[-1]]
1258
1100
  for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
1259
- backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
1101
+ backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1]
1260
1102
  top_fpn_feature_map = fpn_feature_maps[-1]
1261
1103
  # apply lateral block
1262
1104
  top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
@@ -1279,20 +1121,29 @@ class RTDetrHybridEncoder(nn.Module):
1279
1121
  new_pan_feature_map = pan_block(fused_feature_map)
1280
1122
  pan_feature_maps.append(new_pan_feature_map)
1281
1123
 
1282
- if not return_dict:
1283
- return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
1284
- return BaseModelOutput(
1285
- last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
1286
- )
1124
+ return BaseModelOutput(last_hidden_state=pan_feature_maps)
1125
+
1126
+
1127
+ def inverse_sigmoid(x, eps=1e-5):
1128
+ x = x.clamp(min=0, max=1)
1129
+ x1 = x.clamp(min=eps)
1130
+ x2 = (1 - x).clamp(min=eps)
1131
+ return torch.log(x1 / x2)
1287
1132
 
1288
1133
 
1289
1134
  class RTDetrDecoder(RTDetrPreTrainedModel):
1135
+ _can_record_outputs = {
1136
+ "hidden_states": RTDetrDecoderLayer,
1137
+ "attentions": RTDetrSelfAttention,
1138
+ "cross_attentions": RTDetrMultiscaleDeformableAttention,
1139
+ }
1140
+
1290
1141
  def __init__(self, config: RTDetrConfig):
1291
1142
  super().__init__(config)
1292
1143
 
1293
1144
  self.dropout = config.dropout
1294
1145
  self.layers = nn.ModuleList([RTDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
1295
- self.query_pos_head = RTDetrMLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2)
1146
+ self.query_pos_head = RTDetrMLPPredictionHead(4, 2 * config.d_model, config.d_model, num_layers=2)
1296
1147
 
1297
1148
  # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
1298
1149
  self.bbox_embed = None
@@ -1301,21 +1152,17 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
1301
1152
  # Initialize weights and apply final processing
1302
1153
  self.post_init()
1303
1154
 
1155
+ @check_model_inputs()
1304
1156
  def forward(
1305
1157
  self,
1306
1158
  inputs_embeds=None,
1307
1159
  encoder_hidden_states=None,
1308
1160
  encoder_attention_mask=None,
1309
- position_embeddings=None,
1310
1161
  reference_points=None,
1311
1162
  spatial_shapes=None,
1312
1163
  spatial_shapes_list=None,
1313
1164
  level_start_index=None,
1314
- valid_ratios=None,
1315
- output_attentions=None,
1316
- output_hidden_states=None,
1317
- return_dict=None,
1318
- **kwargs,
1165
+ **kwargs: Unpack[TransformersKwargs],
1319
1166
  ):
1320
1167
  r"""
1321
1168
  Args:
@@ -1329,39 +1176,17 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
1329
1176
  in `[0, 1]`:
1330
1177
  - 1 for pixels that are real (i.e. **not masked**),
1331
1178
  - 0 for pixels that are padding (i.e. **masked**).
1332
- position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1333
- Position embeddings that are added to the queries and keys in each self-attention layer.
1334
1179
  reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
1335
1180
  Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
1336
1181
  spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
1337
1182
  Spatial shapes of the feature maps.
1338
1183
  level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
1339
1184
  Indexes for the start of each feature level. In range `[0, sequence_length]`.
1340
- valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
1341
- Ratio of valid area in each feature level.
1342
-
1343
- output_attentions (`bool`, *optional*):
1344
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1345
- returned tensors for more detail.
1346
- output_hidden_states (`bool`, *optional*):
1347
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1348
- for more detail.
1349
- return_dict (`bool`, *optional*):
1350
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1351
1185
  """
1352
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1353
- output_hidden_states = (
1354
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1355
- )
1356
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1357
-
1358
1186
  if inputs_embeds is not None:
1359
1187
  hidden_states = inputs_embeds
1360
1188
 
1361
1189
  # decoder layers
1362
- all_hidden_states = () if output_hidden_states else None
1363
- all_self_attns = () if output_attentions else None
1364
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1365
1190
  intermediate = ()
1366
1191
  intermediate_reference_points = ()
1367
1192
  intermediate_logits = ()
@@ -1371,25 +1196,20 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
1371
1196
  # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L252
1372
1197
  for idx, decoder_layer in enumerate(self.layers):
1373
1198
  reference_points_input = reference_points.unsqueeze(2)
1374
- position_embeddings = self.query_pos_head(reference_points)
1375
-
1376
- if output_hidden_states:
1377
- all_hidden_states += (hidden_states,)
1199
+ object_queries_position_embeddings = self.query_pos_head(reference_points)
1378
1200
 
1379
- layer_outputs = decoder_layer(
1201
+ hidden_states = decoder_layer(
1380
1202
  hidden_states,
1381
- position_embeddings=position_embeddings,
1203
+ object_queries_position_embeddings=object_queries_position_embeddings,
1382
1204
  encoder_hidden_states=encoder_hidden_states,
1383
1205
  reference_points=reference_points_input,
1384
1206
  spatial_shapes=spatial_shapes,
1385
1207
  spatial_shapes_list=spatial_shapes_list,
1386
1208
  level_start_index=level_start_index,
1387
1209
  encoder_attention_mask=encoder_attention_mask,
1388
- output_attentions=output_attentions,
1210
+ **kwargs,
1389
1211
  )
1390
1212
 
1391
- hidden_states = layer_outputs[0]
1392
-
1393
1213
  # hack implementation for iterative bounding box refinement
1394
1214
  if self.bbox_embed is not None:
1395
1215
  predicted_corners = self.bbox_embed[idx](hidden_states)
@@ -1405,68 +1225,141 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
1405
1225
  logits = self.class_embed[idx](hidden_states)
1406
1226
  intermediate_logits += (logits,)
1407
1227
 
1408
- if output_attentions:
1409
- all_self_attns += (layer_outputs[1],)
1410
-
1411
- if encoder_hidden_states is not None:
1412
- all_cross_attentions += (layer_outputs[2],)
1413
-
1414
1228
  # Keep batch_size as first dimension
1415
1229
  intermediate = torch.stack(intermediate, dim=1)
1416
1230
  intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
1417
1231
  if self.class_embed is not None:
1418
1232
  intermediate_logits = torch.stack(intermediate_logits, dim=1)
1419
1233
 
1420
- # add hidden states from the last decoder layer
1421
- if output_hidden_states:
1422
- all_hidden_states += (hidden_states,)
1423
-
1424
- if not return_dict:
1425
- return tuple(
1426
- v
1427
- for v in [
1428
- hidden_states,
1429
- intermediate,
1430
- intermediate_logits,
1431
- intermediate_reference_points,
1432
- all_hidden_states,
1433
- all_self_attns,
1434
- all_cross_attentions,
1435
- ]
1436
- if v is not None
1437
- )
1438
1234
  return RTDetrDecoderOutput(
1439
1235
  last_hidden_state=hidden_states,
1440
1236
  intermediate_hidden_states=intermediate,
1441
1237
  intermediate_logits=intermediate_logits,
1442
1238
  intermediate_reference_points=intermediate_reference_points,
1443
- hidden_states=all_hidden_states,
1444
- attentions=all_self_attns,
1445
- cross_attentions=all_cross_attentions,
1446
1239
  )
1447
1240
 
1448
1241
 
1449
- # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1450
- class RTDetrMLPPredictionHead(nn.Module):
1242
+ def get_contrastive_denoising_training_group(
1243
+ targets,
1244
+ num_classes,
1245
+ num_queries,
1246
+ class_embed,
1247
+ num_denoising_queries=100,
1248
+ label_noise_ratio=0.5,
1249
+ box_noise_scale=1.0,
1250
+ ):
1451
1251
  """
1452
- Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1453
- height and width of a bounding box w.r.t. an image.
1454
-
1455
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1456
- Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_paddle/ppdet/modeling/transformers/utils.py#L453
1252
+ Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
1457
1253
 
1254
+ Args:
1255
+ targets (`list[dict]`):
1256
+ The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
1257
+ num_classes (`int`):
1258
+ Total number of classes in the dataset.
1259
+ num_queries (`int`):
1260
+ Number of query slots in the transformer.
1261
+ class_embed (`callable`):
1262
+ A function or a model layer to embed class labels.
1263
+ num_denoising_queries (`int`, *optional*, defaults to 100):
1264
+ Number of denoising queries.
1265
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
1266
+ Ratio of noise applied to labels.
1267
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
1268
+ Scale of noise applied to bounding boxes.
1269
+ Returns:
1270
+ `tuple` comprising various elements:
1271
+ - **input_query_class** (`torch.FloatTensor`) --
1272
+ Class queries with applied label noise.
1273
+ - **input_query_bbox** (`torch.FloatTensor`) --
1274
+ Bounding box queries with applied box noise.
1275
+ - **attn_mask** (`torch.FloatTensor`) --
1276
+ Attention mask for separating denoising and reconstruction queries.
1277
+ - **denoising_meta_values** (`dict`) --
1278
+ Metadata including denoising positive indices, number of groups, and split sizes.
1458
1279
  """
1459
1280
 
1460
- def __init__(self, config, input_dim, d_model, output_dim, num_layers):
1461
- super().__init__()
1462
- self.num_layers = num_layers
1463
- h = [d_model] * (num_layers - 1)
1464
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1281
+ if num_denoising_queries <= 0:
1282
+ return None, None, None, None
1465
1283
 
1466
- def forward(self, x):
1467
- for i, layer in enumerate(self.layers):
1468
- x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1469
- return x
1284
+ num_ground_truths = [len(t["class_labels"]) for t in targets]
1285
+ device = targets[0]["class_labels"].device
1286
+
1287
+ max_gt_num = max(num_ground_truths)
1288
+ if max_gt_num == 0:
1289
+ return None, None, None, None
1290
+
1291
+ num_groups_denoising_queries = num_denoising_queries // max_gt_num
1292
+ num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
1293
+ # pad gt to max_num of a batch
1294
+ batch_size = len(num_ground_truths)
1295
+
1296
+ input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
1297
+ input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
1298
+ pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
1299
+
1300
+ for i in range(batch_size):
1301
+ num_gt = num_ground_truths[i]
1302
+ if num_gt > 0:
1303
+ input_query_class[i, :num_gt] = targets[i]["class_labels"]
1304
+ input_query_bbox[i, :num_gt] = targets[i]["boxes"]
1305
+ pad_gt_mask[i, :num_gt] = 1
1306
+ # each group has positive and negative queries.
1307
+ input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
1308
+ input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
1309
+ pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
1310
+ # positive and negative mask
1311
+ negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
1312
+ negative_gt_mask[:, max_gt_num:] = 1
1313
+ negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
1314
+ positive_gt_mask = 1 - negative_gt_mask
1315
+ # contrastive denoising training positive index
1316
+ positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
1317
+ denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
1318
+ denoise_positive_idx = torch.split(
1319
+ denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
1320
+ )
1321
+ # total denoising queries
1322
+ num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
1323
+
1324
+ if label_noise_ratio > 0:
1325
+ mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
1326
+ # randomly put a new one here
1327
+ new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
1328
+ input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
1329
+
1330
+ if box_noise_scale > 0:
1331
+ known_bbox = center_to_corners_format(input_query_bbox)
1332
+ diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
1333
+ rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
1334
+ rand_part = torch.rand_like(input_query_bbox)
1335
+ rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
1336
+ rand_part *= rand_sign
1337
+ known_bbox += rand_part * diff
1338
+ known_bbox.clip_(min=0.0, max=1.0)
1339
+ input_query_bbox = corners_to_center_format(known_bbox)
1340
+ input_query_bbox = inverse_sigmoid(input_query_bbox)
1341
+
1342
+ input_query_class = class_embed(input_query_class)
1343
+
1344
+ target_size = num_denoising_queries + num_queries
1345
+ attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
1346
+ # match query cannot see the reconstruction
1347
+ attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
1348
+
1349
+ # reconstructions cannot see each other
1350
+ for i in range(num_groups_denoising_queries):
1351
+ idx_block_start = max_gt_num * 2 * i
1352
+ idx_block_end = max_gt_num * 2 * (i + 1)
1353
+ attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
1354
+ attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
1355
+
1356
+ denoising_meta_values = {
1357
+ "dn_positive_idx": denoise_positive_idx,
1358
+ "dn_num_group": num_groups_denoising_queries,
1359
+ "dn_num_split": [num_denoising_queries, num_queries],
1360
+ }
1361
+
1362
+ return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
1470
1363
 
1471
1364
 
1472
1365
  @auto_docstring(
@@ -1486,8 +1379,8 @@ class RTDetrModel(RTDetrPreTrainedModel):
1486
1379
  # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py#L212
1487
1380
  num_backbone_outs = len(intermediate_channel_sizes)
1488
1381
  encoder_input_proj_list = []
1489
- for _ in range(num_backbone_outs):
1490
- in_channels = intermediate_channel_sizes[_]
1382
+ for i in range(num_backbone_outs):
1383
+ in_channels = intermediate_channel_sizes[i]
1491
1384
  encoder_input_proj_list.append(
1492
1385
  nn.Sequential(
1493
1386
  nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
@@ -1515,7 +1408,7 @@ class RTDetrModel(RTDetrPreTrainedModel):
1515
1408
  nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
1516
1409
  )
1517
1410
  self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
1518
- self.enc_bbox_head = RTDetrMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3)
1411
+ self.enc_bbox_head = RTDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
1519
1412
 
1520
1413
  # init encoder output anchors and valid_mask
1521
1414
  if config.anchor_image_size:
@@ -1525,8 +1418,8 @@ class RTDetrModel(RTDetrPreTrainedModel):
1525
1418
  # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
1526
1419
  num_backbone_outs = len(config.decoder_in_channels)
1527
1420
  decoder_input_proj_list = []
1528
- for _ in range(num_backbone_outs):
1529
- in_channels = config.decoder_in_channels[_]
1421
+ for i in range(num_backbone_outs):
1422
+ in_channels = config.decoder_in_channels[i]
1530
1423
  decoder_input_proj_list.append(
1531
1424
  nn.Sequential(
1532
1425
  nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
@@ -1586,26 +1479,20 @@ class RTDetrModel(RTDetrPreTrainedModel):
1586
1479
  return anchors, valid_mask
1587
1480
 
1588
1481
  @auto_docstring
1482
+ @can_return_tuple
1589
1483
  def forward(
1590
1484
  self,
1591
1485
  pixel_values: torch.FloatTensor,
1592
- pixel_mask: Optional[torch.LongTensor] = None,
1593
- encoder_outputs: Optional[torch.FloatTensor] = None,
1594
- inputs_embeds: Optional[torch.FloatTensor] = None,
1595
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1596
- labels: Optional[list[dict]] = None,
1597
- output_attentions: Optional[bool] = None,
1598
- output_hidden_states: Optional[bool] = None,
1599
- return_dict: Optional[bool] = None,
1600
- **kwargs,
1601
- ) -> Union[tuple[torch.FloatTensor], RTDetrModelOutput]:
1486
+ pixel_mask: torch.LongTensor | None = None,
1487
+ encoder_outputs: torch.FloatTensor | None = None,
1488
+ inputs_embeds: torch.FloatTensor | None = None,
1489
+ labels: list[dict] | None = None,
1490
+ **kwargs: Unpack[TransformersKwargs],
1491
+ ) -> tuple[torch.FloatTensor] | RTDetrModelOutput:
1602
1492
  r"""
1603
1493
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1604
1494
  Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1605
1495
  can choose to directly pass a flattened representation of an image.
1606
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1607
- Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1608
- embedded representation.
1609
1496
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1610
1497
  Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1611
1498
  following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
@@ -1633,53 +1520,46 @@ class RTDetrModel(RTDetrPreTrainedModel):
1633
1520
  >>> list(last_hidden_states.shape)
1634
1521
  [1, 300, 256]
1635
1522
  ```"""
1636
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1637
- output_hidden_states = (
1638
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1639
- )
1640
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1641
-
1642
- batch_size, num_channels, height, width = pixel_values.shape
1643
- device = pixel_values.device
1644
-
1645
- if pixel_mask is None:
1646
- pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1647
-
1648
- features = self.backbone(pixel_values, pixel_mask)
1649
-
1650
- proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
1523
+ if pixel_values is None and inputs_embeds is None:
1524
+ raise ValueError("You have to specify either pixel_values or inputs_embeds")
1525
+
1526
+ if inputs_embeds is None:
1527
+ batch_size, num_channels, height, width = pixel_values.shape
1528
+ device = pixel_values.device
1529
+ if pixel_mask is None:
1530
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1531
+ features = self.backbone(pixel_values, pixel_mask)
1532
+ proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
1533
+ else:
1534
+ batch_size = inputs_embeds.shape[0]
1535
+ device = inputs_embeds.device
1536
+ proj_feats = inputs_embeds
1651
1537
 
1652
1538
  if encoder_outputs is None:
1653
1539
  encoder_outputs = self.encoder(
1654
1540
  proj_feats,
1655
- output_attentions=output_attentions,
1656
- output_hidden_states=output_hidden_states,
1657
- return_dict=return_dict,
1541
+ **kwargs,
1658
1542
  )
1659
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1660
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1543
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
1544
+ elif not isinstance(encoder_outputs, BaseModelOutput):
1661
1545
  encoder_outputs = BaseModelOutput(
1662
1546
  last_hidden_state=encoder_outputs[0],
1663
- hidden_states=encoder_outputs[1] if output_hidden_states else None,
1664
- attentions=encoder_outputs[2]
1665
- if len(encoder_outputs) > 2
1666
- else encoder_outputs[1]
1667
- if output_attentions
1668
- else None,
1547
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1548
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1669
1549
  )
1670
1550
 
1671
1551
  # Equivalent to def _get_encoder_input
1672
1552
  # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
1673
1553
  sources = []
1674
- for level, source in enumerate(encoder_outputs[0]):
1554
+ for level, source in enumerate(encoder_outputs.last_hidden_state):
1675
1555
  sources.append(self.decoder_input_proj[level](source))
1676
1556
 
1677
1557
  # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
1678
1558
  if self.config.num_feature_levels > len(sources):
1679
1559
  _len_sources = len(sources)
1680
- sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1])
1560
+ sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1])
1681
1561
  for i in range(_len_sources + 1, self.config.num_feature_levels):
1682
- sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1]))
1562
+ sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1]))
1683
1563
 
1684
1564
  # Prepare encoder inputs (by flattening)
1685
1565
  source_flatten = []
@@ -1771,22 +1651,9 @@ class RTDetrModel(RTDetrPreTrainedModel):
1771
1651
  spatial_shapes=spatial_shapes,
1772
1652
  spatial_shapes_list=spatial_shapes_list,
1773
1653
  level_start_index=level_start_index,
1774
- output_attentions=output_attentions,
1775
- output_hidden_states=output_hidden_states,
1776
- return_dict=return_dict,
1654
+ **kwargs,
1777
1655
  )
1778
1656
 
1779
- if not return_dict:
1780
- enc_outputs = tuple(
1781
- value
1782
- for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
1783
- if value is not None
1784
- )
1785
- dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
1786
- tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
1787
-
1788
- return tuple_outputs
1789
-
1790
1657
  return RTDetrModelOutput(
1791
1658
  last_hidden_state=decoder_outputs.last_hidden_state,
1792
1659
  intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
@@ -1828,7 +1695,7 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1828
1695
  [torch.nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]
1829
1696
  )
1830
1697
  self.model.decoder.bbox_embed = nn.ModuleList(
1831
- [RTDetrMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3) for _ in range(num_pred)]
1698
+ [RTDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) for _ in range(num_pred)]
1832
1699
  )
1833
1700
  # if two-stage, the last class_embed and bbox_embed is for region proposal generation
1834
1701
  self.post_init()
@@ -1837,26 +1704,20 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1837
1704
  return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
1838
1705
 
1839
1706
  @auto_docstring
1707
+ @can_return_tuple
1840
1708
  def forward(
1841
1709
  self,
1842
1710
  pixel_values: torch.FloatTensor,
1843
- pixel_mask: Optional[torch.LongTensor] = None,
1844
- encoder_outputs: Optional[torch.FloatTensor] = None,
1845
- inputs_embeds: Optional[torch.FloatTensor] = None,
1846
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1847
- labels: Optional[list[dict]] = None,
1848
- output_attentions: Optional[bool] = None,
1849
- output_hidden_states: Optional[bool] = None,
1850
- return_dict: Optional[bool] = None,
1851
- **kwargs,
1852
- ) -> Union[tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]:
1711
+ pixel_mask: torch.LongTensor | None = None,
1712
+ encoder_outputs: torch.FloatTensor | None = None,
1713
+ inputs_embeds: torch.FloatTensor | None = None,
1714
+ labels: list[dict] | None = None,
1715
+ **kwargs: Unpack[TransformersKwargs],
1716
+ ) -> tuple[torch.FloatTensor] | RTDetrObjectDetectionOutput:
1853
1717
  r"""
1854
1718
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1855
1719
  Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1856
1720
  can choose to directly pass a flattened representation of an image.
1857
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1858
- Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1859
- embedded representation.
1860
1721
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1861
1722
  Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1862
1723
  following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
@@ -1909,40 +1770,29 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1909
1770
  Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48]
1910
1771
  Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99]
1911
1772
  ```"""
1912
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1913
- output_hidden_states = (
1914
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1915
- )
1916
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1917
-
1918
1773
  outputs = self.model(
1919
1774
  pixel_values,
1920
1775
  pixel_mask=pixel_mask,
1921
1776
  encoder_outputs=encoder_outputs,
1922
1777
  inputs_embeds=inputs_embeds,
1923
- decoder_inputs_embeds=decoder_inputs_embeds,
1924
1778
  labels=labels,
1925
- output_attentions=output_attentions,
1926
- output_hidden_states=output_hidden_states,
1927
- return_dict=return_dict,
1779
+ **kwargs,
1928
1780
  )
1929
1781
 
1930
- denoising_meta_values = (
1931
- outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
1932
- )
1782
+ denoising_meta_values = outputs.denoising_meta_values if self.training else None
1933
1783
 
1934
- outputs_class = outputs.intermediate_logits if return_dict else outputs[2]
1935
- outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3]
1936
- predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4]
1937
- initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5]
1784
+ outputs_class = outputs.intermediate_logits
1785
+ outputs_coord = outputs.intermediate_reference_points
1786
+ predicted_corners = outputs.intermediate_predicted_corners
1787
+ initial_reference_points = outputs.initial_reference_points
1938
1788
 
1939
1789
  logits = outputs_class[:, -1]
1940
1790
  pred_boxes = outputs_coord[:, -1]
1941
1791
 
1942
1792
  loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
1943
1793
  if labels is not None:
1944
- enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
1945
- enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
1794
+ enc_topk_logits = outputs.enc_topk_logits
1795
+ enc_topk_bboxes = outputs.enc_topk_bboxes
1946
1796
  loss, loss_dict, auxiliary_outputs = self.loss_function(
1947
1797
  logits,
1948
1798
  labels,
@@ -1959,13 +1809,6 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1959
1809
  **kwargs,
1960
1810
  )
1961
1811
 
1962
- if not return_dict:
1963
- if auxiliary_outputs is not None:
1964
- output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
1965
- else:
1966
- output = (logits, pred_boxes) + outputs
1967
- return ((loss, loss_dict) + output) if loss is not None else output
1968
-
1969
1812
  return RTDetrObjectDetectionOutput(
1970
1813
  loss=loss,
1971
1814
  loss_dict=loss_dict,
@@ -1993,8 +1836,4 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
1993
1836
  )
1994
1837
 
1995
1838
 
1996
- __all__ = [
1997
- "RTDetrForObjectDetection",
1998
- "RTDetrModel",
1999
- "RTDetrPreTrainedModel",
2000
- ]
1839
+ __all__ = ["RTDetrForObjectDetection", "RTDetrModel", "RTDetrPreTrainedModel"]