transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ import torch
22
22
  from torch import nn
23
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
24
 
25
+ from ... import initialization as init
25
26
  from ...cache_utils import Cache, DynamicCache
26
27
  from ...generation import GenerationMixin
27
28
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
@@ -187,6 +188,13 @@ class CTRLPreTrainedModel(PreTrainedModel):
187
188
  config: CTRLConfig
188
189
  base_model_prefix = "transformer"
189
190
 
191
+ def _init_weights(self, module):
192
+ super()._init_weights(module)
193
+ if isinstance(module, CTRLModel):
194
+ init.copy_(
195
+ module.pos_encoding, positional_encoding(module.config.n_positions, module.d_model_size, torch.float)
196
+ )
197
+
190
198
 
191
199
  @auto_docstring
192
200
  class CTRLModel(CTRLPreTrainedModel):
@@ -196,7 +204,9 @@ class CTRLModel(CTRLPreTrainedModel):
196
204
  self.d_model_size = config.n_embd
197
205
  self.num_layers = config.n_layer
198
206
 
199
- self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
207
+ self.register_buffer(
208
+ "pos_encoding", positional_encoding(config.n_positions, self.d_model_size, torch.float), persistent=False
209
+ )
200
210
 
201
211
  self.w = nn.Embedding(config.vocab_size, config.n_embd)
202
212
 
@@ -470,7 +480,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin):
470
480
  attentions=transformer_outputs.attentions,
471
481
  )
472
482
 
473
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
483
+ def prepare_inputs_for_generation(
484
+ self, input_ids, past_key_values=None, use_cache=None, is_first_iteration=False, **kwargs
485
+ ):
474
486
  # Overwritten -- inputs_embeds not working properly
475
487
 
476
488
  # only last tokens for inputs_ids if past is defined in kwargs
@@ -497,9 +497,13 @@ class CvtPreTrainedModel(PreTrainedModel):
497
497
  init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
498
498
  if module.bias is not None:
499
499
  init.zeros_(module.bias)
500
- elif isinstance(module, nn.LayerNorm):
500
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
501
501
  init.zeros_(module.bias)
502
502
  init.ones_(module.weight)
503
+ if getattr(module, "running_mean", None) is not None:
504
+ init.zeros_(module.running_mean)
505
+ init.ones_(module.running_var)
506
+ init.zeros_(module.num_batches_tracked)
503
507
  elif isinstance(module, CvtStage):
504
508
  if self.config.cls_token[module.stage]:
505
509
  init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
@@ -58,7 +58,7 @@ class CwmRotaryEmbedding(nn.Module):
58
58
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
59
59
 
60
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
- self.original_inv_freq = inv_freq
61
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
62
62
 
63
63
  @staticmethod
