transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,12 @@ from ... import initialization as init
35
35
  from ...activations import ACT2FN
36
36
  from ...cache_utils import Cache, DynamicCache
37
37
  from ...generation import GenerationMixin
38
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
38
+ from ...integrations import (
39
+ use_experts_implementation,
40
+ use_kernel_forward_from_hub,
41
+ use_kernel_func_from_hub,
42
+ use_kernelized_func,
43
+ )
39
44
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
40
45
  from ...modeling_layers import (
41
46
  GenericForQuestionAnswering,
@@ -47,7 +52,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPas
47
52
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
48
53
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
49
54
  from ...processing_utils import Unpack
50
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
55
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
51
56
  from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
52
57
  from .configuration_qwen2_moe import Qwen2MoeConfig
53
58
 
@@ -90,7 +95,7 @@ class Qwen2MoeRotaryEmbedding(nn.Module):
90
95
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
91
96
 
92
97
  self.register_buffer("inv_freq", inv_freq, persistent=False)
93
- self.original_inv_freq = inv_freq
98
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
94
99
 
95
100
  @staticmethod
96
101
  def compute_default_rope_parameters(
@@ -292,6 +297,7 @@ class Qwen2MoeAttention(nn.Module):
292
297
  return attn_output, attn_weights
293
298
 
294
299
 
300
+ @use_experts_implementation
295
301
  class Qwen2MoeExperts(nn.Module):
296
302
  """Collection of expert weights stored as 3D tensors."""
297
303
 
@@ -432,7 +438,9 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
432
438
  _supports_flash_attn = True
433
439
  _supports_sdpa = True
434
440
  _supports_flex_attn = True
435
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
441
+ _can_compile_fullgraph = (
442
+ is_grouped_mm_available()
443
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
436
444
  _supports_attention_backend = True
437
445
  _can_record_outputs = {
438
446
  "router_logits": OutputRecorder(Qwen2MoeTopKRouter, index=0),
@@ -159,8 +159,9 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
159
159
  **kwargs,
160
160
  ) -> None:
161
161
  super().__init__(**kwargs)
162
- if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
163
- raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
162
+ if size is not None:
163
+ if "shortest_edge" not in size or "longest_edge" not in size:
164
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
164
165
  else:
165
166
  size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
166
167
  # backward compatibility: override size with min_pixels and max_pixels if they are provided
@@ -28,6 +28,7 @@ import torch.nn as nn
28
28
  import torch.nn.functional as F
29
29
  from torch.nn import LayerNorm
30
30
 
31
+ from ... import initialization as init
31
32
  from ...activations import ACT2FN
32
33
  from ...cache_utils import Cache, DynamicCache
33
34
  from ...generation import GenerationMixin
@@ -125,7 +126,7 @@ class Qwen2VLRotaryEmbedding(nn.Module):
125
126
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
126
127
 
127
128
  self.register_buffer("inv_freq", inv_freq, persistent=False)
128
- self.original_inv_freq = inv_freq
129
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
129
130
 
130
131
  @staticmethod
131
132
  def compute_default_rope_parameters(
@@ -246,6 +247,8 @@ class VisionRotaryEmbedding(nn.Module):
246
247
 
247
248
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
248
249
  super().__init__()
250
+ self.dim = dim
251
+ self.theta = theta
249
252
  inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
250
253
  self.register_buffer("inv_freq", inv_freq, persistent=False)
251
254
 
@@ -384,8 +387,8 @@ class VisionAttention(nn.Module):
384
387
  if self.config._attn_implementation != "eager":
385
388
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
386
389
 
387
- if self.config._attn_implementation == "flash_attention_2":
388
- # Flash Attention 2: Use cu_seqlens for variable length attention
390
+ if "flash" in self.config._attn_implementation:
391
+ # Flash Attention: Use cu_seqlens for variable length attention
389
392
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
390
393
  attn_output, _ = attention_interface(
391
394
  self,
@@ -665,6 +668,12 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
665
668
  _can_compile_fullgraph = True
666
669
  _supports_attention_backend = True
667
670
 
671
+ def _init_weights(self, module):
672
+ super()._init_weights(module)
673
+ if isinstance(module, VisionRotaryEmbedding):
674
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
675
+ init.copy_(module.inv_freq, inv_freq)
676
+
668
677
 
669
678
  @auto_docstring
670
679
  class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
@@ -693,6 +702,8 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
693
702
  )
694
703
  self.gradient_checkpointing = False
695
704
 
705
+ self.post_init()
706
+
696
707
  def get_dtype(self) -> torch.dtype:
697
708
  return self.blocks[0].mlp.fc2.weight.dtype
698
709
 
@@ -1416,6 +1427,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1416
1427
  pixel_values_videos=None,
1417
1428
  image_grid_thw=None,
1418
1429
  video_grid_thw=None,
1430
+ is_first_iteration=False,
1419
1431
  **kwargs,
1420
1432
  ):
1421
1433
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -1432,6 +1444,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1432
1444
  image_grid_thw=image_grid_thw,
1433
1445
  video_grid_thw=video_grid_thw,
1434
1446
  use_cache=use_cache,
1447
+ is_first_iteration=is_first_iteration,
1435
1448
  **kwargs,
1436
1449
  )
1437
1450
 
@@ -1463,7 +1476,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1463
1476
  text_positions = model_inputs["position_ids"][None, ...]
1464
1477
  model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
1465
1478
 
1466
- if model_inputs["cache_position"][0] != 0:
1479
+ if not is_first_iteration and use_cache:
1467
1480
  model_inputs["pixel_values"] = None
1468
1481
  model_inputs["pixel_values_videos"] = None
1469
1482
 
@@ -100,7 +100,7 @@ class Qwen3RotaryEmbedding(nn.Module):
100
100
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
101
101
 
102
102
  self.register_buffer("inv_freq", inv_freq, persistent=False)
103
- self.original_inv_freq = inv_freq
103
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
104
104
 
105
105
  @staticmethod
106
106
  def compute_default_rope_parameters(
@@ -30,7 +30,12 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
33
+ from ...integrations import (
34
+ use_experts_implementation,
35
+ use_kernel_forward_from_hub,
36
+ use_kernel_func_from_hub,
37
+ use_kernelized_func,
38
+ )
34
39
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
35
40
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
36
41
  from ...modeling_layers import (
@@ -43,7 +48,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPas
43
48
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44
49
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
50
  from ...processing_utils import Unpack
46
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
51
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
47
52
  from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
48
53
  from .configuration_qwen3_moe import Qwen3MoeConfig
49
54
 
@@ -212,6 +217,7 @@ class Qwen3MoeMLP(nn.Module):
212
217
  return down_proj
213
218
 
214
219
 
220
+ @use_experts_implementation
215
221
  class Qwen3MoeExperts(nn.Module):
216
222
  """Collection of expert weights stored as 3D tensors."""
217
223
 
@@ -365,7 +371,9 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
365
371
  _supports_flash_attn = True
366
372
  _supports_sdpa = True
367
373
  _supports_flex_attn = True
368
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
374
+ _can_compile_fullgraph = (
375
+ is_grouped_mm_available()
376
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
369
377
  _supports_attention_backend = True
370
378
  _can_record_outputs = {
371
379
  "router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.gate", index=0),
@@ -401,7 +409,7 @@ class Qwen3MoeRotaryEmbedding(nn.Module):
401
409
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
402
410
 
403
411
  self.register_buffer("inv_freq", inv_freq, persistent=False)
404
- self.original_inv_freq = inv_freq
412
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
405
413
 
406
414
  @staticmethod
407
415
  def compute_default_rope_parameters(
@@ -30,7 +30,7 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernelized_func
33
+ from ...integrations import use_experts_implementation, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
36
36
  from ...modeling_layers import (
@@ -45,10 +45,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
45
  from ...processing_utils import Unpack
46
46
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
47
47
  from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
48
- from ...utils.import_utils import (
49
- is_causal_conv1d_available,
50
- is_flash_linear_attention_available,
51
- )
48
+ from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
52
49
  from .configuration_qwen3_next import Qwen3NextConfig
53
50
 
54
51
 
@@ -192,7 +189,7 @@ class Qwen3NextRotaryEmbedding(nn.Module):
192
189
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
193
190
 
194
191
  self.register_buffer("inv_freq", inv_freq, persistent=False)
195
- self.original_inv_freq = inv_freq
192
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
196
193
 
197
194
  @staticmethod
198
195
  def compute_default_rope_parameters(
@@ -822,6 +819,7 @@ class Qwen3NextMLP(nn.Module):
822
819
  return down_proj
823
820
 
824
821
 
822
+ @use_experts_implementation
825
823
  class Qwen3NextExperts(nn.Module):
826
824
  """Collection of expert weights stored as 3D tensors."""
827
825
 
@@ -907,6 +907,7 @@ class Qwen3OmniMoeTalkerConfig(PreTrainedConfig):
907
907
  self.audio_start_token_id = audio_start_token_id
908
908
  self.vision_start_token_id = vision_start_token_id
909
909
  self.speaker_id = speaker_id
910
+ self.initializer_range = self.text_config.initializer_range
910
911
  super().__init__(**kwargs)
911
912
 
912
913
 
@@ -997,6 +998,7 @@ class Qwen3OmniMoeCode2WavConfig(PreTrainedConfig):
997
998
  upsampling_ratios=(2, 2),
998
999
  decoder_dim=1536,
999
1000
  attention_dropout=0.0,
1001
+ initializer_range=0.02,
1000
1002
  **kwargs,
1001
1003
  ):
1002
1004
  self.codebook_size = codebook_size
@@ -1016,6 +1018,7 @@ class Qwen3OmniMoeCode2WavConfig(PreTrainedConfig):
1016
1018
  self.upsampling_ratios = upsampling_ratios
1017
1019
  self.decoder_dim = decoder_dim
1018
1020
  self.attention_dropout = attention_dropout
1021
+ self.initializer_range = initializer_range
1019
1022
  self.rope_parameters = rope_parameters
1020
1023
 
1021
1024
  super().__init__(**kwargs)
@@ -1104,6 +1107,7 @@ class Qwen3OmniMoeConfig(PreTrainedConfig):
1104
1107
  self.thinker_config = Qwen3OmniMoeThinkerConfig(**thinker_config)
1105
1108
  self.talker_config = Qwen3OmniMoeTalkerConfig(**talker_config)
1106
1109
  self.code2wav_config = Qwen3OmniMoeCode2WavConfig(**code2wav_config)
1110
+ self.initializer_range = self.thinker_config.initializer_range
1107
1111
  self.enable_audio_output = enable_audio_output
1108
1112
  self.im_start_token_id = im_start_token_id
1109
1113
  self.im_end_token_id = im_end_token_id
@@ -35,7 +35,12 @@ from ... import initialization as init
35
35
  from ...activations import ACT2FN
36
36
  from ...cache_utils import Cache, DynamicCache
37
37
  from ...generation import GenerationMixin
38
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
38
+ from ...integrations import (
39
+ use_experts_implementation,
40
+ use_kernel_forward_from_hub,
41
+ use_kernel_func_from_hub,
42
+ use_kernelized_func,
43
+ )
39
44
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
40
45
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
41
46
  from ...modeling_layers import GradientCheckpointingLayer
@@ -49,7 +54,7 @@ from ...modeling_outputs import (
49
54
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
50
55
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
51
56
  from ...processing_utils import Unpack
52
- from ...utils import auto_docstring, can_return_tuple
57
+ from ...utils import auto_docstring, can_return_tuple, is_grouped_mm_available
53
58
  from ...utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs, maybe_autocast
54
59
  from .configuration_qwen3_omni_moe import (
55
60
  Qwen3OmniMoeAudioEncoderConfig,
@@ -64,6 +69,27 @@ from .configuration_qwen3_omni_moe import (
64
69
  )
65
70
 
66
71
 
72
+ class SinusoidsPositionEmbedding(nn.Module):
73
+ def __init__(self, length, channels, max_timescale=10000):
74
+ super().__init__()
75
+ self.length = length
76
+ self.channels = channels
77
+ self.max_timescale = max_timescale
78
+ if channels % 2 != 0:
79
+ raise ValueError("SinusoidsPositionEmbedding needs even channels input")
80
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
81
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
82
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
83
+ self.register_buffer(
84
+ "positional_embedding",
85
+ torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
86
+ persistent=False,
87
+ )
88
+
89
+ def forward(self, seqlen: int):
90
+ return self.positional_embedding[:seqlen, :]
91
+
92
+
67
93
  @auto_docstring
68
94
  class Qwen3OmniMoePreTrainedModel(PreTrainedModel):
69
95
  config: Qwen3OmniMoeConfig
@@ -85,6 +111,19 @@ class Qwen3OmniMoePreTrainedModel(PreTrainedModel):
85
111
  init.normal_(module.experts.gate_up_proj, mean=0.0, std=std)
86
112
  init.normal_(module.experts.down_proj, mean=0.0, std=std)
87
113
  init.normal_(module.gate.weight, mean=0.0, std=std)
114
+ elif isinstance(module, Qwen3OmniMoeCode2Wav):
115
+ init.copy_(
116
+ module.code_offset,
117
+ torch.arange(module.config.num_quantizers).view(1, -1, 1) * module.config.codebook_size,
118
+ )
119
+ elif isinstance(module, SinusoidsPositionEmbedding):
120
+ log_timescale_increment = np.log(module.max_timescale) / (module.channels // 2 - 1)
121
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(module.channels // 2).float())
122
+ scaled_time = torch.arange(module.length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
123
+ init.copy_(module.positional_embedding, torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1))
124
+ elif isinstance(module, Qwen3OmniMoeVisionRotaryEmbedding):
125
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
126
+ init.copy_(module.inv_freq, inv_freq)
88
127
 
89
128
 
90
129
  def _get_feat_extract_output_lengths(input_lengths):
@@ -620,24 +659,6 @@ class Qwen3OmniMoeAudioEncoderLayer(GradientCheckpointingLayer):
620
659
  return outputs
621
660
 
622
661
 
623
- class SinusoidsPositionEmbedding(nn.Module):
624
- def __init__(self, length, channels, max_timescale=10000):
625
- super().__init__()
626
- if channels % 2 != 0:
627
- raise ValueError("SinusoidsPositionEmbedding needs even channels input")
628
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
629
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
630
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
631
- self.register_buffer(
632
- "positional_embedding",
633
- torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
634
- persistent=False,
635
- )
636
-
637
- def forward(self, seqlen: int):
638
- return self.positional_embedding[:seqlen, :]
639
-
640
-
641
662
  @auto_docstring(
642
663
  custom_intro="""
643
664
  Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
@@ -891,8 +912,8 @@ class Qwen3OmniMoeVisionAttention(nn.Module):
891
912
  if self.config._attn_implementation != "eager":
892
913
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
893
914
 
894
- if self.config._attn_implementation == "flash_attention_2":
895
- # Flash Attention 2: Use cu_seqlens for variable length attention
915
+ if "flash" in self.config._attn_implementation:
916
+ # Flash Attention: Use cu_seqlens for variable length attention
896
917
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
897
918
  attn_output, _ = attention_interface(
898
919
  self,
@@ -960,6 +981,22 @@ class Qwen3OmniMoeVisionPatchMerger(nn.Module):
960
981
  return hidden
961
982
 
962
983
 
984
+ class Qwen3OmniMoeVisionRotaryEmbedding(nn.Module):
985
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
986
+
987
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
988
+ super().__init__()
989
+ self.dim = dim
990
+ self.theta = theta
991
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
992
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
993
+
994
+ def forward(self, seqlen: int) -> torch.Tensor:
995
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
996
+ freqs = torch.outer(seq, self.inv_freq)
997
+ return freqs
998
+
999
+
963
1000
  class Qwen3OmniMoeVisionMLP(nn.Module):
964
1001
  def __init__(self, config):
965
1002
  super().__init__()
@@ -993,20 +1030,6 @@ class Qwen3OmniMoeVisionPatchEmbed(nn.Module):
993
1030
  return hidden_states
994
1031
 
995
1032
 
996
- class Qwen3OmniMoeVisionRotaryEmbedding(nn.Module):
997
- inv_freq: torch.Tensor # fix linting for `register_buffer`
998
-
999
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
1000
- super().__init__()
1001
- inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
1002
- self.register_buffer("inv_freq", inv_freq, persistent=False)
1003
-
1004
- def forward(self, seqlen: int) -> torch.Tensor:
1005
- seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
1006
- freqs = torch.outer(seq, self.inv_freq)
1007
- return freqs
1008
-
1009
-
1010
1033
  class Qwen3OmniMoeVisionBlock(GradientCheckpointingLayer):
1011
1034
  def __init__(self, config, attn_implementation: str = "sdpa") -> None:
1012
1035
  super().__init__()
@@ -1073,6 +1096,8 @@ class Qwen3OmniMoeVisionEncoder(Qwen3OmniMoePreTrainedModel):
1073
1096
 
1074
1097
  self.gradient_checkpointing = False
1075
1098
 
1099
+ self.post_init()
1100
+
1076
1101
  def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
1077
1102
  merge_size = self.spatial_merge_size
1078
1103
 
@@ -1246,7 +1271,7 @@ class Qwen3OmniMoeThinkerTextRotaryEmbedding(nn.Module):
1246
1271
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
1247
1272
 
1248
1273
  self.register_buffer("inv_freq", inv_freq, persistent=False)
1249
- self.original_inv_freq = inv_freq
1274
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
1250
1275
 
1251
1276
  self.mrope_section = config.rope_parameters.get("mrope_section", [24, 20, 20])
1252
1277
 
@@ -1318,6 +1343,7 @@ class Qwen3OmniMoeThinkerTextRotaryEmbedding(nn.Module):
1318
1343
  return freqs_t
1319
1344
 
1320
1345
 
1346
+ @use_experts_implementation
1321
1347
  class Qwen3OmniMoeThinkerTextExperts(nn.Module):
1322
1348
  """
1323
1349
  ModuleList of experts.
@@ -1596,7 +1622,9 @@ class Qwen3OmniMoeThinkerTextPreTrainedModel(PreTrainedModel):
1596
1622
  _supports_flash_attn = True
1597
1623
  _supports_sdpa = True
1598
1624
  _supports_flex_attn = True
1599
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
1625
+ _can_compile_fullgraph = (
1626
+ is_grouped_mm_available()
1627
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
1600
1628
  _supports_attention_backend = True
1601
1629
  _can_record_outputs = {
1602
1630
  "router_logits": OutputRecorder(Qwen3OmniMoeThinkerTextTopKRouter, layer_name="mlp.gate", index=0),
@@ -2248,6 +2276,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2248
2276
  feature_attention_mask=None,
2249
2277
  use_audio_in_video=False,
2250
2278
  video_second_per_grid=None,
2279
+ is_first_iteration=False,
2251
2280
  **kwargs,
2252
2281
  ):
2253
2282
  model_inputs = super().prepare_inputs_for_generation(
@@ -2266,12 +2295,13 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2266
2295
  feature_attention_mask=feature_attention_mask,
2267
2296
  use_audio_in_video=use_audio_in_video,
2268
2297
  video_second_per_grid=video_second_per_grid,
2298
+ is_first_iteration=is_first_iteration,
2269
2299
  **kwargs,
2270
2300
  )
2271
2301
 
2272
2302
  model_inputs["position_ids"] = None
2273
2303
 
2274
- if cache_position[0] != 0:
2304
+ if not is_first_iteration and use_cache:
2275
2305
  model_inputs["pixel_values"] = None
2276
2306
  model_inputs["pixel_values_videos"] = None
2277
2307
  model_inputs["input_features"] = None
@@ -2477,7 +2507,7 @@ class Qwen3OmniMoeRotaryEmbedding(nn.Module):
2477
2507
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
2478
2508
 
2479
2509
  self.register_buffer("inv_freq", inv_freq, persistent=False)
2480
- self.original_inv_freq = inv_freq
2510
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
2481
2511
 
2482
2512
  @staticmethod
2483
2513
  def compute_default_rope_parameters(
@@ -2745,6 +2775,7 @@ class Qwen3OmniMoeTalkerTextMLP(nn.Module):
2745
2775
  return down_proj
2746
2776
 
2747
2777
 
2778
+ @use_experts_implementation
2748
2779
  class Qwen3OmniMoeTalkerTextExperts(nn.Module):
2749
2780
  """Collection of expert weights stored as 3D tensors."""
2750
2781
 
@@ -3020,9 +3051,9 @@ class Qwen3OmniMoeTalkerModel(Qwen3OmniMoePreTrainedModel):
3020
3051
 
3021
3052
  @auto_docstring
3022
3053
  class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrainedModel, GenerationMixin):
3023
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
3024
- _tp_plan = {"lm_head": "colwise_rep"}
3025
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
3054
+ _tied_weights_keys = {"codec_head": "model.codec_embedding.weight"}
3055
+ _tp_plan = {"codec_head": "colwise_rep"}
3056
+ _pp_plan = {"codec_head": (["hidden_states"], ["logits"])}
3026
3057
  config_class = Qwen3OmniMoeTalkerConfig
3027
3058
  base_model_prefix = "talker"
3028
3059
  _no_split_modules = ["Qwen3OmniMoeTalkerCodePredictorModelForConditionalGeneration"]
@@ -3213,18 +3244,31 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
3213
3244
  return model_kwargs
3214
3245
 
3215
3246
  def prepare_inputs_for_generation(
3216
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
3247
+ self,
3248
+ input_ids,
3249
+ past_key_values=None,
3250
+ attention_mask=None,
3251
+ inputs_embeds=None,
3252
+ cache_position=None,
3253
+ is_first_iteration=False,
3254
+ **kwargs,
3217
3255
  ):
3218
3256
  hidden_states = kwargs.pop("hidden_states", None)
3219
3257
  inputs = super().prepare_inputs_for_generation(
3220
- input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
3258
+ input_ids,
3259
+ past_key_values,
3260
+ attention_mask,
3261
+ inputs_embeds,
3262
+ cache_position,
3263
+ is_first_iteration=is_first_iteration,
3264
+ **kwargs,
3221
3265
  )
3222
3266
 
3223
3267
  # Qwen3-Omni will prepare position ids in forward with deltas
3224
3268
  inputs["position_ids"] = None
3225
3269
 
3226
3270
  # TODO(raushan, gante): Refactor this part to a utility function
3227
- if cache_position[0] != 0:
3271
+ if not is_first_iteration and kwargs.get("use_cache", True):
3228
3272
  input_ids = input_ids[:, -1:]
3229
3273
  generation_step = kwargs.get("generation_step")
3230
3274
  trailing_text_hidden = kwargs.get("trailing_text_hidden")
@@ -3716,6 +3760,8 @@ class Qwen3OmniMoeCode2WavDecoderBlock(Qwen3OmniMoePreTrainedModel):
3716
3760
 
3717
3761
  self.block = nn.ModuleList(block)
3718
3762
 
3763
+ self.post_init()
3764
+
3719
3765
  def forward(self, hidden, **kwargs):
3720
3766
  for block in self.block:
3721
3767
  hidden = block(hidden)