transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,12 @@ from ... import initialization as init
31
31
  from ...activations import ACT2FN
32
32
  from ...cache_utils import Cache, DynamicCache
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
34
+ from ...integrations import (
35
+ use_experts_implementation,
36
+ use_kernel_forward_from_hub,
37
+ use_kernel_func_from_hub,
38
+ use_kernelized_func,
39
+ )
35
40
  from ...masking_utils import create_causal_mask
36
41
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
37
42
  from ...modeling_layers import GradientCheckpointingLayer
@@ -39,7 +44,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
39
44
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
45
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
46
  from ...processing_utils import Unpack
42
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
47
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
43
48
  from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
44
49
  from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig
45
50
 
@@ -65,92 +70,77 @@ class Qwen3VLMoeTextRMSNorm(nn.Module):
65
70
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
66
71
 
67
72
 
73
+ @use_experts_implementation
68
74
  class Qwen3VLMoeTextExperts(nn.Module):
75
+ """Collection of expert weights stored as 3D tensors."""
76
+
69
77
  def __init__(self, config):
70
78
  super().__init__()
71
79
  self.num_experts = config.num_experts
72
- self.intermediate_size = config.moe_intermediate_size
73
- self.hidden_size = config.hidden_size
74
- self.expert_dim = self.intermediate_size
75
- self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
76
- self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
80
+ self.hidden_dim = config.hidden_size
81
+ self.intermediate_dim = config.moe_intermediate_size
82
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
83
+ self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
77
84
  self.act_fn = ACT2FN[config.hidden_act]
78
85
 
79
86
  def forward(
80
- self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ top_k_index: torch.Tensor,
90
+ top_k_weights: torch.Tensor,
81
91
  ) -> torch.Tensor:
82
- """
83
- When training it is more efficient to just loop over the experts and compute the output for each expert
84
- as otherwise the memory would explode.
92
+ final_hidden_states = torch.zeros_like(hidden_states)
93
+ with torch.no_grad():
94
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
95
+ expert_mask = expert_mask.permute(2, 1, 0)
96
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
97
+
98
+ for expert_idx in expert_hit:
99
+ expert_idx = expert_idx[0]
100
+ if expert_idx == self.num_experts:
101
+ continue
102
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
103
+ current_state = hidden_states[token_idx]
104
+ gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
105
+ current_hidden_states = self.act_fn(gate) * up
106
+ current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
107
+ current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
108
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
109
+
110
+ return final_hidden_states
85
111
 
86
- For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
87
112
 
88
- Args:
89
- hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
90
- routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
91
- router_indices (torch.Tensor): (batch_size * token_num, top_k)
92
- Returns:
93
- torch.Tensor
94
- """
95
- batch_size = hidden_states.shape[0]
96
- hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
97
- if self.training:
98
- next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
99
- with torch.no_grad():
100
- expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
101
- expert_mask = expert_mask.permute(2, 1, 0)
102
- # we sum on the top_k and on the sequence length to get which experts
103
- # are hit this time around
104
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
105
- for expert_idx in expert_hit[:]:
106
- with torch.no_grad():
107
- _, token_idx = torch.where(expert_mask[expert_idx[0]])
108
- current_state = hidden_states[token_idx]
109
- gate_up = current_state @ self.gate_up_proj[expert_idx]
110
- gate, up = gate_up.chunk(2, dim=-1)
111
- gated_output = up * self.act_fn(gate)
112
- out = gated_output @ self.down_proj[expert_idx]
113
- weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
114
- next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
115
- next_states = next_states.view(batch_size, -1, self.hidden_size)
116
- else:
117
- hidden_states = hidden_states.repeat(self.num_experts, 1)
118
- hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
119
- gate_up = torch.bmm(hidden_states, self.gate_up_proj)
120
- gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
121
- next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
122
- next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size)
123
- next_states = (
124
- next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
125
- )
126
- next_states = next_states.sum(dim=0)
127
- return next_states
113
+ class Qwen3VLMoeTextTopKRouter(nn.Module):
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ self.top_k = config.num_experts_per_tok
117
+ self.num_experts = config.num_experts
118
+ self.hidden_dim = config.hidden_size
119
+ self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
120
+
121
+ def forward(self, hidden_states):
122
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
123
+ router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
124
+ router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
125
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
126
+ router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
127
+ router_top_value = router_top_value.to(router_logits.dtype)
128
+ router_scores = router_top_value
129
+ return router_logits, router_scores, router_indices
128
130
 