64
64
  def compute_default_rope_parameters(
@@ -47,7 +47,7 @@ class DFineConfig(PreTrainedConfig):
47
47
  The epsilon used by the layer normalization layers.
48
48
  batch_norm_eps (`float`, *optional*, defaults to 1e-05):
49
49
  The epsilon used by the batch normalization layers.
50
- backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
50
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`):
51
51
  The configuration of the backbone model.
52
52
  backbone (`str`, *optional*):
53
53
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -288,8 +288,7 @@ class DFineConfig(PreTrainedConfig):
288
288
  )
289
289
  backbone_model_type = "hgnet_v2"
290
290
  config_class = CONFIG_MAPPING[backbone_model_type]
291
- # this will map it to RTDetrResNetConfig
292
- # note: we can instead create HGNetV2Config
291
+ # this will map it to HGNetV2Config
293
292
  # and we would need to create HGNetV2Backbone
294
293
  backbone_config = config_class(
295
294
  num_channels=3,
@@ -395,8 +394,8 @@ class DFineConfig(PreTrainedConfig):
395
394
  raise ValueError(
396
395
  f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
397
396
  )
397
+
398
398
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
399
- self.tie_encoder_decoder = True
400
399
 
401
400
 
402
401
  __all__ = ["DFineConfig"]
@@ -483,6 +483,9 @@ class DFinePreTrainedModel(PreTrainedModel):
483
483
  init.constant_(module.attention_weights.weight, 0.0)
484
484
  init.constant_(module.attention_weights.bias, 0.0)
485
485
 
486
+ num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)]
487
+ init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32))
488
+
486
489
  if isinstance(module, DFineModel):
487
490
  prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
488
491
  bias = float(-math.log((1 - prior_prob) / prior_prob))
@@ -493,6 +496,10 @@ class DFinePreTrainedModel(PreTrainedModel):
493
496
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
494
497
  if module.bias is not None:
495
498
  init.zeros_(module.bias)
499
+ if getattr(module, "running_mean", None) is not None:
500
+ init.zeros_(module.running_mean)
501
+ init.ones_(module.running_var)
502
+ init.zeros_(module.num_batches_tracked)
496
503
 
497
504
  if isinstance(module, DFineGate):
498
505
  bias = float(-math.log((1 - 0.5) / 0.5))
@@ -838,6 +845,45 @@ class DFineDecoder(DFinePreTrainedModel):
838
845
  )
839
846
 
840
847
 
848
+ class DFineFrozenBatchNorm2d(nn.Module):
849
+ """
850
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
851
+
852
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
853
+ torchvision.models.resnet[18,34,50,101] produce nans.
854
+ """
855
+
856
+ def __init__(self, n):
857
+ super().__init__()
858
+ self.register_buffer("weight", torch.ones(n))
859
+ self.register_buffer("bias", torch.zeros(n))
860
+ self.register_buffer("running_mean", torch.zeros(n))
861
+ self.register_buffer("running_var", torch.ones(n))
862
+
863
+ def _load_from_state_dict(
864
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
865
+ ):
866
+ num_batches_tracked_key = prefix + "num_batches_tracked"
867
+ if num_batches_tracked_key in state_dict:
868
+ del state_dict[num_batches_tracked_key]
869
+
870
+ super()._load_from_state_dict(
871
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
872
+ )
873
+
874
+ def forward(self, x):
875
+ # move reshapes to the beginning
876
+ # to make it user-friendly
877
+ weight = self.weight.reshape(1, -1, 1, 1)
878
+ bias = self.bias.reshape(1, -1, 1, 1)
879
+ running_var = self.running_var.reshape(1, -1, 1, 1)
880
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
881
+ epsilon = 1e-5
882
+ scale = weight * (running_var + epsilon).rsqrt()
883
+ bias = bias - running_mean * scale
884
+ return x * scale + bias
885
+
886
+
841
887
  @dataclass
842
888
  @auto_docstring(
843
889
  custom_intro="""
@@ -896,45 +942,6 @@ class DFineModelOutput(ModelOutput):
896
942
  denoising_meta_values: Optional[dict] = None
897
943
 
898
944
 
899
- class DFineFrozenBatchNorm2d(nn.Module):
900
- """
901
- BatchNorm2d where the batch statistics and the affine parameters are fixed.
902
-
903
- Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
904
- torchvision.models.resnet[18,34,50,101] produce nans.
905
- """
906
-
907
- def __init__(self, n):
908
- super().__init__()
909
- self.register_buffer("weight", torch.ones(n))
910
- self.register_buffer("bias", torch.zeros(n))
911
- self.register_buffer("running_mean", torch.zeros(n))
912
- self.register_buffer("running_var", torch.ones(n))
913
-
914
- def _load_from_state_dict(
915
- self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
916
- ):
917
- num_batches_tracked_key = prefix + "num_batches_tracked"
918
- if num_batches_tracked_key in state_dict:
919
- del state_dict[num_batches_tracked_key]
920
-
921
- super()._load_from_state_dict(
922
- state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
923
- )
924
-
925
- def forward(self, x):
926
- # move reshapes to the beginning
927
- # to make it user-friendly
928
- weight = self.weight.reshape(1, -1, 1, 1)
929
- bias = self.bias.reshape(1, -1, 1, 1)
930
- running_var = self.running_var.reshape(1, -1, 1, 1)
931
- running_mean = self.running_mean.reshape(1, -1, 1, 1)
932
- epsilon = 1e-5
933
- scale = weight * (running_var + epsilon).rsqrt()
934
- bias = bias - running_mean * scale
935
- return x * scale + bias
936
-
937
-
938
945
  def replace_batch_norm(model):
939
946
  r"""
940
947
  Recursively replace all `torch.nn.BatchNorm2d` with `DFineFrozenBatchNorm2d`.
@@ -33,6 +33,7 @@ from ..rt_detr.modeling_rt_detr import (
33
33
  RTDetrDecoderOutput,
34
34
  RTDetrEncoder,
35
35
  RTDetrForObjectDetection,
36
+ RTDetrFrozenBatchNorm2d,
36
37
  RTDetrHybridEncoder,
37
38
  RTDetrMLPPredictionHead,
38
39
  RTDetrModel,
@@ -66,7 +67,7 @@ class DFineConfig(PreTrainedConfig):
66
67
  The epsilon used by the layer normalization layers.
67
68
  batch_norm_eps (`float`, *optional*, defaults to 1e-05):
68
69
  The epsilon used by the batch normalization layers.
69
- backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
70
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `HGNetV2Config()`):
70
71
  The configuration of the backbone model.
71
72
  backbone (`str`, *optional*):
72
73
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -307,8 +308,7 @@ class DFineConfig(PreTrainedConfig):
307
308
  )
308
309
  backbone_model_type = "hgnet_v2"
309
310
  config_class = CONFIG_MAPPING[backbone_model_type]
310
- # this will map it to RTDetrResNetConfig
311
- # note: we can instead create HGNetV2Config
311
+ # this will map it to HGNetV2Config
312
312
  # and we would need to create HGNetV2Backbone
313
313
  backbone_config = config_class(
314
314
  num_channels=3,
@@ -414,8 +414,8 @@ class DFineConfig(PreTrainedConfig):
414
414
  raise ValueError(
415
415
  f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
416
416
  )
417
+
417
418
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
418
- self.tie_encoder_decoder = True
419
419
 
420
420
 
421
421
  class DFineMultiscaleDeformableAttention(nn.Module):
@@ -628,6 +628,9 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel):
628
628
  init.constant_(module.attention_weights.weight, 0.0)
629
629
  init.constant_(module.attention_weights.bias, 0.0)
630
630
 
631
+ num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)]
632
+ init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32))
633
+
631
634
  if isinstance(module, DFineModel):
632
635
  prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
633
636
  bias = float(-math.log((1 - prior_prob) / prior_prob))
@@ -638,6 +641,10 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel):
638
641
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
639
642
  if module.bias is not None:
640
643
  init.zeros_(module.bias)
644
+ if getattr(module, "running_mean", None) is not None:
645
+ init.zeros_(module.running_mean)
646
+ init.ones_(module.running_var)
647
+ init.zeros_(module.num_batches_tracked)
641
648
 
642
649
  if isinstance(module, DFineGate):
643
650
  bias = float(-math.log((1 - 0.5) / 0.5))
@@ -851,6 +858,10 @@ class DFineDecoder(RTDetrDecoder):
851
858
  )
852
859
 
853
860
 
861
+ class DFineFrozenBatchNorm2d(RTDetrFrozenBatchNorm2d):
862
+ pass
863
+
864
+
854
865
  class DFineModel(RTDetrModel):
855
866
  def __init__(self, config: DFineConfig):
856
867
  super().__init__(config)
@@ -37,7 +37,7 @@ class DabDetrConfig(PreTrainedConfig):
37
37
  use_timm_backbone (`bool`, *optional*, defaults to `True`):
38
38
  Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
39
39
  API.
40
- backbone_config (`PreTrainedConfig` or `dict`, *optional*):
40
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `ResNetConfig()`):
41
41
  The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
42
42
  case it will default to `ResNetConfig()`.
43
43
  backbone (`str`, *optional*, defaults to `"resnet50"`):
@@ -255,8 +255,8 @@ class DabDetrConfig(PreTrainedConfig):
255
255
  self.temperature_height = temperature_height
256
256
  self.sine_position_embedding_scale = sine_position_embedding_scale
257
257
  self.initializer_bias_prior_prob = initializer_bias_prior_prob
258
+
258
259
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
259
- self.tie_encoder_decoder = True # weights have to be tied for this model
260
260
 
261
261
 
262
262
  __all__ = ["DabDetrConfig"]
@@ -826,7 +826,7 @@ class DabDetrPreTrainedModel(PreTrainedModel):
826
826
  init.zeros_(module.q_linear.bias)
827
827
  init.xavier_uniform_(module.k_linear.weight, gain=xavier_std)
828
828
  init.xavier_uniform_(module.q_linear.weight, gain=xavier_std)
829
- if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
829
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
830
830
  init.normal_(module.weight, mean=0.0, std=std)
831
831
  if module.bias is not None:
832
832
  init.zeros_(module.bias)
@@ -16,7 +16,7 @@
16
16
 
17
17
  import math
18
18
  from dataclasses import dataclass
19
- from typing import Optional
19
+ from typing import Optional, Union
20
20
 
21
21
  import numpy as np
22
22
  import torch
@@ -583,7 +583,7 @@ class DacModel(DacPreTrainedModel):
583
583
  input_values: torch.Tensor,
584
584
  n_quantizers: Optional[int] = None,
585
585
  return_dict: Optional[bool] = None,
586
- ):
586
+ ) -> Union[tuple, DacEncoderOutput]:
587
587
  r"""
588
588
  input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
589
589
  Input audio data to encode,
@@ -610,7 +610,7 @@ class DacModel(DacPreTrainedModel):
610
610
  quantized_representation: Optional[torch.Tensor] = None,
611
611
  audio_codes: Optional[torch.Tensor] = None,
612
612
  return_dict: Optional[bool] = None,
613
- ):
613
+ ) -> Union[tuple, DacDecoderOutput]:
614
614
  r"""
615
615
  quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
616
616
  Quantized continuous representation of input.
@@ -643,7 +643,7 @@ class DacModel(DacPreTrainedModel):
643
643
  input_values: torch.Tensor,
644
644
  n_quantizers: Optional[int] = None,
645
645
  return_dict: Optional[bool] = None,
646
- ):
646
+ ) -> Union[tuple, DacOutput]:
647
647
  r"""
648
648
  input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`):
