transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,6 @@
4
4
  # the file from the modular. If any change should be done, please apply the change to the
5
5
  # modular_d_fine.py file directly. One of our CI enforces this.
6
6
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
- # coding=utf-8
8
7
  # Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
9
8
  #
10
9
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,24 +18,134 @@
19
18
  # See the License for the specific language governing permissions and
20
19
  # limitations under the License.
21
20
  import math
21
+ from collections.abc import Callable
22
22
  from dataclasses import dataclass
23
- from typing import Any, Optional, Union
24
23
 
25
24
  import torch
25
+ import torch.nn as nn
26
26
  import torch.nn.functional as F
27
- from torch import Tensor, nn
27
+ from torch import Tensor
28
28
 
29
29
  from ... import initialization as init
30
- from ...activations import ACT2CLS, ACT2FN
30
+ from ...activations import ACT2CLS
31
+ from ...backbone_utils import load_backbone
31
32
  from ...image_transforms import center_to_corners_format, corners_to_center_format
32
33
  from ...modeling_outputs import BaseModelOutput
33
- from ...modeling_utils import PreTrainedModel
34
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
35
+ from ...processing_utils import Unpack
34
36
  from ...pytorch_utils import compile_compatible_method_lru_cache
35
- from ...utils import ModelOutput, auto_docstring, is_torchdynamo_compiling, torch_int
36
- from ...utils.backbone_utils import load_backbone
37
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check, torch_int
38
+ from ...utils.generic import can_return_tuple, check_model_inputs
37
39
  from .configuration_d_fine import DFineConfig
38
40
 
39
41
 
42
+ @dataclass
43
+ @auto_docstring(
44
+ custom_intro="""
45
+ Base class for outputs of the DFineDecoder. This class adds two attributes to
46
+ BaseModelOutputWithCrossAttentions, namely:
47
+ - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
48
+ - a stacked tensor of intermediate reference points.
49
+ """
50
+ )
51
+ class DFineDecoderOutput(ModelOutput):
52
+ r"""
53
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
54
+ Stacked intermediate hidden states (output of each layer of the decoder).
55
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
56
+ Stacked intermediate logits (logits of each layer of the decoder).
57
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
58
+ Stacked intermediate reference points (reference points of each layer of the decoder).
59
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
60
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
61
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
62
+ Stacked initial reference points (initial reference points of each layer of the decoder).
63
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
64
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
65
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
66
+ used to compute the weighted average in the cross-attention heads.
67
+ """
68
+
69
+ last_hidden_state: torch.FloatTensor | None = None
70
+ intermediate_hidden_states: torch.FloatTensor | None = None
71
+ intermediate_logits: torch.FloatTensor | None = None
72
+ intermediate_reference_points: torch.FloatTensor | None = None
73
+ intermediate_predicted_corners: torch.FloatTensor | None = None
74
+ initial_reference_points: torch.FloatTensor | None = None
75
+ hidden_states: tuple[torch.FloatTensor] | None = None
76
+ attentions: tuple[torch.FloatTensor] | None = None
77
+ cross_attentions: tuple[torch.FloatTensor] | None = None
78
+
79
+
80
+ class DFineMLP(nn.Module):
81
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"):
82
+ super().__init__()
83
+ self.num_layers = num_layers
84
+ hidden_dims = [hidden_dim] * (num_layers - 1)
85
+ input_dims = [input_dim] + hidden_dims
86
+ output_dims = hidden_dims + [output_dim]
87
+ self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims))
88
+ self.act = ACT2CLS[act]()
89
+
90
+ def forward(self, stat_features: torch.Tensor) -> torch.Tensor:
91
+ for i, layer in enumerate(self.layers):
92
+ stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features)
93
+ return stat_features
94
+
95
+
96
+ class DFineGate(nn.Module):
97
+ def __init__(self, d_model: int):
98
+ super().__init__()
99
+ self.gate = nn.Linear(2 * d_model, 2 * d_model)
100
+ self.norm = nn.LayerNorm(d_model)
101
+
102
+ def forward(self, second_residual: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
103
+ gate_input = torch.cat([second_residual, hidden_states], dim=-1)
104
+ gates = torch.sigmoid(self.gate(gate_input))
105
+ gate1, gate2 = gates.chunk(2, dim=-1)
106
+ hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states)
107
+ return hidden_states
108
+
109
+
110
+ class DFineFrozenBatchNorm2d(nn.Module):
111
+ """
112
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
113
+
114
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
115
+ torchvision.models.resnet[18,34,50,101] produce nans.
116
+ """
117
+
118
+ def __init__(self, n):
119
+ super().__init__()
120
+ self.register_buffer("weight", torch.ones(n))
121
+ self.register_buffer("bias", torch.zeros(n))
122
+ self.register_buffer("running_mean", torch.zeros(n))
123
+ self.register_buffer("running_var", torch.ones(n))
124
+
125
+ def _load_from_state_dict(
126
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
127
+ ):
128
+ num_batches_tracked_key = prefix + "num_batches_tracked"
129
+ if num_batches_tracked_key in state_dict:
130
+ del state_dict[num_batches_tracked_key]
131
+
132
+ super()._load_from_state_dict(
133
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
134
+ )
135
+
136
+ def forward(self, x):
137
+ # move reshapes to the beginning
138
+ # to make it user-friendly
139
+ weight = self.weight.reshape(1, -1, 1, 1)
140
+ bias = self.bias.reshape(1, -1, 1, 1)
141
+ running_var = self.running_var.reshape(1, -1, 1, 1)
142
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
143
+ epsilon = 1e-5
144
+ scale = weight * (running_var + epsilon).rsqrt()
145
+ bias = bias - running_mean * scale
146
+ return x * scale + bias
147
+
148
+
40
149
  def multi_scale_deformable_attention_v2(
41
150
  value: Tensor,
42
151
  value_spatial_shapes: Tensor,
@@ -143,19 +252,20 @@ class DFineMultiscaleDeformableAttention(nn.Module):
143
252
  def forward(
144
253
  self,
145
254
  hidden_states: torch.Tensor,
146
- attention_mask: Optional[torch.Tensor] = None,
255
+ attention_mask: torch.Tensor | None = None,
147
256
  reference_points=None,
148
257
  encoder_hidden_states=None,
149
258
  spatial_shapes=None,
150
259
  spatial_shapes_list=None,
260
+ **kwargs: Unpack[TransformersKwargs],
151
261
  ) -> tuple[torch.Tensor, torch.Tensor]:
152
262
  batch_size, num_queries, _ = hidden_states.shape
153
263
  batch_size, sequence_length, _ = encoder_hidden_states.shape
154
264
 
155
- if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
156
- raise ValueError(
157
- "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
158
- )
265
+ torch_compilable_check(
266
+ (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == sequence_length,
267
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
268
+ )
159
269
 
160
270
  # Reshape for multi-head attention
161
271
  value = encoder_hidden_states.reshape(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
@@ -202,182 +312,485 @@ class DFineMultiscaleDeformableAttention(nn.Module):
202
312
  return output, attention_weights
203
313
 
204
314
 
205
- class DFineGate(nn.Module):
206
- def __init__(self, d_model: int):
315
+ class DFineConvNormLayer(nn.Module):
316
+ def __init__(
317
+ self,
318
+ config: DFineConfig,
319
+ in_channels: int,
320
+ out_channels: int,
321
+ kernel_size: int,
322
+ stride: int,
323
+ groups: int = 1,
324
+ padding: int | None = None,
325
+ activation: str | None = None,
326
+ ):
207
327
  super().__init__()
208
- self.gate = nn.Linear(2 * d_model, 2 * d_model)
209
- self.norm = nn.LayerNorm(d_model)
328
+ self.conv = nn.Conv2d(
329
+ in_channels,
330
+ out_channels,
331
+ kernel_size,
332
+ stride,
333
+ groups=groups,
334
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
335
+ bias=False,
336
+ )
337
+ self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
338
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
210
339
 
211
- def forward(self, second_residual: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
212
- gate_input = torch.cat([second_residual, hidden_states], dim=-1)
213
- gates = torch.sigmoid(self.gate(gate_input))
214
- gate1, gate2 = gates.chunk(2, dim=-1)
215
- hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states)
216
- return hidden_states
340
+ def forward(self, hidden_state):
341
+ hidden_state = self.conv(hidden_state)
342
+ hidden_state = self.norm(hidden_state)
343
+ hidden_state = self.activation(hidden_state)
344
+ return hidden_state
217
345
 
218
346
 
219
- class DFineMultiheadAttention(nn.Module):
347
+ class DFineRepVggBlock(nn.Module):
348
+ """
349
+ RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
220
350
  """
221
- Multi-headed attention from 'Attention Is All You Need' paper.
222
351
 
223
- Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
352
+ def __init__(self, config: DFineConfig, in_channels: int, out_channels: int):
353
+ super().__init__()
354
+
355
+ activation = config.activation_function
356
+ hidden_channels = in_channels
357
+ self.conv1 = DFineConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1)
358
+ self.conv2 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0)
359
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
360
+
361
+ def forward(self, x):
362
+ y = self.conv1(x) + self.conv2(x)
363
+ return self.activation(y)
364
+
365
+
366
+ class DFineCSPRepLayer(nn.Module):
367
+ """
368
+ Cross Stage Partial (CSP) network layer with RepVGG blocks.
224
369
  """
225
370
 
226
371
  def __init__(
227
- self,
228
- embed_dim: int,
229
- num_heads: int,
230
- dropout: float = 0.0,
231
- bias: bool = True,
372
+ self, config: DFineConfig, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0
232
373
  ):
233
374
  super().__init__()
234
- self.embed_dim = embed_dim
235
- self.num_heads = num_heads
236
- self.dropout = dropout
237
- self.head_dim = embed_dim // num_heads
238
- if self.head_dim * num_heads != self.embed_dim:
239
- raise ValueError(
240
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
241
- f" {num_heads})."
242
- )
243
- self.scaling = self.head_dim**-0.5
375
+ activation = config.activation_function
244
376
 
245
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
246
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
247
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
248
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
377
+ hidden_channels = int(out_channels * expansion)
378
+ self.conv1 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
379
+ self.conv2 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
380
+ self.bottlenecks = nn.ModuleList(
381
+ [DFineRepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)]
382
+ )
383
+ if hidden_channels != out_channels:
384
+ self.conv3 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
385
+ else:
386
+ self.conv3 = nn.Identity()
249
387
 