129
131
 
130
132
  class Qwen3VLMoeTextSparseMoeBlock(nn.Module):
131
- def __init__(self, config):
133
+ def __init__(self, config: Qwen3VLMoeTextConfig):
132
134
  super().__init__()
133
- self.hidden_size = config.hidden_size
134
- self.num_experts = config.num_experts
135
- self.top_k = config.num_experts_per_tok
136
- self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
137
135
  self.experts = Qwen3VLMoeTextExperts(config)
136
+ self.gate = Qwen3VLMoeTextTopKRouter(config)
138
137
 
139
- # since all the models use norm_topk_prob, we don't need to have a extra check for it
140
- # self.norm_topk_prob = config.norm_topk_prob
141
-
142
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
143
- batch_size = hidden_states.shape[0]
144
- hidden_states = hidden_states.reshape(-1, self.hidden_size)
145
- router_logits = self.gate(hidden_states)
146
- routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
147
- routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
148
- routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
149
- routing_weights = routing_weights.to(router_logits.dtype)
150
- router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
151
- hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
152
- routed_out = self.experts(hidden_states, router_weights, router_indices)
153
- return routed_out
138
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
139
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
140
+ hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
141
+ _, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
142
+ final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
143
+ return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
154
144
 
155
145
 
156
146
  def rotate_half(x):
@@ -368,27 +358,6 @@ class Qwen3VLMoeTextDecoderLayer(GradientCheckpointingLayer):
368
358
  return hidden_states
369
359
 
370
360
 
371
- class Qwen3VLMoeTextTopKRouter(nn.Module):
372
- def __init__(self, config):
373
- super().__init__()
374
- self.top_k = config.num_experts_per_tok
375
- self.num_experts = config.num_experts
376
- self.norm_topk_prob = config.norm_topk_prob
377
- self.hidden_dim = config.hidden_size
378
- self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
379
-
380
- def forward(self, hidden_states):
381
- hidden_states = hidden_states.reshape(-1, self.hidden_dim)
382
- router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
383
- router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
384
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
385
- if self.norm_topk_prob:
386
- router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
387
- router_top_value = router_top_value.to(router_logits.dtype)
388
- router_scores = router_top_value
389
- return router_logits, router_scores, router_indices
390
-
391
-
392
361
  @auto_docstring
393
362
  class Qwen3VLMoePreTrainedModel(PreTrainedModel):
394
363
  config: Qwen3VLMoeConfig
@@ -399,7 +368,9 @@ class Qwen3VLMoePreTrainedModel(PreTrainedModel):
399
368
  _supports_flash_attn = True
400
369
  _supports_sdpa = True
401
370
  _supports_flex_attn = True