649
649
  Audio data to encode.
@@ -26,6 +26,7 @@ import torch
26
26
  import torch.nn as nn
27
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
28
 
29
+ from ... import initialization as init
29
30
  from ...activations import ACT2FN, gelu
30
31
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
31
32
  from ...generation import GenerationMixin
@@ -494,6 +495,12 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
494
495
  "cross_attentions": Data2VecTextCrossAttention,
495
496
  }
496
497
 
498
+ def _init_weights(self, module):
499
+ super()._init_weights(module)
500
+ if isinstance(module, Data2VecTextEmbeddings):
501
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
502
+ init.zeros_(module.token_type_ids)
503
+
497
504
 
498
505
  class Data2VecTextEncoder(nn.Module):
499
506
  def __init__(self, config):
@@ -20,6 +20,7 @@ import torch
20
20
  import torch.nn as nn
21
21
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
22
22
 
23
+ from ... import initialization as init
23
24
  from ...generation import GenerationMixin
24
25
  from ...modeling_outputs import (
25
26
  BaseModelOutputWithPoolingAndCrossAttentions,
@@ -81,6 +82,12 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
81
82
  "cross_attentions": Data2VecTextCrossAttention,
82
83
  }
83
84
 
85
+ def _init_weights(self, module):
86
+ super()._init_weights(module)
87
+ if isinstance(module, Data2VecTextEmbeddings):
88
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
89
+ init.zeros_(module.token_type_ids)
90
+
84
91
 