250
- def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
251
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
388
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
389
+ hidden_state_1 = self.conv1(hidden_state)
390
+ for bottleneck in self.bottlenecks:
391
+ hidden_state_1 = bottleneck(hidden_state_1)
392
+ hidden_state_2 = self.conv2(hidden_state)
393
+ hidden_state_3 = self.conv3(hidden_state_1 + hidden_state_2)
394
+ return hidden_state_3
252
395
 
253
- def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
254
- return tensor if position_embeddings is None else tensor + position_embeddings
255
396
 
256
- def forward(
257
- self,
258
- hidden_states: torch.Tensor,
259
- attention_mask: Optional[torch.Tensor] = None,
260
- position_embeddings: Optional[torch.Tensor] = None,
261
- output_attentions: bool = False,
262
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
263
- """Input shape: Batch x Time x Channel"""
397
+ class DFineRepNCSPELAN4(nn.Module):
398
+ def __init__(self, config: DFineConfig, act: str = "silu", numb_blocks: int = 3):
399
+ super().__init__()
400
+ conv1_dim = config.encoder_hidden_dim * 2
401
+ conv2_dim = config.encoder_hidden_dim
402
+ conv3_dim = config.encoder_hidden_dim * 2
403
+ conv4_dim = round(config.hidden_expansion * config.encoder_hidden_dim // 2)
404
+ self.conv_dim = conv3_dim // 2
405
+ self.conv1 = DFineConvNormLayer(config, conv1_dim, conv3_dim, 1, 1, activation=act)
406
+ self.csp_rep1 = DFineCSPRepLayer(config, conv3_dim // 2, conv4_dim, num_blocks=numb_blocks)
407
+ self.conv2 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
408
+ self.csp_rep2 = DFineCSPRepLayer(config, conv4_dim, conv4_dim, num_blocks=numb_blocks)
409
+ self.conv3 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
410
+ self.conv4 = DFineConvNormLayer(config, conv3_dim + (2 * conv4_dim), conv2_dim, 1, 1, activation=act)
264
411
 
265
- batch_size, target_len, embed_dim = hidden_states.size()
266
- # add position embeddings to the hidden states before projecting to queries and keys
267
- if position_embeddings is not None:
268
- hidden_states_original = hidden_states
269
- hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
412
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
413
+ # Split initial features into two branches after first convolution
414
+ split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1))
415
+
416
+ # Process branches sequentially
417
+ branch1 = self.csp_rep1(split_features[-1])
418
+ branch1 = self.conv2(branch1)
419
+ branch2 = self.csp_rep2(branch1)
420
+ branch2 = self.conv3(branch2)
270
421
 
271
- # get queries, keys and values
272
- query_states = self.q_proj(hidden_states) * self.scaling
273
- key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size)
274
- value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size)
422
+ split_features.extend([branch1, branch2])
423
+ merged_features = torch.cat(split_features, 1)
424
+ merged_features = self.conv4(merged_features)
425
+ return merged_features
275
426
 
276
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
277
- query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape)
278
- key_states = key_states.view(*proj_shape)
279
- value_states = value_states.view(*proj_shape)
280
427
 
281
- source_len = key_states.size(1)
428
+ class DFineSCDown(nn.Module):
429
+ def __init__(self, config: DFineConfig, kernel_size: int, stride: int):
430
+ super().__init__()
431
+ self.conv1 = DFineConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1)
432
+ self.conv2 = DFineConvNormLayer(
433
+ config,
434
+ config.encoder_hidden_dim,
435
+ config.encoder_hidden_dim,
436
+ kernel_size,
437
+ stride,
438
+ config.encoder_hidden_dim,
439
+ )
282
440
 
283
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
441
+ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
442
+ input_features = self.conv1(input_features)
443
+ input_features = self.conv2(input_features)
444
+ return input_features
284
445
 
285
- if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
286
- raise ValueError(
287
- f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
288
- f" {attn_weights.size()}"
289
- )
290
446
 
291
- # expand attention_mask
292
- if attention_mask is not None:
293
- # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
294
- attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size())
447
+ def eager_attention_forward(
448
+ module: nn.Module,
449
+ query: torch.Tensor,
450
+ key: torch.Tensor,
451
+ value: torch.Tensor,
452
+ attention_mask: torch.Tensor | None,
453
+ scaling: float | None = None,
454
+ dropout: float = 0.0,
455
+ **kwargs: Unpack[TransformersKwargs],
456
+ ):
457
+ if scaling is None:
458
+ scaling = query.size(-1) ** -0.5
295
459
 
296
- if attention_mask is not None:
297
- if attention_mask.size() != (batch_size, 1, target_len, source_len):
298
- raise ValueError(
299
- f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
300
- f" {attention_mask.size()}"
301
- )
302
- if attention_mask.dtype == torch.bool:
303
- attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_(
304
- attention_mask, -torch.inf
305
- )
306
- attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
307
- attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
308
-
309
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
310
-
311
- if output_attentions:
312
- # this operation is a bit awkward, but it's required to
313
- # make sure that attn_weights keeps its gradient.
314
- # In order to do so, attn_weights have to reshaped
315
- # twice and have to be reused in the following
316
- attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
317
- attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
318
- else:
319
- attn_weights_reshaped = None
460
+ # Take the dot product between "query" and "key" to get the raw attention scores.
461
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
320
462
 
321
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
463
+ if attention_mask is not None:
464
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
465
+ attn_weights = attn_weights + attention_mask
322
466
 
323
- attn_output = torch.bmm(attn_probs, value_states)
467
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
468
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
324
469
 
325
- if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
326
- raise ValueError(
327
- f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
328
- f" {attn_output.size()}"
329
- )
470
+ attn_output = torch.matmul(attn_weights, value)
471
+ attn_output = attn_output.transpose(1, 2).contiguous()
330
472
 
331
- attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
332
- attn_output = attn_output.transpose(1, 2)
333
- attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
473
+ return attn_output, attn_weights
334
474
 
335
- attn_output = self.out_proj(attn_output)
336
475
 
337
- return attn_output, attn_weights_reshaped
476
+ class DFineSelfAttention(nn.Module):
477
+ """
478
+ Multi-headed self-attention from 'Attention Is All You Need' paper.
338
479
 
480
+ In D_FINE, position embeddings are added to both queries and keys (but not values) in self-attention.
481
+ """
339
482
 
340
- class DFineDecoderLayer(nn.Module):
341
- def __init__(self, config: DFineConfig):
483
+ def __init__(
484
+ self,
485
+ config: DFineConfig,
486
+ hidden_size: int,
487
+ num_attention_heads: int,
488
+ dropout: float = 0.0,
489
+ bias: bool = True,
490
+ ):
342
491
  super().__init__()
343
- # self-attention
344
- self.self_attn = DFineMultiheadAttention(
345
- embed_dim=config.d_model,
346
- num_heads=config.decoder_attention_heads,
347
- dropout=config.attention_dropout,
348
- )
349
- self.dropout = config.dropout
350
- self.activation_fn = ACT2FN[config.decoder_activation_function]
351
- self.activation_dropout = config.activation_dropout
492
+ self.config = config
493
+ self.head_dim = hidden_size // num_attention_heads
494
+ self.scaling = self.head_dim**-0.5
495
+ self.attention_dropout = dropout
496
+ self.is_causal = False
352
497
 
353
- self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
498
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
499
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
500
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
501
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
354
502
 