402
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
371
+ _can_compile_fullgraph = (
372
+ is_grouped_mm_available()
373
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
403
374
  _supports_attention_backend = True
404
375
  _can_record_outputs = {
405
376
  "router_logits": OutputRecorder(Qwen3VLMoeTextTopKRouter, layer_name="mlp.gate", index=0),
@@ -418,6 +389,27 @@ class Qwen3VLMoePreTrainedModel(PreTrainedModel):
418
389
  if isinstance(module, Qwen3VLMoeTextExperts):
419
390
  init.normal_(module.gate_up_proj, mean=0.0, std=std)
420
391
  init.normal_(module.down_proj, mean=0.0, std=std)
392
+ elif isinstance(module, Qwen3VLMoeTextTopKRouter):
393
+ init.normal_(module.weight, mean=0.0, std=std)
394
+ elif isinstance(module, Qwen3VLMoeVisionRotaryEmbedding):
395
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
396
+ init.copy_(module.inv_freq, inv_freq)
397
+
398
+
399
+ class Qwen3VLMoeVisionRotaryEmbedding(nn.Module):
400
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
401
+
402
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
403
+ super().__init__()
404
+ self.dim = dim
405
+ self.theta = theta
406
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
407
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
408
+
409
+ def forward(self, seqlen: int) -> torch.Tensor:
410
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
411
+ freqs = torch.outer(seq, self.inv_freq)
412
+ return freqs
421
413
 
422
414
 
423
415
  class Qwen3VLMoeVisionMLP(nn.Module):
@@ -453,20 +445,6 @@ class Qwen3VLMoeVisionPatchEmbed(nn.Module):
453
445
  return hidden_states
454
446
 
455
447
 
456
- class Qwen3VLMoeVisionRotaryEmbedding(nn.Module):
457
- inv_freq: torch.Tensor # fix linting for `register_buffer`
458
-
459
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
460
- super().__init__()
461
- inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
462
- self.register_buffer("inv_freq", inv_freq, persistent=False)
463
-
464
- def forward(self, seqlen: int) -> torch.Tensor:
465
- seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
466
- freqs = torch.outer(seq, self.inv_freq)
467
- return freqs
468
-
469
-
470
448
  class Qwen3VLMoeVisionPatchMerger(nn.Module):
471
449
  def __init__(self, config: Qwen3VLMoeVisionConfig, use_postshuffle_norm=False) -> None:
472
450
  super().__init__()
@@ -534,8 +512,8 @@ class Qwen3VLMoeVisionAttention(nn.Module):
534
512
  if self.config._attn_implementation != "eager":
535
513
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
536
514
 
537
- if self.config._attn_implementation == "flash_attention_2":
538
- # Flash Attention 2: Use cu_seqlens for variable length attention
515
+ if "flash" in self.config._attn_implementation:
516
+ # Flash Attention: Use cu_seqlens for variable length attention
539
517
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
540
518
  attn_output, _ = attention_interface(
541
519
  self,
@@ -646,6 +624,8 @@ class Qwen3VLMoeVisionModel(Qwen3VLMoePreTrainedModel):
646
624
 
647
625
  self.gradient_checkpointing = False
648
626
 
627
+ self.post_init()
628
+
649
629
  def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
650
630
  merge_size = self.spatial_merge_size
651
631
 
@@ -815,7 +795,7 @@ class Qwen3VLMoeTextRotaryEmbedding(nn.Module):
815
795
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
816
796
 
817
797
  self.register_buffer("inv_freq", inv_freq, persistent=False)
818
- self.original_inv_freq = inv_freq
798
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
819
799
 
820
800
  self.mrope_section = config.rope_parameters.get("mrope_section", [24, 20, 20])
821
801
 
@@ -1635,6 +1615,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMi
1635
1615
  pixel_values_videos=None,
1636
1616
  image_grid_thw=None,
1637
1617
  video_grid_thw=None,
1618
+ is_first_iteration=False,
1638
1619
  **kwargs,
1639
1620
  ):
1640
1621
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -1651,6 +1632,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMi
1651
1632
  image_grid_thw=image_grid_thw,
1652
1633
  video_grid_thw=video_grid_thw,
1653
1634
  use_cache=use_cache,
1635
+ is_first_iteration=is_first_iteration,
1654
1636
  **kwargs,
1655
1637
  )
1656
1638
 
@@ -1682,7 +1664,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMi
1682
1664
  text_positions = model_inputs["position_ids"][None, ...]
1683
1665
  model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
1684
1666
 
1685
- if cache_position[0] != 0:
1667
+ if not is_first_iteration and use_cache:
1686
1668
  model_inputs["pixel_values"] = None
1687
1669
  model_inputs["pixel_values_videos"] = None
1688
1670
 
@@ -18,9 +18,9 @@ from typing import Optional, Union
18
18
 
19
19
  import torch
20
20
  import torch.nn as nn
21
+ import torch.nn.functional as F
21
22
 
22
23
  from ... import initialization as init
23
- from ...activations import ACT2FN
24
24
  from ...cache_utils import Cache
25
25
  from ...configuration_utils import PreTrainedConfig
26
26
  from ...modeling_rope_utils import RopeParameters
@@ -29,8 +29,10 @@ from ...processing_utils import Unpack
29
29
  from ...utils import TransformersKwargs, can_return_tuple, logging
30
30
  from ..qwen3_moe.modeling_qwen3_moe import (
31
31
  Qwen3MoeDecoderLayer,
32
+ Qwen3MoeExperts,
32
33
  Qwen3MoePreTrainedModel,
33
34
  Qwen3MoeRMSNorm,
35
+ Qwen3MoeSparseMoeBlock,
34
36
  load_balancing_loss_func,
35
37
  )
36
38
  from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
@@ -41,6 +43,7 @@ from ..qwen3_vl.modeling_qwen3_vl import (
41
43
  Qwen3VLTextAttention,
42
44
  Qwen3VLTextModel,
43
45
  Qwen3VLVisionModel,
46
+ Qwen3VLVisionRotaryEmbedding,
44
47
  )
45
48
 
46
49
 
@@ -257,92 +260,31 @@ class Qwen3VLMoeTextRMSNorm(Qwen3MoeRMSNorm):
257
260
  pass
258
261
 
259
262
 
260
- class Qwen3VLMoeTextExperts(nn.Module):
263
+ class Qwen3VLMoeTextExperts(Qwen3MoeExperts):
264
+ pass
265
+
266
+
267
+ class Qwen3VLMoeTextTopKRouter(nn.Module):
261
268
  def __init__(self, config):
262
269
  super().__init__()
270
+ self.top_k = config.num_experts_per_tok
263
271
  self.num_experts = config.num_experts
264
- self.intermediate_size = config.moe_intermediate_size
265
- self.hidden_size = config.hidden_size
266
- self.expert_dim = self.intermediate_size
267
- self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
268
- self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
269
- self.act_fn = ACT2FN[config.hidden_act]
272
+ self.hidden_dim = config.hidden_size
273
+ self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
270
274
 
271
- def forward(
272
- self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
273
- ) -> torch.Tensor:
274
- """
275
- When training it is more efficient to just loop over the experts and compute the output for each expert
276
- as otherwise the memory would explode.
277
-
278
- For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
279
-
280
- Args:
281
- hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
282
- routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
283
- router_indices (torch.Tensor): (batch_size * token_num, top_k)
284
- Returns:
285
- torch.Tensor
286
- """
287
- batch_size = hidden_states.shape[0]
288
- hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
289
- if self.training:
290
- next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
291
- with torch.no_grad():
292
- expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
293
- expert_mask = expert_mask.permute(2, 1, 0)
294
- # we sum on the top_k and on the sequence length to get which experts
295
- # are hit this time around
296
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
297
- for expert_idx in expert_hit[:]:
298
- with torch.no_grad():
299
- _, token_idx = torch.where(expert_mask[expert_idx[0]])
300
- current_state = hidden_states[token_idx]
301
- gate_up = current_state @ self.gate_up_proj[expert_idx]
302
- gate, up = gate_up.chunk(2, dim=-1)
303
- gated_output = up * self.act_fn(gate)
304
- out = gated_output @ self.down_proj[expert_idx]
305
- weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
306
- next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
307
- next_states = next_states.view(batch_size, -1, self.hidden_size)
308
- else:
309
- hidden_states = hidden_states.repeat(self.num_experts, 1)
310
- hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
311
- gate_up = torch.bmm(hidden_states, self.gate_up_proj)
312
- gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
313
- next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
314
- next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size)
315
- next_states = (
316
- next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
317
- )
318
- next_states = next_states.sum(dim=0)
319
- return next_states
275
+ def forward(self, hidden_states):
276
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
277
+ router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
278
+ router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
279
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
280
+ router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
281
+ router_top_value = router_top_value.to(router_logits.dtype)
282
+ router_scores = router_top_value
283
+ return router_logits, router_scores, router_indices
320
284
 