85
92
  @auto_docstring
86
93
  class Data2VecTextModel(RobertaModel):
@@ -104,7 +104,15 @@ class DbrxFFNConfig(PreTrainedConfig):
104
104
  self.moe_loss_weight = moe_loss_weight
105
105
  self.moe_normalize_expert_weights = moe_normalize_expert_weights
106
106
 
107
- for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype", "dtype"]:
107
+ for k in [
108
+ "model_type",
109
+ "attn_implementation",
110
+ "experts_implementation",
111
+ "transformers_version",
112
+ "_commit_hash",
113
+ "torch_dtype",
114
+ "dtype",
115
+ ]:
108
116
  if k in kwargs:
109
117
  kwargs.pop(k)
110
118
  if len(kwargs) != 0:
@@ -58,7 +58,7 @@ class DbrxRotaryEmbedding(nn.Module):
58
58
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
59
59
 
60
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
- self.original_inv_freq = inv_freq
61
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
62
62
 
63
63
  @staticmethod
64
64
  def compute_default_rope_parameters(
@@ -624,6 +624,8 @@ class DebertaPreTrainedModel(PreTrainedModel):
624
624
  init.zeros_(module.v_bias)
625
625
  elif isinstance(module, (LegacyDebertaLMPredictionHead, DebertaLMPredictionHead)):
626
626
  init.zeros_(module.bias)
627
+ elif isinstance(module, DebertaEmbeddings):
628
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
627
629
 
628
630
 
629
631
  @auto_docstring
@@ -700,6 +700,8 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
700
700
  super()._init_weights(module)
701
701
  if isinstance(module, (LegacyDebertaV2LMPredictionHead, DebertaV2LMPredictionHead)):
702
702
  init.zeros_(module.bias)
703
+ elif isinstance(module, DebertaV2Embeddings):
704
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
703
705
 
704
706
 
705
707
  @auto_docstring
@@ -94,7 +94,6 @@ class DecisionTransformerGPT2Attention(nn.Module):
94
94
  ),