355
- # override the encoder attention module with d-fine version
503
+ def forward(
504
+ self,
505
+ hidden_states: torch.Tensor,
506
+ attention_mask: torch.Tensor | None = None,
507
+ position_embeddings: torch.Tensor | None = None,
508
+ **kwargs: Unpack[TransformersKwargs],
509
+ ) -> tuple[torch.Tensor, torch.Tensor]:
510
+ """
511
+ Position embeddings are added to both queries and keys (but not values).
512
+ """
513
+ input_shape = hidden_states.shape[:-1]
514
+ hidden_shape = (*input_shape, -1, self.head_dim)
515
+
516
+ query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
517
+
518
+ query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
519
+ key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
520
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
521
+
522
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
523
+ self.config._attn_implementation, eager_attention_forward
524
+ )
525
+
526
+ attn_output, attn_weights = attention_interface(
527
+ self,
528
+ query_states,
529
+ key_states,
530
+ value_states,
531
+ attention_mask,
532
+ dropout=0.0 if not self.training else self.attention_dropout,
533
+ scaling=self.scaling,
534
+ **kwargs,
535
+ )
536
+
537
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
538
+ attn_output = self.o_proj(attn_output)
539
+ return attn_output, attn_weights
540
+
541
+
542
+ class DFineEncoderLayer(nn.Module):
543
+ def __init__(self, config: DFineConfig):
544
+ super().__init__()
545
+ self.normalize_before = config.normalize_before
546
+ self.hidden_size = config.encoder_hidden_dim
547
+
548
+ # self-attention
549
+ self.self_attn = DFineSelfAttention(
550
+ config=config,
551
+ hidden_size=self.hidden_size,
552
+ num_attention_heads=config.num_attention_heads,
553
+ dropout=config.dropout,
554
+ )
555
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
556
+ self.dropout = config.dropout
557
+ self.mlp = DFineMLP(
558
+ self.hidden_size, config.encoder_ffn_dim, self.hidden_size, 2, config.encoder_activation_function
559
+ )
560
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
561
+
562
+ def forward(
563
+ self,
564
+ hidden_states: torch.Tensor,
565
+ attention_mask: torch.Tensor,
566
+ spatial_position_embeddings: torch.Tensor | None = None,
567
+ **kwargs: Unpack[TransformersKwargs],
568
+ ) -> torch.Tensor:
569
+ """
570
+ Args:
571
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
572
+ attention_mask (`torch.FloatTensor`): attention mask of size
573
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
574
+ values.
575
+ spatial_position_embeddings (`torch.FloatTensor`, *optional*):
576
+ Spatial position embeddings (2D positional encodings of image locations), to be added to both
577
+ the queries and keys in self-attention (but not to values).
578
+ """
579
+ residual = hidden_states
580
+ if self.normalize_before:
581
+ hidden_states = self.self_attn_layer_norm(hidden_states)
582
+
583
+ hidden_states, _ = self.self_attn(
584
+ hidden_states=hidden_states,
585
+ attention_mask=attention_mask,
586
+ position_embeddings=spatial_position_embeddings,
587
+ **kwargs,
588
+ )
589
+
590
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
591
+ hidden_states = residual + hidden_states
592
+ if not self.normalize_before:
593
+ hidden_states = self.self_attn_layer_norm(hidden_states)
594
+
595
+ if self.normalize_before:
596
+ hidden_states = self.final_layer_norm(hidden_states)
597
+ residual = hidden_states
598
+
599
+ hidden_states = self.mlp(hidden_states)
600
+
601
+ hidden_states = residual + hidden_states
602
+ if not self.normalize_before:
603
+ hidden_states = self.final_layer_norm(hidden_states)
604
+
605
+ if self.training:
606
+ if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
607
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
608
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
609
+
610
+ return hidden_states
611
+
612
+
613
+ class DFineSinePositionEmbedding(nn.Module):
614
+ """
615
+ 2D sinusoidal position embedding used in RT-DETR hybrid encoder.
616
+ """
617
+
618
+ def __init__(self, embed_dim: int = 256, temperature: int = 10000):
619
+ super().__init__()
620
+ self.embed_dim = embed_dim
621
+ self.temperature = temperature
622
+
623
+ @compile_compatible_method_lru_cache(maxsize=32)
624
+ def forward(
625
+ self,
626
+ width: int,
627
+ height: int,
628
+ device: torch.device | str,
629
+ dtype: torch.dtype,
630
+ ) -> torch.Tensor:
631
+ """
632
+ Generate 2D sinusoidal position embeddings.
633
+
634
+ Returns:
635
+ Position embeddings of shape (1, height*width, embed_dim)
636
+ """
637
+ grid_w = torch.arange(torch_int(width), device=device).to(dtype)
638
+ grid_h = torch.arange(torch_int(height), device=device).to(dtype)
639
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
640
+ if self.embed_dim % 4 != 0:
641
+ raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
642
+ pos_dim = self.embed_dim // 4
643
+ omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
644
+ omega = 1.0 / (self.temperature**omega)
645
+
646
+ out_w = grid_w.flatten()[..., None] @ omega[None]
647
+ out_h = grid_h.flatten()[..., None] @ omega[None]
648
+
649
+ return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
650
+
651
+
652
+ class DFineAIFILayer(nn.Module):
653
+ """
654
+ AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder.
655
+ """
656
+
657
+ def __init__(self, config: DFineConfig):
658
+ super().__init__()
659
+ self.config = config
660
+ self.encoder_hidden_dim = config.encoder_hidden_dim
661
+ self.eval_size = config.eval_size
662
+
663
+ self.position_embedding = DFineSinePositionEmbedding(
664
+ embed_dim=self.encoder_hidden_dim,
665
+ temperature=config.positional_encoding_temperature,
666
+ )
667
+ self.layers = nn.ModuleList([DFineEncoderLayer(config) for _ in range(config.encoder_layers)])
668
+
669
+ def forward(
670
+ self,
671
+ hidden_states: torch.Tensor,
672
+ **kwargs: Unpack[TransformersKwargs],
673
+ ) -> torch.Tensor:
674
+ """
675
+ Args:
676
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
677
+ Feature map to process.
678
+ """
679
+ batch_size = hidden_states.shape[0]
680
+ height, width = hidden_states.shape[2:]
681
+
682
+ hidden_states = hidden_states.flatten(2).permute(0, 2, 1)
683
+
684
+ if self.training or self.eval_size is None:
685
+ pos_embed = self.position_embedding(
686
+ width=width,
687
+ height=height,
688
+ device=hidden_states.device,
689
+ dtype=hidden_states.dtype,
690
+ )
691
+ else:
692
+ pos_embed = None
693
+
694
+ for layer in self.layers:
695
+ hidden_states = layer(
696
+ hidden_states,
697
+ attention_mask=None,
698
+ spatial_position_embeddings=pos_embed,
699
+ **kwargs,
700
+ )
701
+
702
+ hidden_states = (
703
+ hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous()
704
+ )
705
+
706
+ return hidden_states
707
+
708
+
709
+ class DFineIntegral(nn.Module):
710
+ """
711
+ A static layer that calculates integral results from a distribution.
712
+
713
+ This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
714
+ where Pr(n) is the softmax probability vector representing the discrete
715
+ distribution, and W(n) is the non-uniform Weighting Function.
716
+
717
+ Args:
718
+ max_num_bins (int): Max number of the discrete bins. Default is 32.
719
+ It can be adjusted based on the dataset or task requirements.
720
+ """
721
+
722
+ def __init__(self, config: DFineConfig):
723
+ super().__init__()
724
+ self.max_num_bins = config.max_num_bins
725
+
726
+ def forward(self, pred_corners: torch.Tensor, project: torch.Tensor) -> torch.Tensor:
727
+ batch_size, num_queries, _ = pred_corners.shape
728
+ pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1)
729
+ pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4)
730
+ pred_corners = pred_corners.reshape(batch_size, num_queries, -1)
731
+ return pred_corners
732
+
733
+
734
+ class DFineLQE(nn.Module):
735
+ def __init__(self, config: DFineConfig):
736
+ super().__init__()
737
+ self.top_prob_values = config.top_prob_values
738
+ self.max_num_bins = config.max_num_bins
739
+ self.reg_conf = DFineMLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers)
740
+
741
+ def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor:
742
+ batch_size, length, _ = pred_corners.size()
743
+ prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1)
744
+ prob_topk, _ = prob.topk(self.top_prob_values, dim=-1)
745
+ stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
746
+ quality_score = self.reg_conf(stat.reshape(batch_size, length, -1))
747
+ scores = scores + quality_score
748
+ return scores
749
+
750
+
751
+ class DFineDecoderLayer(nn.Module):
752
+ def __init__(self, config: DFineConfig):
753
+ super().__init__()
754
+ self.hidden_size = config.d_model
755
+
756
+ # self-attention
757
+ self.self_attn = DFineSelfAttention(
758
+ config=config,
759
+ hidden_size=self.hidden_size,
760
+ num_attention_heads=config.decoder_attention_heads,
761
+ dropout=config.attention_dropout,
762
+ )
763
+ self.dropout = config.dropout
764
+
765
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
766
+
767
+ # override the encoder attention module with d-fine version
356
768
  self.encoder_attn = DFineMultiscaleDeformableAttention(config=config)
357
- # feedforward neural networks
358
- self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
359
- self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
360
- self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
769
+ self.mlp = DFineMLP(
770
+ self.hidden_size, config.decoder_ffn_dim, self.hidden_size, 2, config.decoder_activation_function
771
+ )
772
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
361
773
  # gate
362
774
  self.gateway = DFineGate(config.d_model)
363
775
 
364
776
  def forward(
365
777
  self,
366
778
  hidden_states: torch.Tensor,
367
- position_embeddings: Optional[torch.Tensor] = None,
779
+ position_embeddings: torch.Tensor | None = None,
368
780
  reference_points=None,
369
781
  spatial_shapes=None,
370
782
  spatial_shapes_list=None,
371
- encoder_hidden_states: Optional[torch.Tensor] = None,
372
- encoder_attention_mask: Optional[torch.Tensor] = None,
373
- output_attentions: Optional[bool] = False,
374
- ) -> tuple[torch.Tensor, Any, Any]:
783
+ encoder_hidden_states: torch.Tensor | None = None,
784
+ encoder_attention_mask: torch.Tensor | None = None,
785
+ **kwargs: Unpack[TransformersKwargs],
786
+ ) -> torch.Tensor:
375
787
  """
376
788
  Args:
377
789
  hidden_states (`torch.FloatTensor`):
378
- Input to the layer of shape `(seq_len, batch, embed_dim)`.
379
- position_embeddings (`torch.FloatTensor`, *optional*):
380
- Position embeddings that are added to the queries and keys in the self-attention layer.
790
+ Input to the layer of shape `(batch, seq_len, hidden_size)`.
791
+ object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
792
+ Position embeddings for the object query slots. These are added to both queries and keys
793
+ in the self-attention layer (not values).
381
794
  reference_points (`torch.FloatTensor`, *optional*):
382
795
  Reference points.
383
796
  spatial_shapes (`torch.LongTensor`, *optional*):
@@ -385,55 +798,65 @@ class DFineDecoderLayer(nn.Module):
385
798
  level_start_index (`torch.LongTensor`, *optional*):
386
799
  Level start index.
387
800
  encoder_hidden_states (`torch.FloatTensor`):
388
- cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
801
+ cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
389
802
  encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
390
803
  `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
391
804
  values.
392
- output_attentions (`bool`, *optional*):
393
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
394
- returned tensors for more detail.
395
805
  """
806
+ residual = hidden_states
807
+
396
808
  # Self Attention
397
- hidden_states_2, self_attn_weights = self.self_attn(
809
+ hidden_states, _ = self.self_attn(
398
810
  hidden_states=hidden_states,
399
811
  attention_mask=encoder_attention_mask,
400
812
  position_embeddings=position_embeddings,
401
- output_attentions=output_attentions,
813
+ **kwargs,
402
814
  )
403
815
 
404
- hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
405
- hidden_states = hidden_states + hidden_states_2
816
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
817
+ hidden_states = residual + hidden_states
406
818
  hidden_states = self.self_attn_layer_norm(hidden_states)
819
+
407
820
  residual = hidden_states
408
821
 
409
822
  # Cross-Attention
410
- cross_attn_weights = None
411
823
  hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
412
- hidden_states_2, cross_attn_weights = self.encoder_attn(
824
+ hidden_states, _ = self.encoder_attn(
413
825
  hidden_states=hidden_states,
414
826
  encoder_hidden_states=encoder_hidden_states,
415
827
  reference_points=reference_points,
416
828
  spatial_shapes=spatial_shapes,
417
829
  spatial_shapes_list=spatial_shapes_list,
418
830
  )
419
-
420
- hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
421
- hidden_states = self.gateway(residual, hidden_states_2)
831
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
832
+ hidden_states = self.gateway(residual, hidden_states)
422
833
 
423
834
  # Fully Connected
424
- hidden_states_2 = self.activation_fn(self.fc1(hidden_states))
425
- hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.activation_dropout, training=self.training)
426
- hidden_states_2 = self.fc2(hidden_states_2)
427
- hidden_states_2 = nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training)
428
- hidden_states = hidden_states + hidden_states_2
835
+ residual = hidden_states
836
+ hidden_states = self.mlp(hidden_states)
837
+ hidden_states = residual + hidden_states
429
838
  hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504))