321
285
 
322
- class Qwen3VLMoeTextSparseMoeBlock(nn.Module):
323
- def __init__(self, config):
324
- super().__init__()
325
- self.hidden_size = config.hidden_size
326
- self.num_experts = config.num_experts
327
- self.top_k = config.num_experts_per_tok
328
- self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
329
- self.experts = Qwen3VLMoeTextExperts(config)
330
-
331
- # since all the models use norm_topk_prob, we don't need to have a extra check for it
332
- # self.norm_topk_prob = config.norm_topk_prob
333
-
334
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
335
- batch_size = hidden_states.shape[0]
336
- hidden_states = hidden_states.reshape(-1, self.hidden_size)
337
- router_logits = self.gate(hidden_states)
338
- routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
339
- routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
340
- routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
341
- routing_weights = routing_weights.to(router_logits.dtype)
342
- router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
343
- hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
344
- routed_out = self.experts(hidden_states, router_weights, router_indices)
345
- return routed_out
286
+ class Qwen3VLMoeTextSparseMoeBlock(Qwen3MoeSparseMoeBlock):
287
+ pass
346
288
 
347
289
 
348
290
  class Qwen3VLMoeTextAttention(Qwen3VLTextAttention):
@@ -368,6 +310,15 @@ class Qwen3VLMoePreTrainedModel(Qwen3MoePreTrainedModel):
368
310
  if isinstance(module, Qwen3VLMoeTextExperts):