95
95
  persistent=False,
96
96
  )
97
- self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
98
97
 
99
98
  self.embed_dim = config.hidden_size
100
99
  self.num_heads = config.num_attention_heads
@@ -367,12 +366,8 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
367
366
  config: DecisionTransformerConfig
368
367
  base_model_prefix = "transformer"
369
368
  supports_gradient_checkpointing = True
370
-
371
369
  _can_compile_fullgraph = False
372
370
 
373
- def __init__(self, *inputs, **kwargs):
374
- super().__init__(*inputs, **kwargs)
375
-
376
371
  @torch.no_grad()
377
372
  def _init_weights(self, module):
378
373
  """Initialize the weights."""
@@ -389,6 +384,14 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
389
384
  if "c_proj" in name and "weight" in name:
390
385
  # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
391
386
  init.normal_(p, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
387
+ elif isinstance(module, DecisionTransformerGPT2Attention):
388
+ max_positions = module.config.max_position_embeddings
389
+ init.copy_(
390
+ module.bias,
391
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
392
+ 1, 1, max_positions, max_positions
393
+ ),
394
+ )
392
395
 
393
396
 
394
397
  class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
@@ -30,18 +30,19 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub
33
+ from ...integrations import use_experts_implementation, use_kernel_forward_from_hub
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
36
36
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
37
37
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
40
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
41
41
  from ...utils.generic import check_model_inputs, maybe_autocast
42
42
  from .configuration_deepseek_v2 import DeepseekV2Config
43
43
 
44
44
 
45
+ @use_experts_implementation
45
46
  class DeepseekV2Experts(nn.Module):
46
47
  """Collection of expert weights stored as 3D tensors."""
47
48
 
@@ -184,7 +185,7 @@ class DeepseekV2RotaryEmbedding(nn.Module):
184
185
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
185
186
 
186
187
  self.register_buffer("inv_freq", inv_freq, persistent=False)
187
- self.original_inv_freq = inv_freq
188
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
188
189
 
189
190
  @staticmethod