430
839
 
431
- outputs = (hidden_states,)
840
+ return hidden_states
841
+
842
+
843
+ class DFineMLPPredictionHead(nn.Module):
844
+ """
845
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
846
+ height and width of a bounding box w.r.t. an image.
847
+
848
+ """
432
849
 
433
- if output_attentions:
434
- outputs += (self_attn_weights, cross_attn_weights)
850
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
851
+ super().__init__()
852
+ self.num_layers = num_layers
853
+ h = [hidden_dim] * (num_layers - 1)
854
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
435
855
 
436
- return outputs
856
+ def forward(self, x):
857
+ for i, layer in enumerate(self.layers):
858
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
859
+ return x
437
860
 
438
861
 
439
862
  @auto_docstring
@@ -443,6 +866,10 @@ class DFinePreTrainedModel(PreTrainedModel):
443
866
  main_input_name = "pixel_values"
444
867
  input_modalities = ("image",)
445
868
  _no_split_modules = [r"DFineHybridEncoder", r"DFineDecoderLayer"]
869
+ _supports_sdpa = True
870
+ _supports_flash_attn = True
871
+ _supports_attention_backend = True
872
+ _supports_flex_attn = True
446
873
 
447
874
  @torch.no_grad()
448
875
  def _init_weights(self, module):
@@ -520,67 +947,102 @@ class DFinePreTrainedModel(PreTrainedModel):
520
947
  init.xavier_uniform_(module.denoising_class_embed.weight)
521
948
 
522
949
 
523
- class DFineIntegral(nn.Module):
950
+ class DFineHybridEncoder(DFinePreTrainedModel):
524
951
  """
525
- A static layer that calculates integral results from a distribution.
526
-
527
- This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`,
528
- where Pr(n) is the softmax probability vector representing the discrete
529
- distribution, and W(n) is the non-uniform Weighting Function.
952
+ Hybrid encoder consisting of AIFI (Attention-based Intra-scale Feature Interaction) layers,
953
+ a top-down Feature Pyramid Network (FPN) and a bottom-up Path Aggregation Network (PAN).
954
+ More details on the paper: https://huggingface.co/papers/2304.08069
530
955
 
531
956
  Args:
532
- max_num_bins (int): Max number of the discrete bins. Default is 32.
533
- It can be adjusted based on the dataset or task requirements.
957
+ config: DFineConfig
534
958
  """
535
959
 
960
+ _can_record_outputs = {
961
+ "hidden_states": DFineAIFILayer,
962
+ "attentions": DFineSelfAttention,
963
+ }
964
+
536
965
  def __init__(self, config: DFineConfig):
537
- super().__init__()
538
- self.max_num_bins = config.max_num_bins
966
+ super().__init__(config)
967
+ self.config = config
968
+ self.in_channels = config.encoder_in_channels
969
+ self.num_fpn_stages = len(self.in_channels) - 1
970
+ self.feat_strides = config.feat_strides
971
+ self.encoder_hidden_dim = config.encoder_hidden_dim
972
+ self.encode_proj_layers = config.encode_proj_layers
973
+ self.positional_encoding_temperature = config.positional_encoding_temperature
974
+ self.eval_size = config.eval_size
975
+ self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
976
+ self.out_strides = self.feat_strides
539
977
 
540
- def forward(self, pred_corners: torch.Tensor, project: torch.Tensor) -> torch.Tensor:
541
- batch_size, num_queries, _ = pred_corners.shape
542
- pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1)
543
- pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4)
544
- pred_corners = pred_corners.reshape(batch_size, num_queries, -1)
545
- return pred_corners
978
+ # AIFI (Attention-based Intra-scale Feature Interaction) layers
979
+ self.aifi = nn.ModuleList([DFineAIFILayer(config) for _ in range(len(self.encode_proj_layers))])
546
980
 
981
+ # top-down fpn
982
+ self.lateral_convs = nn.ModuleList()
983
+ self.fpn_blocks = nn.ModuleList()
984
+ for _ in range(len(self.in_channels) - 1, 0, -1):
985
+ lateral_layer = DFineConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1)
986
+ self.lateral_convs.append(lateral_layer)
987
+ num_blocks = round(3 * config.depth_mult)
988
+ fpn_layer = DFineRepNCSPELAN4(config, numb_blocks=num_blocks)
989
+ self.fpn_blocks.append(fpn_layer)
547
990
 
548
- @dataclass
549
- @auto_docstring(
550
- custom_intro="""
551
- Base class for outputs of the DFineDecoder. This class adds two attributes to
552
- BaseModelOutputWithCrossAttentions, namely:
553
- - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
554
- - a stacked tensor of intermediate reference points.
555
- """
556
- )
557
- class DFineDecoderOutput(ModelOutput):
558
- r"""
559
- intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
560
- Stacked intermediate hidden states (output of each layer of the decoder).
561
- intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
562
- Stacked intermediate logits (logits of each layer of the decoder).
563
- intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
564
- Stacked intermediate reference points (reference points of each layer of the decoder).
565
- intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
566
- Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
567
- initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
568
- Stacked initial reference points (initial reference points of each layer of the decoder).
569
- cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
570
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
571
- sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
572
- used to compute the weighted average in the cross-attention heads.
573
- """
991
+ # bottom-up pan
992
+ self.downsample_convs = nn.ModuleList()
993
+ self.pan_blocks = nn.ModuleList()
994
+ for _ in range(len(self.in_channels) - 1):
995
+ self.downsample_convs.append(DFineSCDown(config, 3, 2))
996
+ num_blocks = round(3 * config.depth_mult)
997
+ self.pan_blocks.append(DFineRepNCSPELAN4(config, numb_blocks=num_blocks))
998
+
999
+ self.post_init()
1000
+
1001
+ @check_model_inputs(tie_last_hidden_states=False)
1002
+ def forward(
1003
+ self,
1004
+ inputs_embeds=None,
1005
+ **kwargs: Unpack[TransformersKwargs],
1006
+ ) -> BaseModelOutput:
1007
+ r"""
1008
+ Args:
1009
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1010
+ Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
1011
+ """
1012
+ feature_maps = inputs_embeds
574
1013
 
575
- last_hidden_state: Optional[torch.FloatTensor] = None
576
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
577
- intermediate_logits: Optional[torch.FloatTensor] = None
578
- intermediate_reference_points: Optional[torch.FloatTensor] = None
579
- intermediate_predicted_corners: Optional[torch.FloatTensor] = None
580
- initial_reference_points: Optional[torch.FloatTensor] = None
581
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
582
- attentions: Optional[tuple[torch.FloatTensor]] = None
583
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
1014
+ # AIFI: Apply transformer encoder to specified feature levels
1015
+ if self.config.encoder_layers > 0:
1016
+ for i, enc_ind in enumerate(self.encode_proj_layers):
1017
+ feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs)
1018
+
1019
+ # top-down FPN
1020
+ fpn_feature_maps = [feature_maps[-1]]
1021
+ for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
1022
+ backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1]
1023
+ top_fpn_feature_map = fpn_feature_maps[-1]
1024
+ # apply lateral block
1025
+ top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
1026
+ fpn_feature_maps[-1] = top_fpn_feature_map
1027
+ # apply fpn block
1028
+ top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
1029
+ fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
1030
+ new_fpn_feature_map = fpn_block(fused_feature_map)
1031
+ fpn_feature_maps.append(new_fpn_feature_map)
1032
+
1033
+ fpn_feature_maps.reverse()
1034
+
1035
+ # bottom-up PAN
1036
+ pan_feature_maps = [fpn_feature_maps[0]]
1037
+ for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
1038
+ top_pan_feature_map = pan_feature_maps[-1]
1039
+ fpn_feature_map = fpn_feature_maps[idx + 1]
1040
+ downsampled_feature_map = downsample_conv(top_pan_feature_map)
1041
+ fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
1042
+ new_pan_feature_map = pan_block(fused_feature_map)
1043
+ pan_feature_maps.append(new_pan_feature_map)
1044
+
1045
+ return BaseModelOutput(last_hidden_state=pan_feature_maps)
584
1046
 
585
1047
 
586
1048
  def inverse_sigmoid(x, eps=1e-5):
@@ -648,6 +1110,12 @@ class DFineDecoder(DFinePreTrainedModel):
648
1110
  to improve bounding box accuracy and robustness.
649
1111
  """
650
1112
 
1113
+ _can_record_outputs = {
1114
+ "hidden_states": DFineDecoderLayer,
1115
+ "attentions": DFineSelfAttention,
1116
+ "cross_attentions": DFineMultiscaleDeformableAttention,
1117
+ }
1118
+
651
1119
  def __init__(self, config: DFineConfig):
652
1120
  super().__init__(config)
653
1121
  self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx
@@ -657,7 +1125,7 @@ class DFineDecoder(DFinePreTrainedModel):
657
1125
  [DFineDecoderLayer(config) for _ in range(config.decoder_layers)]
658
1126
  + [DFineDecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)]
659
1127
  )
660
- self.query_pos_head = DFineMLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2)
1128
+ self.query_pos_head = DFineMLPPredictionHead(4, 2 * config.d_model, config.d_model, num_layers=2)
661
1129
 
662
1130
  # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
663
1131
  self.bbox_embed = None
@@ -675,6 +1143,7 @@ class DFineDecoder(DFinePreTrainedModel):
675
1143
  # Initialize weights and apply final processing
676
1144
  self.post_init()
677
1145
 
1146
+ @check_model_inputs()
678
1147
  def forward(
679
1148
  self,
680
1149
  encoder_hidden_states: torch.Tensor,
@@ -683,12 +1152,9 @@ class DFineDecoder(DFinePreTrainedModel):
683
1152
  spatial_shapes,
684
1153
  level_start_index=None,
685
1154
  spatial_shapes_list=None,
686
- output_hidden_states=None,
687
1155
  encoder_attention_mask=None,
688
1156
  memory_mask=None,
689
- output_attentions=None,
690
- return_dict=None,
691
- **kwargs,
1157
+ **kwargs: Unpack[TransformersKwargs],
692
1158
  ) -> DFineDecoderOutput:
693
1159
  r"""