369
311
  init.normal_(module.gate_up_proj, mean=0.0, std=std)
370
312
  init.normal_(module.down_proj, mean=0.0, std=std)
313
+ elif isinstance(module, Qwen3VLMoeTextTopKRouter):
314
+ init.normal_(module.weight, mean=0.0, std=std)
315
+ elif isinstance(module, Qwen3VLMoeVisionRotaryEmbedding):
316
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
317
+ init.copy_(module.inv_freq, inv_freq)
318
+
319
+
320
+ class Qwen3VLMoeVisionRotaryEmbedding(Qwen3VLVisionRotaryEmbedding):
321
+ pass
371
322
 
372
323
 
373
324
  class Qwen3VLMoeVisionModel(Qwen3VLVisionModel):
@@ -70,9 +70,6 @@ RAG_CONFIG_DOC = r"""
70
70
  `context_attention_mask` are returned. See returned tensors for more detail.
71
71
  use_cache (`bool`, *optional*, defaults to `True`):
72
72
  Whether or not the model should return the last key/values attentions (not used by all models).
73
- forced_eos_token_id (`int`, *optional*):
74
- The id of the token to force as the last generated token when `max_length` is reached. Usually set to
75
- `eos_token_id`.
76
73
  """
77
74
 
78
75
 
@@ -109,7 +106,6 @@ class RagConfig(PreTrainedConfig):
109
106
  do_marginalize=False,
110
107
  output_retrieved=False,
111
108
  use_cache=True,
112
- forced_eos_token_id=None,
113
109
  dataset_revision=None,
114
110
  **kwargs,
115
111
  ):
@@ -118,7 +114,6 @@ class RagConfig(PreTrainedConfig):
118
114
  pad_token_id=pad_token_id,
119
115
  eos_token_id=eos_token_id,
120
116
  decoder_start_token_id=decoder_start_token_id,
121
- forced_eos_token_id=forced_eos_token_id,
122
117
  is_encoder_decoder=is_encoder_decoder,
123
118
  prefix=prefix,
124
119
  vocab_size=vocab_size,
@@ -166,9 +161,6 @@ class RagConfig(PreTrainedConfig):
166
161
 
167
162
  self.use_cache = use_cache
168
163
 
169
- if forced_eos_token_id is None:
170
- self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)
171
-
172
164
  @classmethod
173
165
  def from_question_encoder_generator_configs(
174
166
  cls, question_encoder_config: PreTrainedConfig, generator_config: PreTrainedConfig, **kwargs
@@ -422,6 +422,8 @@ class RagModel(RagPreTrainedModel):
422
422
  self.ctx_encoder = None
423
423
  self.context_encoder_training = False
424
424
 
425
+ self.post_init()
426
+
425
427
  @auto_docstring
426
428
  def forward(
427
429
  self,
@@ -690,6 +692,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
690
692
  # instantiate model
691
693
  self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)
692
694
 
695
+ self.post_init()
696
+
693
697
  def set_retriever(self, retriever: RagRetriever):
694
698
  self.rag.retriever = retriever
695
699
 
@@ -1126,6 +1130,8 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
1126
1130
  # instantiate model
1127
1131
  self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)
1128
1132
 
1133
+ self.post_init()
1134
+
1129
1135
  def set_retriever(self, retriever: RagRetriever):
1130
1136
  self.rag.retriever = retriever
1131
1137
 
@@ -1404,7 +1410,6 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
1404
1410
  prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
1405
1411
  logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
1406
1412
  stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
1407
- use_model_defaults: Optional[bool] = None,
1408
1413
  **kwargs,
1409
1414
  ) -> torch.LongTensor:
1410
1415
  """
@@ -1463,11 +1468,6 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
1463
1468
  Custom stopping criteria that complement the default stopping criteria built from arguments and a
1464
1469
  model's config. If a stopping criteria is passed that is already created with the arguments or a
1465
1470
  model's config an error is thrown.
1466
- use_model_defaults (`bool`, *optional*):
1467
- When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
1468
- generation configuration (`model.generation_config`), as opposed to the global defaults
1469
- (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
1470
- `True`.
1471
1471
  kwargs (`dict[str, Any]`, *optional*):
1472
1472
  Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
1473
1473
  forwarded to the `forward` function of the model.
@@ -1479,9 +1479,7 @@ class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
1479
1479
  """