190
191
  def compute_default_rope_parameters(
@@ -453,7 +454,9 @@ class DeepseekV2PreTrainedModel(PreTrainedModel):
453
454
  _supports_flash_attn = True
454
455
  _supports_sdpa = True
455
456
  _supports_flex_attn = True
456
- _can_compile_fullgraph = False
457
+ _can_compile_fullgraph = (
458
+ is_grouped_mm_available()
459
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
457
460
  _supports_attention_backend = True
458
461
  _can_record_outputs = {
459
462
  "hidden_states": DeepseekV2DecoderLayer,
@@ -24,7 +24,7 @@ from ... import initialization as init
24
24
  from ...cache_utils import Cache
25
25
  from ...modeling_rope_utils import RopeParameters, dynamic_rope_update
26
26
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27
- from ...utils import logging
27
+ from ...utils import is_grouped_mm_available, logging
28
28
  from ...utils.generic import maybe_autocast
29
29
  from ..llama.configuration_llama import LlamaConfig
30
30
  from ..llama.modeling_llama import (
@@ -437,7 +437,9 @@ class DeepseekV2DecoderLayer(LlamaDecoderLayer):
437
437
 
438
438
 
439
439
  class DeepseekV2PreTrainedModel(LlamaPreTrainedModel):
440
- _can_compile_fullgraph = False
440
+ _can_compile_fullgraph = (
441
+ is_grouped_mm_available()
442
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
441
443
 
442
444
  @torch.no_grad()
443
445
  def _init_weights(self, module):
@@ -16,7 +16,7 @@ from ... import initialization as init
16
16
  from ...activations import ACT2FN
17
17
  from ...cache_utils import Cache, DynamicCache
18
18
  from ...generation import GenerationMixin
19
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
19
+ from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernel_func_from_hub
20
20
  from ...masking_utils import create_causal_mask
21
21
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
22
22
  from ...modeling_layers import (
@@ -28,7 +28,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
28
28
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
29
29
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
30
30
  from ...processing_utils import Unpack
31
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
31
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
32
32
  from ...utils.generic import check_model_inputs, maybe_autocast
33
33
  from .configuration_deepseek_v3 import DeepseekV3Config
34
34
 
@@ -71,7 +71,7 @@ class DeepseekV3RotaryEmbedding(nn.Module):
71
71
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
72
72
 
73
73
  self.register_buffer("inv_freq", inv_freq, persistent=False)
74
- self.original_inv_freq = inv_freq
74
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
75
75
 
76
76
  @staticmethod
77
77
  def compute_default_rope_parameters(
@@ -150,6 +150,7 @@ class DeepseekV3TopkRouter(nn.Module):
150
150
  return router_logits
151
151
 
152
152
 
153
+ @use_experts_implementation
153
154
  class DeepseekV3NaiveMoe(nn.Module):
154
155
  """Collection of expert weights stored as 3D tensors."""
155
156
 
@@ -157,7 +158,7 @@ class DeepseekV3NaiveMoe(nn.Module):
157
158
  super().__init__()
158
159
  self.num_experts = config.num_local_experts
159
160
  self.hidden_dim = config.hidden_size
160
- self.intermediate_dim = config.intermediate_size
161
+ self.intermediate_dim = config.moe_intermediate_size
161
162
  self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
162
163
  self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
163
164
  self.act_fn = ACT2FN[config.hidden_act]
@@ -542,7 +543,9 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
542
543
  _supports_flash_attn = True
543
544
  _supports_sdpa = True
544
545
  _supports_flex_attn = True
545
- _can_compile_fullgraph = False
546
+ _can_compile_fullgraph = (
547
+ is_grouped_mm_available()
548
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
546
549
  _supports_attention_backend = True
547
550
  _can_record_outputs = {
548
551
  "hidden_states": DeepseekV3DecoderLayer,
@@ -555,6 +558,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
555
558
  super()._init_weights(module)
556
559
  if isinstance(module, DeepseekV3TopkRouter):
557
560
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
561
+ init.zeros_(module.e_score_correction_bias)
558
562
  elif isinstance(module, DeepseekV3NaiveMoe):
559
563
  init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
560
564
  init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
@@ -12,7 +12,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs
12
12
  from ...modeling_layers import GenericForSequenceClassification, GenericForTokenClassification
13
13
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
14
14
  from ...processing_utils import Unpack
15
- from ...utils import logging
15
+ from ...utils import is_grouped_mm_available, logging
16
16
  from ..llama.modeling_llama import (
17
17
  LlamaDecoderLayer,
18
18
  LlamaForCausalLM,
@@ -107,6 +107,7 @@ class DeepseekV3NaiveMoe(MixtralExperts):
107
107
  def __init__(self, config):
108
108
  super().__init__(config)
109
109
  self.num_experts = config.num_local_experts
110
+ self.intermediate_dim = config.moe_intermediate_size
110
111
 
111
112
 
112
113
  class DeepseekV3MoE(nn.Module):
@@ -303,7 +304,9 @@ class DeepseekV3DecoderLayer(LlamaDecoderLayer):
303
304
 
304
305
 
305
306
  class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
306
- _can_compile_fullgraph = False
307
+ _can_compile_fullgraph = (
308
+ is_grouped_mm_available()
309
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
307
310
  _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
308
311
 
309
312
  @torch.no_grad()
@@ -311,6 +314,7 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
311
314
  PreTrainedModel._init_weights(self, module)
312
315
  if isinstance(module, DeepseekV3TopkRouter):
313
316
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
317
+ init.zeros_(module.e_score_correction_bias)
314
318
  elif isinstance(module, DeepseekV3NaiveMoe):
315
319
  init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
316
320
  init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
@@ -171,7 +171,6 @@ class DeepseekVLImageProcessorFast(BaseImageProcessorFast):
171
171
  processed_images_grouped[shape] = stacked_images
172
172
 
173
173
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
174
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
175
174
 
176
175
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
177
176