694
1160
  Args:
@@ -702,39 +1168,17 @@ class DFineDecoder(DFinePreTrainedModel):
702
1168
  in `[0, 1]`:
703
1169
  - 1 for pixels that are real (i.e. **not masked**),
704
1170
  - 0 for pixels that are padding (i.e. **masked**).
705
- position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
706
- Position embeddings that are added to the queries and keys in each self-attention layer.
707
1171
  reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
708
1172
  Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
709
1173
  spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
710
1174
  Spatial shapes of the feature maps.
711
1175
  level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
712
1176
  Indexes for the start of each feature level. In range `[0, sequence_length]`.
713
- valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
714
- Ratio of valid area in each feature level.
715
-
716
- output_attentions (`bool`, *optional*):
717
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
718
- returned tensors for more detail.
719
- output_hidden_states (`bool`, *optional*):
720
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
721
- for more detail.
722
- return_dict (`bool`, *optional*):
723
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
724
1177
  """
725
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
726
- output_hidden_states = (
727
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
728
- )
729
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
730
-
731
1178
  if inputs_embeds is not None:
732
1179
  hidden_states = inputs_embeds
733
1180
 
734
1181
  # decoder layers
735
- all_hidden_states = () if output_hidden_states else None
736
- all_self_attns = () if output_attentions else None
737
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
738
1182
  intermediate = ()
739
1183
  intermediate_reference_points = ()
740
1184
  intermediate_logits = ()
@@ -750,25 +1194,22 @@ class DFineDecoder(DFinePreTrainedModel):
750
1194
  ref_points_input = ref_points_detach.unsqueeze(2)
751
1195
  query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10)
752
1196
 
753
- if output_hidden_states:
754
- all_hidden_states += (hidden_states,)
755
-
756
- output = decoder_layer(
757
- hidden_states=hidden_states,
1197
+ hidden_states = decoder_layer(
1198
+ hidden_states,
758
1199
  position_embeddings=query_pos_embed,
759
1200
  reference_points=ref_points_input,
760
1201
  spatial_shapes=spatial_shapes,
761
1202
  spatial_shapes_list=spatial_shapes_list,
762
1203
  encoder_hidden_states=encoder_hidden_states,
763
1204
  encoder_attention_mask=encoder_attention_mask,
764
- output_attentions=output_attentions,
1205
+ **kwargs,
765
1206
  )
766
1207
 
767
- hidden_states = output[0]
768
-
769
1208
  if i == 0:
770
1209
  # Initial bounding box predictions with inverse sigmoid refinement
771
- new_reference_points = F.sigmoid(self.pre_bbox_head(output[0]) + inverse_sigmoid(ref_points_detach))
1210
+ new_reference_points = F.sigmoid(
1211
+ self.pre_bbox_head(hidden_states) + inverse_sigmoid(ref_points_detach)
1212
+ )
772
1213
  ref_points_initial = new_reference_points.detach()
773
1214
 
774
1215
  # Refine bounding box corners using FDR, integrating previous layer's corrections
@@ -797,12 +1238,6 @@ class DFineDecoder(DFinePreTrainedModel):
797
1238
  initial_reference_points += (ref_points_initial,)
798
1239
  intermediate_predicted_corners += (pred_corners,)
799
1240
 
800
- if output_attentions:
801
- all_self_attns += (output[1],)
802
-
803
- if encoder_hidden_states is not None:
804
- all_cross_attentions += (output[2],)
805
-
806
1241
  # Keep batch_size as first dimension
807
1242
  intermediate = torch.stack(intermediate)
808
1243
  if self.class_embed is not None and self.bbox_embed is not None:
@@ -811,27 +1246,6 @@ class DFineDecoder(DFinePreTrainedModel):
811
1246
  initial_reference_points = torch.stack(initial_reference_points, dim=1)
812
1247
  intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
813
1248
 
814
- # add hidden states from the last decoder layer
815
- if output_hidden_states:
816
- all_hidden_states += (hidden_states,)
817
-
818
- if not return_dict:
819
- return tuple(
820
- v
821
- for v in [
822
- hidden_states,
823
- intermediate,
824
- intermediate_logits,
825
- intermediate_reference_points,
826
- intermediate_predicted_corners,
827
- initial_reference_points,
828
- all_hidden_states,
829
- all_self_attns,
830
- all_cross_attentions,
831
- ]
832
- if v is not None
833
- )
834
-
835
1249
  return DFineDecoderOutput(
836
1250
  last_hidden_state=hidden_states,
837
1251
  intermediate_hidden_states=intermediate,
@@ -839,51 +1253,9 @@ class DFineDecoder(DFinePreTrainedModel):
839
1253
  intermediate_reference_points=intermediate_reference_points,
840
1254
  intermediate_predicted_corners=intermediate_predicted_corners,
841
1255
  initial_reference_points=initial_reference_points,
842
- hidden_states=all_hidden_states,
843
- attentions=all_self_attns,
844
- cross_attentions=all_cross_attentions,
845
1256
  )
846
1257
 
847
1258
 
848
- class DFineFrozenBatchNorm2d(nn.Module):
849
- """
850
- BatchNorm2d where the batch statistics and the affine parameters are fixed.
851
-
852
- Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
853
- torchvision.models.resnet[18,34,50,101] produce nans.
854
- """
855
-
856
- def __init__(self, n):
857
- super().__init__()
858
- self.register_buffer("weight", torch.ones(n))
859
- self.register_buffer("bias", torch.zeros(n))
860
- self.register_buffer("running_mean", torch.zeros(n))
861
- self.register_buffer("running_var", torch.ones(n))
862
-
863
- def _load_from_state_dict(
864
- self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
865
- ):
866
- num_batches_tracked_key = prefix + "num_batches_tracked"
867
- if num_batches_tracked_key in state_dict:
868
- del state_dict[num_batches_tracked_key]
869
-
870
- super()._load_from_state_dict(
871
- state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
872
- )
873
-
874
- def forward(self, x):
875
- # move reshapes to the beginning
876
- # to make it user-friendly
877
- weight = self.weight.reshape(1, -1, 1, 1)
878
- bias = self.bias.reshape(1, -1, 1, 1)
879
- running_var = self.running_var.reshape(1, -1, 1, 1)
880
- running_mean = self.running_mean.reshape(1, -1, 1, 1)
881
- epsilon = 1e-5
882
- scale = weight * (running_var + epsilon).rsqrt()
883
- bias = bias - running_mean * scale
884
- return x * scale + bias
885
-
886
-
887
1259
  @dataclass
888
1260
  @auto_docstring(
889
1261
  custom_intro="""
@@ -922,24 +1294,24 @@ class DFineModelOutput(ModelOutput):
922
1294
  Extra dictionary for the denoising related values.
923
1295
  """
924
1296
 
925
- last_hidden_state: Optional[torch.FloatTensor] = None
926
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
927
- intermediate_logits: Optional[torch.FloatTensor] = None
928
- intermediate_reference_points: Optional[torch.FloatTensor] = None
929
- intermediate_predicted_corners: Optional[torch.FloatTensor] = None
930
- initial_reference_points: Optional[torch.FloatTensor] = None
931
- decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
932
- decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
933
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
934
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
935
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
936
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
937
- init_reference_points: Optional[torch.FloatTensor] = None
938
- enc_topk_logits: Optional[torch.FloatTensor] = None
939
- enc_topk_bboxes: Optional[torch.FloatTensor] = None
940
- enc_outputs_class: Optional[torch.FloatTensor] = None
941
- enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
942
- denoising_meta_values: Optional[dict] = None
1297
+ last_hidden_state: torch.FloatTensor | None = None
1298
+ intermediate_hidden_states: torch.FloatTensor | None = None
1299
+ intermediate_logits: torch.FloatTensor | None = None
1300
+ intermediate_reference_points: torch.FloatTensor | None = None
1301
+ intermediate_predicted_corners: torch.FloatTensor | None = None
1302
+ initial_reference_points: torch.FloatTensor | None = None
1303
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
1304
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
1305
+ cross_attentions: tuple[torch.FloatTensor] | None = None
1306
+ encoder_last_hidden_state: torch.FloatTensor | None = None
1307
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
1308
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
1309
+ init_reference_points: torch.FloatTensor | None = None
1310
+ enc_topk_logits: torch.FloatTensor | None = None
1311
+ enc_topk_bboxes: torch.FloatTensor | None = None
1312
+ enc_outputs_class: torch.FloatTensor | None = None
1313
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
1314
+ denoising_meta_values: dict | None = None
943
1315
 
944
1316
 
945
1317
  def replace_batch_norm(model):
@@ -1135,8 +1507,8 @@ class DFineModel(DFinePreTrainedModel):
1135
1507
  intermediate_channel_sizes = self.backbone.intermediate_channel_sizes
1136
1508
  num_backbone_outs = len(config.decoder_in_channels)
1137
1509
  encoder_input_proj_list = []
1138
- for _ in range(num_backbone_outs):
1139
- in_channels = intermediate_channel_sizes[_]
1510
+ for i in range(num_backbone_outs):
1511
+ in_channels = intermediate_channel_sizes[i]
1140
1512
  encoder_input_proj_list.append(
1141
1513
  nn.Sequential(
1142
1514
  nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
@@ -1162,15 +1534,15 @@ class DFineModel(DFinePreTrainedModel):
1162
1534
  nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
1163
1535
  )
1164
1536
  self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
1165
- self.enc_bbox_head = DFineMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3)
1537
+ self.enc_bbox_head = DFineMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
1166
1538
 
1167
1539
  # init encoder output anchors and valid_mask
1168
1540
  if config.anchor_image_size:
1169
1541
  self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)
1170
1542
  num_backbone_outs = len(config.decoder_in_channels)
1171
1543
  decoder_input_proj_list = []
1172
- for _ in range(num_backbone_outs):
1173
- in_channels = config.decoder_in_channels[_]
1544
+ for i in range(num_backbone_outs):
1545
+ in_channels = config.decoder_in_channels[i]
1174
1546
  decoder_input_proj_list.append(
1175
1547
  nn.Sequential(
1176
1548
  nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
@@ -1244,26 +1616,20 @@ class DFineModel(DFinePreTrainedModel):
1244
1616
  return anchors, valid_mask
1245
1617
 
1246
1618
  @auto_docstring
1619
+ @can_return_tuple
1247
1620
  def forward(
1248
1621
  self,
1249
1622
  pixel_values: torch.FloatTensor,
1250
- pixel_mask: Optional[torch.LongTensor] = None,
1251
- encoder_outputs: Optional[torch.FloatTensor] = None,
1252
- inputs_embeds: Optional[torch.FloatTensor] = None,
1253
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1254
- labels: Optional[list[dict]] = None,
1255
- output_attentions: Optional[bool] = None,
1256
- output_hidden_states: Optional[bool] = None,
1257
- return_dict: Optional[bool] = None,
1258
- **kwargs,
1259
- ) -> Union[tuple[torch.FloatTensor], DFineModelOutput]:
1623
+ pixel_mask: torch.LongTensor | None = None,
1624
+ encoder_outputs: torch.FloatTensor | None = None,
1625
+ inputs_embeds: torch.FloatTensor | None = None,
1626
+ labels: list[dict] | None = None,
1627
+ **kwargs: Unpack[TransformersKwargs],
1628
+ ) -> tuple[torch.FloatTensor] | DFineModelOutput:
1260
1629
  r"""