1480
1480
  # Handle `generation_config` and kwargs that might update it
1481
1481
  generation_mode_kwargs = self._extract_generation_mode_kwargs(None, kwargs, False, None, None)
1482
- generation_config, model_kwargs = self._prepare_generation_config(
1483
- generation_config, use_model_defaults, **kwargs
1484
- )
1482
+ generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
1485
1483
  generation_mode = generation_config.get_generation_mode()
1486
1484
  if generation_mode not in [
1487
1485
  GenerationMode.SAMPLE,
@@ -80,7 +80,7 @@ class RecurrentGemmaRotaryEmbedding(nn.Module):
80
80
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
81
81
 
82
82
  self.register_buffer("inv_freq", inv_freq, persistent=False)
83
- self.original_inv_freq = inv_freq
83
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
84
84
 
85
85
  @staticmethod
86
86
  # Ignore copy
@@ -611,10 +611,11 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
611
611
  # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
612
612
  if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
613
613
  init.zeros_(module.weight[module.padding_idx])
614
-
615
614
  # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
616
615
  elif isinstance(module, RecurrentGemmaRMSNorm):
617
616
  init.zeros_(module.weight)
617
+ elif isinstance(module, RecurrentGemmaModel):
618
+ init.constant_(module.normalizer, module.config.hidden_size**0.5)
618
619
 
619
620
  def _setup_cache(self, config, batch, device, dtype):
620
621
  layers = getattr(self, "model", self).layers
@@ -1851,6 +1851,14 @@ class ReformerPreTrainedModel(PreTrainedModel):
1851
1851
  if isinstance(module, AxialPositionEmbeddings):
1852
1852
  for weight in module.weights:
1853
1853
  init.normal_(weight, std=self.config.axial_norm_std)
1854
+ elif isinstance(module, LSHSelfAttention):
1855
+ init.constant_(module.self_mask_value_float16, -1e3)
1856
+ init.constant_(module.self_mask_value_float32, -1e5)
1857
+ init.constant_(module.mask_value_float16, -1e4)
1858
+ init.constant_(module.mask_value_float32, -1e9)
1859
+ elif isinstance(module, LocalSelfAttention):
1860
+ init.constant_(module.mask_value_float16, -1e4)
1861
+ init.constant_(module.mask_value_float32, -1e9)
1854
1862
 
1855
1863
 
1856
1864
  @dataclass
@@ -2239,7 +2247,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel, GenerationMixin):
2239
2247
  )
2240
2248
 
2241
2249
  def prepare_inputs_for_generation(
2242
- self, input_ids, past_key_values=None, use_cache=None, num_hashes=None, **kwargs
2250
+ self, input_ids, past_key_values=None, use_cache=None, num_hashes=None, is_first_iteration=False, **kwargs
2243
2251
  ):
2244
2252
  # Overitten -- different expected inputs/outputs
2245
2253
 
@@ -278,6 +278,10 @@ class RegNetPreTrainedModel(PreTrainedModel):
278
278
  elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
279
279
  init.constant_(module.weight, 1)
280
280
  init.constant_(module.bias, 0)
281
+ if getattr(module, "running_mean", None) is not None:
282
+ init.zeros_(module.running_mean)
283
+ init.ones_(module.running_var)
284
+ init.zeros_(module.num_batches_tracked)
281
285
 
282
286
 
283
287
  @auto_docstring
@@ -21,6 +21,7 @@ import torch
21
21
  from torch import nn
22
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
23
 
24
+ from ... import initialization as init
24
25
  from ...activations import ACT2FN
25
26
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
26
27
  from ...generation import GenerationMixin
@@ -488,6 +489,11 @@ class RemBertPreTrainedModel(PreTrainedModel):
488
489
  base_model_prefix = "rembert"
489
490
  supports_gradient_checkpointing = True
490
491
 
492
+ def _init_weights(self, module):
493
+ super()._init_weights(module)
494
+ if isinstance(module, RemBertEmbeddings):
495
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
496
+
491
497
 
492
498
  @auto_docstring(
493
499
  custom_intro="""
@@ -702,7 +708,7 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
702
708
  attentions=outputs.attentions,
703
709
  )
704
710
 
705
- def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
711
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, is_first_iteration=False, **model_kwargs):
706
712
  input_shape = input_ids.shape
707
713
  effective_batch_size = input_shape[0]
708
714