1261
1630
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1262
1631
  Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1263
1632
  can choose to directly pass a flattened representation of an image.
1264
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1265
- Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1266
- embedded representation.
1267
1633
  labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1268
1634
  Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1269
1635
  following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
@@ -1291,53 +1657,46 @@ class DFineModel(DFinePreTrainedModel):
1291
1657
  >>> list(last_hidden_states.shape)
1292
1658
  [1, 300, 256]
1293
1659
  ```"""
1294
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1295
- output_hidden_states = (
1296
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1297
- )
1298
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1299
-
1300
- batch_size, num_channels, height, width = pixel_values.shape
1301
- device = pixel_values.device
1302
-
1303
- if pixel_mask is None:
1304
- pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1305
-
1306
- features = self.backbone(pixel_values, pixel_mask)
1307
-
1308
- proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
1660
+ if pixel_values is None and inputs_embeds is None:
1661
+ raise ValueError("You have to specify either pixel_values or inputs_embeds")
1662
+
1663
+ if inputs_embeds is None:
1664
+ batch_size, num_channels, height, width = pixel_values.shape
1665
+ device = pixel_values.device
1666
+ if pixel_mask is None:
1667
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1668
+ features = self.backbone(pixel_values, pixel_mask)
1669
+ proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
1670
+ else:
1671
+ batch_size = inputs_embeds.shape[0]
1672
+ device = inputs_embeds.device
1673
+ proj_feats = inputs_embeds
1309
1674
 
1310
1675
  if encoder_outputs is None:
1311
1676
  encoder_outputs = self.encoder(
1312
1677
  proj_feats,
1313
- output_attentions=output_attentions,
1314
- output_hidden_states=output_hidden_states,
1315
- return_dict=return_dict,
1678
+ **kwargs,
1316
1679
  )
1317
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1318
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1680
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
1681
+ elif not isinstance(encoder_outputs, BaseModelOutput):
1319
1682
  encoder_outputs = BaseModelOutput(
1320
1683
  last_hidden_state=encoder_outputs[0],
1321
- hidden_states=encoder_outputs[1] if output_hidden_states else None,
1322
- attentions=encoder_outputs[2]
1323
- if len(encoder_outputs) > 2
1324
- else encoder_outputs[1]
1325
- if output_attentions
1326
- else None,
1684
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1685
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1327
1686
  )
1328
1687
 
1329
1688
  # Equivalent to def _get_encoder_input
1330
1689
  # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/DFine_pytorch/src/zoo/DFine/DFine_decoder.py#L412
1331
1690
  sources = []
1332
- for level, source in enumerate(encoder_outputs[0]):
1691
+ for level, source in enumerate(encoder_outputs.last_hidden_state):
1333
1692
  sources.append(self.decoder_input_proj[level](source))
1334
1693
 
1335
1694
  # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
1336
1695
  if self.config.num_feature_levels > len(sources):
1337
1696
  _len_sources = len(sources)
1338
- sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1])
1697
+ sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state)[-1])
1339
1698
  for i in range(_len_sources + 1, self.config.num_feature_levels):
1340
- sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1]))
1699
+ sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1]))
1341
1700
 
1342
1701
  # Prepare encoder inputs (by flattening)
1343
1702
  source_flatten = []
@@ -1429,22 +1788,9 @@ class DFineModel(DFinePreTrainedModel):
1429
1788
  spatial_shapes=spatial_shapes,
1430
1789
  spatial_shapes_list=spatial_shapes_list,
1431
1790
  level_start_index=level_start_index,
1432
- output_attentions=output_attentions,
1433
- output_hidden_states=output_hidden_states,
1434
- return_dict=return_dict,
1791
+ **kwargs,
1435
1792
  )
1436
1793
 
1437
- if not return_dict:
1438
- enc_outputs = tuple(
1439
- value
1440
- for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits]
1441
- if value is not None
1442
- )
1443
- dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values])
1444
- tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs
1445
-
1446
- return tuple_outputs
1447
-
1448
1794
  return DFineModelOutput(
1449
1795
  last_hidden_state=decoder_outputs.last_hidden_state,
1450
1796
  intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
@@ -1520,29 +1866,29 @@ class DFineObjectDetectionOutput(ModelOutput):
1520
1866
  Extra dictionary for the denoising related values
1521
1867
  """
1522
1868
 
1523
- loss: Optional[torch.FloatTensor] = None
1524
- loss_dict: Optional[dict] = None
1525
- logits: Optional[torch.FloatTensor] = None
1526
- pred_boxes: Optional[torch.FloatTensor] = None
1527
- auxiliary_outputs: Optional[list[dict]] = None
1528
- last_hidden_state: Optional[torch.FloatTensor] = None
1529
- intermediate_hidden_states: Optional[torch.FloatTensor] = None
1530
- intermediate_logits: Optional[torch.FloatTensor] = None
1531
- intermediate_reference_points: Optional[torch.FloatTensor] = None
1532
- intermediate_predicted_corners: Optional[torch.FloatTensor] = None
1533
- initial_reference_points: Optional[torch.FloatTensor] = None
1534
- decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
1535
- decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
1536
- cross_attentions: Optional[tuple[torch.FloatTensor]] = None
1537
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
1538
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
1539
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
1540
- init_reference_points: Optional[tuple[torch.FloatTensor]] = None
1541
- enc_topk_logits: Optional[torch.FloatTensor] = None
1542
- enc_topk_bboxes: Optional[torch.FloatTensor] = None
1543
- enc_outputs_class: Optional[torch.FloatTensor] = None
1544
- enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
1545
- denoising_meta_values: Optional[dict] = None
1869
+ loss: torch.FloatTensor | None = None
1870
+ loss_dict: dict | None = None
1871
+ logits: torch.FloatTensor | None = None
1872
+ pred_boxes: torch.FloatTensor | None = None
1873
+ auxiliary_outputs: list[dict] | None = None
1874
+ last_hidden_state: torch.FloatTensor | None = None
1875
+ intermediate_hidden_states: torch.FloatTensor | None = None
1876
+ intermediate_logits: torch.FloatTensor | None = None
1877
+ intermediate_reference_points: torch.FloatTensor | None = None
1878
+ intermediate_predicted_corners: torch.FloatTensor | None = None
1879
+ initial_reference_points: torch.FloatTensor | None = None
1880
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
1881
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
1882
+ cross_attentions: tuple[torch.FloatTensor] | None = None
1883
+ encoder_last_hidden_state: torch.FloatTensor | None = None
1884
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
1885
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
1886
+ init_reference_points: tuple[torch.FloatTensor] | None = None
1887
+ enc_topk_logits: torch.FloatTensor | None = None
1888
+ enc_topk_bboxes: torch.FloatTensor | None = None
1889
+ enc_outputs_class: torch.FloatTensor | None = None
1890
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
1891
+ denoising_meta_values: dict | None = None
1546
1892
 
1547
1893
 
1548
1894
  @auto_docstring(
@@ -1556,10 +1902,10 @@ class DFineForObjectDetection(DFinePreTrainedModel):
1556
1902
  # We can't initialize the model on meta device as some weights are modified during the initialization
1557
1903
  _no_split_modules = None
1558
1904
  _tied_weights_keys = {
1559
- r"bbox_embed.(?![0])\d+": "bbox_embed.0",
1560
- r"class_embed.(?![0])\d+": "class_embed.0",
1561
- "model.decoder.class_embed": "class_embed",
1562
- "model.decoder.bbox_embed": "bbox_embed",
1905
+ r"bbox_embed.(?![0])\d+": r"bbox_embed.0",
1906
+ r"class_embed.(?![0])\d+": r"^class_embed.0",
1907
+ "class_embed": "model.decoder.class_embed",
1908
+ "bbox_embed": "model.decoder.bbox_embed",
1563
1909
  }
1564
1910
 
1565
1911
  def __init__(self, config: DFineConfig):
@@ -1591,19 +1937,16 @@ class DFineForObjectDetection(DFinePreTrainedModel):
1591
1937
  return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
1592
1938
 
1593
1939
  @auto_docstring
1940
+ @can_return_tuple
1594
1941
  def forward(
1595
1942
  self,
1596
1943
  pixel_values: torch.FloatTensor,
1597
- pixel_mask: Optional[torch.LongTensor] = None,
1598
- encoder_outputs: Optional[torch.FloatTensor] = None,
1599
- inputs_embeds: Optional[torch.FloatTensor] = None,
1600
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1601
- labels: Optional[list[dict]] = None,
1602
- output_attentions: Optional[bool] = None,
1603
- output_hidden_states: Optional[bool] = None,
1604
- return_dict: Optional[bool] = None,
1605
- **kwargs,
1606
- ) -> Union[tuple[torch.FloatTensor], DFineObjectDetectionOutput]:
1944
+ pixel_mask: torch.LongTensor | None = None,
1945
+ encoder_outputs: torch.FloatTensor | None = None,
1946
+ inputs_embeds: torch.FloatTensor | None = None,
1947
+ labels: list[dict] | None = None,
1948
+ **kwargs: Unpack[TransformersKwargs],
1949
+ ) -> tuple[torch.FloatTensor] | DFineObjectDetectionOutput:
1607
1950
  r"""
1608
1951
  Example:
1609
1952
 
@@ -1649,40 +1992,29 @@ class DFineForObjectDetection(DFinePreTrainedModel):
1649
1992
  Detected sofa with confidence 0.918 at location [0.59, 1.88, 640.25, 474.74]
1650
1993
  ```
1651
1994
  """
1652
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1653
- output_hidden_states = (
1654
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1655
- )
1656
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1657
-
1658
1995
  outputs = self.model(
1659
1996
  pixel_values,
1660
1997
  pixel_mask=pixel_mask,
1661
1998
  encoder_outputs=encoder_outputs,
1662
1999
  inputs_embeds=inputs_embeds,
1663
- decoder_inputs_embeds=decoder_inputs_embeds,
1664
2000
  labels=labels,
1665
- output_attentions=output_attentions,
1666
- output_hidden_states=output_hidden_states,
1667
- return_dict=return_dict,
2001
+ **kwargs,
1668
2002
  )
1669
2003
 
1670
- denoising_meta_values = (
1671
- outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None
1672
- )
2004
+ denoising_meta_values = outputs.denoising_meta_values if self.training else None
1673
2005
 
1674
- outputs_class = outputs.intermediate_logits if return_dict else outputs[2]
1675
- outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3]
1676
- predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4]
1677
- initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5]
2006
+ outputs_class = outputs.intermediate_logits
2007
+ outputs_coord = outputs.intermediate_reference_points
2008
+ predicted_corners = outputs.intermediate_predicted_corners
2009
+ initial_reference_points = outputs.initial_reference_points
1678
2010
 
1679
2011
  logits = outputs_class[:, -1]
1680
2012
  pred_boxes = outputs_coord[:, -1]
1681
2013
 
1682
2014
  loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
1683
2015
  if labels is not None:
1684
- enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
1685
- enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
2016
+ enc_topk_logits = outputs.enc_topk_logits
2017
+ enc_topk_bboxes = outputs.enc_topk_bboxes
1686
2018
  loss, loss_dict, auxiliary_outputs = self.loss_function(
1687
2019
  logits,
1688
2020
  labels,
@@ -1699,13 +2031,6 @@ class DFineForObjectDetection(DFinePreTrainedModel):
1699
2031
  **kwargs,
1700
2032
  )
1701
2033
 
1702
- if not return_dict:
1703
- if auxiliary_outputs is not None:
1704
- output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs
1705
- else:
1706
- output = (logits, pred_boxes) + outputs
1707
- return ((loss, loss_dict) + output) if loss is not None else output
1708
-
1709
2034
  return DFineObjectDetectionOutput(
1710
2035
  loss=loss,
1711
2036
  loss_dict=loss_dict,
@@ -1733,470 +2058,4 @@ class DFineForObjectDetection(DFinePreTrainedModel):
1733
2058
  )
1734
2059
 
1735
2060
 
1736
- # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1737
- class DFineMLPPredictionHead(nn.Module):
1738
- """
1739
- Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1740
- height and width of a bounding box w.r.t. an image.
1741
-
1742
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1743
- Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/DFine_paddle/ppdet/modeling/transformers/utils.py#L453
1744
-
1745
- """
1746
-
1747
- def __init__(self, config, input_dim, d_model, output_dim, num_layers):
1748
- super().__init__()
1749
- self.num_layers = num_layers
1750
- h = [d_model] * (num_layers - 1)
1751
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1752
-
1753
- def forward(self, x):
1754
- for i, layer in enumerate(self.layers):
1755
- x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1756
- return x
1757
-
1758
-
1759
- class DFineMLP(nn.Module):
1760
- def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"):
1761
- super().__init__()
1762
- self.num_layers = num_layers
1763
- hidden_dims = [hidden_dim] * (num_layers - 1)
1764
- input_dims = [input_dim] + hidden_dims
1765
- output_dims = hidden_dims + [output_dim]
1766
- self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims))
1767
- self.act = ACT2CLS[act]()
1768
-
1769
- def forward(self, stat_features: torch.Tensor) -> torch.Tensor:
1770
- for i, layer in enumerate(self.layers):
1771
- stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features)
1772
- return stat_features
1773
-
1774
-
1775
- class DFineLQE(nn.Module):
1776
- def __init__(self, config: DFineConfig):
1777
- super().__init__()
1778
- self.top_prob_values = config.top_prob_values
1779
- self.max_num_bins = config.max_num_bins
1780
- self.reg_conf = DFineMLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers)
1781
-
1782
- def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor:
1783
- batch_size, length, _ = pred_corners.size()
1784
- prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1)
1785
- prob_topk, _ = prob.topk(self.top_prob_values, dim=-1)
1786
- stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1)
1787
- quality_score = self.reg_conf(stat.reshape(batch_size, length, -1))
1788
- scores = scores + quality_score
1789
- return scores
1790
-
1791
-
1792
- class DFineConvNormLayer(nn.Module):
1793
- def __init__(
1794
- self,
1795
- config: DFineConfig,
1796
- in_channels: int,
1797
- out_channels: int,
1798
- kernel_size: int,
1799
- stride: int,
1800
- groups: int = 1,
1801
- padding: Optional[int] = None,
1802
- activation: Optional[str] = None,
1803
- ):
1804
- super().__init__()
1805
- self.conv = nn.Conv2d(
1806
- in_channels,
1807
- out_channels,
1808
- kernel_size,
1809
- stride,
1810
- groups=groups,
1811
- padding=(kernel_size - 1) // 2 if padding is None else padding,
1812
- bias=False,
1813
- )
1814
- self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
1815
- self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
1816
-
1817
- def forward(self, hidden_state):
1818
- hidden_state = self.conv(hidden_state)
1819
- hidden_state = self.norm(hidden_state)
1820
- hidden_state = self.activation(hidden_state)
1821
- return hidden_state
1822
-
1823
-
1824
- class DFineRepVggBlock(nn.Module):
1825
- """
1826
- RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
1827
- """
1828
-
1829
- def __init__(self, config: DFineConfig, in_channels: int, out_channels: int):
1830
- super().__init__()
1831
-
1832
- activation = config.activation_function
1833
- hidden_channels = in_channels
1834
- self.conv1 = DFineConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1)
1835
- self.conv2 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0)
1836
- self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
1837
-
1838
- def forward(self, x):
1839
- y = self.conv1(x) + self.conv2(x)
1840
- return self.activation(y)
1841
-
1842
-
1843
- class DFineCSPRepLayer(nn.Module):
1844
- """
1845
- Cross Stage Partial (CSP) network layer with RepVGG blocks.
1846
- """
1847
-
1848
- def __init__(
1849
- self, config: DFineConfig, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0
1850
- ):
1851
- super().__init__()
1852
- activation = config.activation_function
1853
-
1854
- hidden_channels = int(out_channels * expansion)
1855
- self.conv1 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
1856
- self.conv2 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
1857
- self.bottlenecks = nn.ModuleList(
1858
- [DFineRepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)]
1859
- )
1860
- if hidden_channels != out_channels:
1861
- self.conv3 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
1862
- else:
1863
- self.conv3 = nn.Identity()
1864
-
1865
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
1866
- hidden_state_1 = self.conv1(hidden_state)
1867
- for bottleneck in self.bottlenecks:
1868
- hidden_state_1 = bottleneck(hidden_state_1)
1869
- hidden_state_2 = self.conv2(hidden_state)
1870
- hidden_state_3 = self.conv3(hidden_state_1 + hidden_state_2)
1871
- return hidden_state_3
1872
-
1873
-
1874
- class DFineRepNCSPELAN4(nn.Module):
1875
- def __init__(self, config: DFineConfig, act: str = "silu", numb_blocks: int = 3):
1876
- super().__init__()
1877
- conv1_dim = config.encoder_hidden_dim * 2
1878
- conv2_dim = config.encoder_hidden_dim
1879
- conv3_dim = config.encoder_hidden_dim * 2
1880
- conv4_dim = round(config.hidden_expansion * config.encoder_hidden_dim // 2)
1881
- self.conv_dim = conv3_dim // 2
1882
- self.conv1 = DFineConvNormLayer(config, conv1_dim, conv3_dim, 1, 1, activation=act)
1883
- self.csp_rep1 = DFineCSPRepLayer(config, conv3_dim // 2, conv4_dim, num_blocks=numb_blocks)
1884
- self.conv2 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
1885
- self.csp_rep2 = DFineCSPRepLayer(config, conv4_dim, conv4_dim, num_blocks=numb_blocks)
1886
- self.conv3 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act)
1887
- self.conv4 = DFineConvNormLayer(config, conv3_dim + (2 * conv4_dim), conv2_dim, 1, 1, activation=act)
1888
-
1889
- def forward(self, input_features: torch.Tensor) -> torch.Tensor:
1890
- # Split initial features into two branches after first convolution
1891
- split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1))
1892
-
1893
- # Process branches sequentially
1894
- branch1 = self.csp_rep1(split_features[-1])
1895
- branch1 = self.conv2(branch1)
1896
- branch2 = self.csp_rep2(branch1)
1897
- branch2 = self.conv3(branch2)
1898
-
1899
- split_features.extend([branch1, branch2])
1900
- merged_features = torch.cat(split_features, 1)
1901
- merged_features = self.conv4(merged_features)
1902
- return merged_features
1903
-
1904
-
1905
- class DFineSCDown(nn.Module):
1906
- def __init__(self, config: DFineConfig, kernel_size: int, stride: int):
1907
- super().__init__()
1908
- self.conv1 = DFineConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1)
1909
- self.conv2 = DFineConvNormLayer(
1910
- config,
1911
- config.encoder_hidden_dim,
1912
- config.encoder_hidden_dim,
1913
- kernel_size,
1914
- stride,
1915
- config.encoder_hidden_dim,
1916
- )
1917
-
1918
- def forward(self, input_features: torch.Tensor) -> torch.Tensor:
1919
- input_features = self.conv1(input_features)
1920
- input_features = self.conv2(input_features)
1921
- return input_features
1922
-
1923
-
1924
- class DFineEncoderLayer(nn.Module):
1925
- def __init__(self, config: DFineConfig):
1926
- super().__init__()
1927
- self.normalize_before = config.normalize_before
1928
-
1929
- # self-attention
1930
- self.self_attn = DFineMultiheadAttention(
1931
- embed_dim=config.encoder_hidden_dim,
1932
- num_heads=config.num_attention_heads,
1933
- dropout=config.dropout,
1934
- )
1935
- self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
1936
- self.dropout = config.dropout
1937
- self.activation_fn = ACT2FN[config.encoder_activation_function]
1938
- self.activation_dropout = config.activation_dropout
1939
- self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim)
1940
- self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim)
1941
- self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
1942
-
1943
- def forward(
1944
- self,
1945
- hidden_states: torch.Tensor,
1946
- attention_mask: torch.Tensor,
1947
- position_embeddings: Optional[torch.Tensor] = None,
1948
- output_attentions: bool = False,
1949
- **kwargs,
1950
- ):
1951
- """
1952
- Args:
1953
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1954
- attention_mask (`torch.FloatTensor`): attention mask of size
1955
- `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
1956
- values.
1957
- position_embeddings (`torch.FloatTensor`, *optional*):
1958
- Object queries (also called content embeddings), to be added to the hidden states.
1959
- output_attentions (`bool`, *optional*):
1960
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1961
- returned tensors for more detail.
1962
- """
1963
- residual = hidden_states
1964
- if self.normalize_before:
1965
- hidden_states = self.self_attn_layer_norm(hidden_states)
1966
-
1967
- hidden_states, attn_weights = self.self_attn(
1968
- hidden_states=hidden_states,
1969
- attention_mask=attention_mask,
1970
- position_embeddings=position_embeddings,
1971
- output_attentions=output_attentions,
1972
- )
1973
-
1974
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1975
- hidden_states = residual + hidden_states
1976
- if not self.normalize_before:
1977
- hidden_states = self.self_attn_layer_norm(hidden_states)
1978
-
1979
- if self.normalize_before:
1980
- hidden_states = self.final_layer_norm(hidden_states)
1981
- residual = hidden_states
1982
-
1983
- hidden_states = self.activation_fn(self.fc1(hidden_states))
1984
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
1985
-
1986
- hidden_states = self.fc2(hidden_states)
1987
-
1988
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1989
-
1990
- hidden_states = residual + hidden_states
1991
- if not self.normalize_before:
1992
- hidden_states = self.final_layer_norm(hidden_states)
1993
-
1994
- if self.training:
1995
- if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
1996
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
1997
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
1998
-
1999
- outputs = (hidden_states,)
2000
-
2001
- if output_attentions:
2002
- outputs += (attn_weights,)
2003
-
2004
- return outputs
2005
-
2006
-
2007
- class DFineEncoder(nn.Module):
2008
- def __init__(self, config: DFineConfig):
2009
- super().__init__()
2010
-
2011
- self.layers = nn.ModuleList([DFineEncoderLayer(config) for _ in range(config.encoder_layers)])
2012
-
2013
- def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor:
2014
- hidden_states = src
2015
- for layer in self.layers:
2016
- hidden_states = layer(
2017
- hidden_states,
2018
- attention_mask=src_mask,
2019
- position_embeddings=pos_embed,
2020
- output_attentions=output_attentions,
2021
- )
2022
- return hidden_states
2023
-
2024
-
2025
- class DFineHybridEncoder(nn.Module):
2026
- """
2027
- Decoder consisting of a projection layer, a set of `DFineEncoder`, a top-down Feature Pyramid Network
2028
- (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069
2029
-
2030
- Args:
2031
- config: DFineConfig
2032
- """
2033
-
2034
- def __init__(self, config: DFineConfig):
2035
- super().__init__()
2036
- self.config = config
2037
- self.in_channels = config.encoder_in_channels
2038
- self.num_fpn_stages = len(self.in_channels) - 1
2039
- self.feat_strides = config.feat_strides
2040
- self.encoder_hidden_dim = config.encoder_hidden_dim
2041
- self.encode_proj_layers = config.encode_proj_layers
2042
- self.positional_encoding_temperature = config.positional_encoding_temperature
2043
- self.eval_size = config.eval_size
2044
- self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
2045
- self.out_strides = self.feat_strides
2046
-
2047
- # encoder transformer
2048
- self.encoder = nn.ModuleList([DFineEncoder(config) for _ in range(len(self.encode_proj_layers))])
2049
- # top-down fpn
2050
- self.lateral_convs = nn.ModuleList()
2051
- self.fpn_blocks = nn.ModuleList()
2052
- for _ in range(len(self.in_channels) - 1, 0, -1):
2053
- lateral_layer = DFineConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1)
2054
- self.lateral_convs.append(lateral_layer)
2055
- num_blocks = round(3 * config.depth_mult)
2056
- fpn_layer = DFineRepNCSPELAN4(config, numb_blocks=num_blocks)
2057
- self.fpn_blocks.append(fpn_layer)
2058
-
2059
- # bottom-up pan
2060
- self.downsample_convs = nn.ModuleList()
2061
- self.pan_blocks = nn.ModuleList()
2062
- for _ in range(len(self.in_channels) - 1):
2063
- self.downsample_convs.append(DFineSCDown(config, 3, 2))
2064
- num_blocks = round(3 * config.depth_mult)
2065
- self.pan_blocks.append(DFineRepNCSPELAN4(config, numb_blocks=num_blocks))
2066
-
2067
- @staticmethod
2068
- def build_2d_sincos_position_embedding(
2069
- width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
2070
- ):
2071
- grid_w = torch.arange(torch_int(width), device=device).to(dtype)
2072
- grid_h = torch.arange(torch_int(height), device=device).to(dtype)
2073
- grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
2074
- if embed_dim % 4 != 0:
2075
- raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
2076
- pos_dim = embed_dim // 4
2077
- omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
2078
- omega = 1.0 / (temperature**omega)
2079
-
2080
- out_w = grid_w.flatten()[..., None] @ omega[None]
2081
- out_h = grid_h.flatten()[..., None] @ omega[None]
2082
-
2083
- return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
2084
-
2085
- def forward(
2086
- self,
2087
- inputs_embeds=None,
2088
- attention_mask=None,
2089
- position_embeddings=None,
2090
- spatial_shapes=None,
2091
- level_start_index=None,
2092
- valid_ratios=None,
2093
- output_attentions=None,
2094
- output_hidden_states=None,
2095
- return_dict=None,
2096
- ):
2097
- r"""
2098
- Args:
2099
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
2100
- Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
2101
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
2102
- Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
2103
- - 1 for pixel features that are real (i.e. **not masked**),
2104
- - 0 for pixel features that are padding (i.e. **masked**).
2105
- [What are attention masks?](../glossary#attention-mask)
2106
- position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
2107
- Position embeddings that are added to the queries and keys in each self-attention layer.
2108
- spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
2109
- Spatial shapes of each feature map.
2110
- level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
2111
- Starting index of each feature map.
2112
- valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
2113
- Ratio of valid area in each feature level.
2114
- output_attentions (`bool`, *optional*):
2115
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2116
- returned tensors for more detail.
2117
- output_hidden_states (`bool`, *optional*):
2118
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
2119
- for more detail.
2120
- return_dict (`bool`, *optional*):
2121
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
2122
- """
2123
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2124
- output_hidden_states = (
2125
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2126
- )
2127
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2128
-
2129
- hidden_states = inputs_embeds
2130
-
2131
- encoder_states = () if output_hidden_states else None
2132
- all_attentions = () if output_attentions else None
2133
-
2134
- # encoder
2135
- if self.config.encoder_layers > 0:
2136
- for i, enc_ind in enumerate(self.encode_proj_layers):
2137
- if output_hidden_states:
2138
- encoder_states = encoder_states + (hidden_states[enc_ind],)
2139
- height, width = hidden_states[enc_ind].shape[2:]
2140
- # flatten [batch, channel, height, width] to [batch, height*width, channel]
2141
- src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1)
2142
- if self.training or self.eval_size is None:
2143
- pos_embed = self.build_2d_sincos_position_embedding(
2144
- width,
2145
- height,
2146
- self.encoder_hidden_dim,
2147
- self.positional_encoding_temperature,
2148
- device=src_flatten.device,
2149
- dtype=src_flatten.dtype,
2150
- )
2151
- else:
2152
- pos_embed = None
2153
-
2154
- layer_outputs = self.encoder[i](
2155
- src_flatten,
2156
- pos_embed=pos_embed,
2157
- output_attentions=output_attentions,
2158
- )
2159
- hidden_states[enc_ind] = (
2160
- layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
2161
- )
2162
-
2163
- if output_attentions:
2164
- all_attentions = all_attentions + (layer_outputs[1],)
2165
-
2166
- if output_hidden_states:
2167
- encoder_states = encoder_states + (hidden_states[enc_ind],)
2168
-
2169
- # top-down FPN
2170
- fpn_feature_maps = [hidden_states[-1]]
2171
- for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
2172
- backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
2173
- top_fpn_feature_map = fpn_feature_maps[-1]
2174
- # apply lateral block
2175
- top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
2176
- fpn_feature_maps[-1] = top_fpn_feature_map
2177
- # apply fpn block
2178
- top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
2179
- fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
2180
- new_fpn_feature_map = fpn_block(fused_feature_map)
2181
- fpn_feature_maps.append(new_fpn_feature_map)
2182
-
2183
- fpn_feature_maps.reverse()
2184
-
2185
- # bottom-up PAN
2186
- pan_feature_maps = [fpn_feature_maps[0]]
2187
- for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
2188
- top_pan_feature_map = pan_feature_maps[-1]
2189
- fpn_feature_map = fpn_feature_maps[idx + 1]
2190
- downsampled_feature_map = downsample_conv(top_pan_feature_map)
2191
- fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
2192
- new_pan_feature_map = pan_block(fused_feature_map)
2193
- pan_feature_maps.append(new_pan_feature_map)
2194
-
2195
- if not return_dict:
2196
- return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
2197
- return BaseModelOutput(
2198
- last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
2199
- )
2200
-
2201
-
2202
2061
  __all__ = ["DFineModel", "DFinePreTrainedModel", "DFineForObjectDetection"]