transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,7 @@ from ... import initialization as init
35
35
  from ...activations import ACT2FN
36
36
  from ...cache_utils import Cache, DynamicCache
37
37
  from ...generation import GenerationMixin
38
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
38
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
39
39
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
40
40
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
41
41
  from ...modeling_layers import GradientCheckpointingLayer
@@ -50,7 +50,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
50
50
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
51
51
  from ...processing_utils import Unpack
52
52
  from ...utils import auto_docstring, can_return_tuple
53
- from ...utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
53
+ from ...utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs, maybe_autocast
54
54
  from .configuration_qwen3_omni_moe import (
55
55
  Qwen3OmniMoeAudioEncoderConfig,
56
56
  Qwen3OmniMoeCode2WavConfig,
@@ -716,6 +716,7 @@ class Qwen3OmniMoeAudioEncoder(Qwen3OmniMoePreTrainedModel):
716
716
  input_features,
717
717
  feature_lens=None,
718
718
  aftercnn_lens=None,
719
+ **kwargs,
719
720
  ):
720
721
  r"""
721
722
  feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
@@ -1290,7 +1291,7 @@ class Qwen3OmniMoeThinkerTextRotaryEmbedding(nn.Module):
1290
1291
  position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
1291
1292
 
1292
1293
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1293
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
1294
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
1294
1295
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
1295
1296
  freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
1296
1297
  emb = torch.cat((freqs, freqs), dim=-1)
@@ -1442,6 +1443,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
1442
1443
  return q_embed, k_embed
1443
1444
 
1444
1445
 
1446
+ @use_kernelized_func(apply_rotary_pos_emb)
1445
1447
  class Qwen3OmniMoeThinkerTextAttention(nn.Module):
1446
1448
  """Multi-headed attention from 'Attention Is All You Need' paper"""
1447
1449
 
@@ -1467,7 +1469,6 @@ class Qwen3OmniMoeThinkerTextAttention(nn.Module):
1467
1469
  self.o_proj = nn.Linear(
1468
1470
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
1469
1471
  )
1470
- self.rotary_fn = apply_rotary_pos_emb
1471
1472
  self.q_norm = Qwen3OmniMoeThinkerTextRMSNorm(
1472
1473
  self.head_dim, eps=config.rms_norm_eps
1473
1474
  ) # unlike olmo, only on the head dim!
@@ -2165,11 +2166,8 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2165
2166
  audio_feature_lengths = None
2166
2167
 
2167
2168
  if attention_mask is not None and position_ids is None:
2168
- if (
2169
- cache_position is None
2170
- or (cache_position is not None and cache_position[0] == 0)
2171
- or self.rope_deltas is None
2172
- ):
2169
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
2170
+ if past_key_values_length == 0 or self.rope_deltas is None:
2173
2171
  delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
2174
2172
  position_ids, rope_deltas = self.get_rope_index(
2175
2173
  input_ids,
@@ -2184,7 +2182,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2184
2182
  self.rope_deltas = rope_deltas
2185
2183
  else:
2186
2184
  batch_size, seq_length = input_ids.shape
2187
- delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
2185
+ delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
2188
2186
  position_ids = torch.arange(seq_length, device=input_ids.device)
2189
2187
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
2190
2188
  position_ids = position_ids.add(delta)
@@ -2323,6 +2321,7 @@ class Qwen3OmniMoeRMSNorm(nn.Module):
2323
2321
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
2324
2322
 
2325
2323
 
2324
+ @use_kernelized_func(apply_rotary_pos_emb)
2326
2325
  class Qwen3OmniMoeTalkerCodePredictorAttention(nn.Module):
2327
2326
  """Multi-headed attention from 'Attention Is All You Need' paper"""
2328
2327
 
@@ -2349,7 +2348,6 @@ class Qwen3OmniMoeTalkerCodePredictorAttention(nn.Module):
2349
2348
  self.o_proj = nn.Linear(
2350
2349
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
2351
2350
  )
2352
- self.rotary_fn = apply_rotary_pos_emb
2353
2351
  self.q_norm = Qwen3OmniMoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
2354
2352
  self.k_norm = Qwen3OmniMoeRMSNorm(
2355
2353
  self.head_dim, eps=config.rms_norm_eps
@@ -2518,7 +2516,7 @@ class Qwen3OmniMoeRotaryEmbedding(nn.Module):
2518
2516
  position_ids_expanded = position_ids[:, None, :].float()
2519
2517
 
2520
2518
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
2521
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
2519
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
2522
2520
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
2523
2521
  emb = torch.cat((freqs, freqs), dim=-1)
2524
2522
  cos = emb.cos() * self.attention_scaling
@@ -3103,12 +3101,9 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
3103
3101
  if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
3104
3102
  generation_step = -1
3105
3103
  residual_codes = None
3106
- if attention_mask is not None:
3107
- if (
3108
- cache_position is None
3109
- or (cache_position is not None and cache_position[0] == 0)
3110
- or self.rope_deltas is None
3111
- ):
3104
+ if position_ids is None:
3105
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
3106
+ if past_key_values_length == 0 or self.rope_deltas is None:
3112
3107
  delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
3113
3108
  position_ids, rope_deltas = self.get_rope_index(
3114
3109
  talker_input_ids,
@@ -3123,7 +3118,7 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
3123
3118
  self.rope_deltas = rope_deltas
3124
3119
  else:
3125
3120
  batch_size, seq_length = input_ids.shape
3126
- delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
3121
+ delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
3127
3122
  position_ids = torch.arange(seq_length, device=input_ids.device)
3128
3123
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
3129
3124
  position_ids = position_ids.add(delta)
@@ -3224,7 +3219,10 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
3224
3219
  inputs = super().prepare_inputs_for_generation(
3225
3220
  input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
3226
3221
  )
3227
- # Decode stage
3222
+
3223
+ # Qwen3-Omni will prepare position ids in forward with deltas
3224
+ inputs["position_ids"] = None
3225
+
3228
3226
  # TODO(raushan, gante): Refactor this part to a utility function
3229
3227
  if cache_position[0] != 0:
3230
3228
  input_ids = input_ids[:, -1:]
@@ -3352,6 +3350,7 @@ class Qwen3OmniMoeConvNeXtBlock(nn.Module):
3352
3350
  return hidden_states
3353
3351
 
3354
3352
 
3353
+ @use_kernelized_func(apply_rotary_pos_emb)
3355
3354
  class Qwen3OmniMoeCode2WavAttention(nn.Module):
3356
3355
  """Multi-headed attention from 'Attention Is All You Need' paper"""
3357
3356
 
@@ -3378,7 +3377,6 @@ class Qwen3OmniMoeCode2WavAttention(nn.Module):
3378
3377
  self.o_proj = nn.Linear(
3379
3378
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
3380
3379
  )
3381
- self.rotary_fn = apply_rotary_pos_emb
3382
3380
  self.q_norm = nn.Identity()
3383
3381
  self.k_norm = nn.Identity()
3384
3382
  self.sliding_window = config.sliding_window
@@ -3718,7 +3716,7 @@ class Qwen3OmniMoeCode2WavDecoderBlock(Qwen3OmniMoePreTrainedModel):
3718
3716
 
3719
3717
  self.block = nn.ModuleList(block)
3720
3718
 
3721
- def forward(self, hidden):
3719
+ def forward(self, hidden, **kwargs):
3722
3720
  for block in self.block:
3723
3721
  hidden = block(hidden)
3724
3722
  return hidden
@@ -3760,7 +3758,7 @@ class Qwen3OmniMoeCode2Wav(Qwen3OmniMoePreTrainedModel):
3760
3758
 
3761
3759
  self.post_init()
3762
3760
 
3763
- def forward(self, codes):
3761
+ def forward(self, codes, **kwargs):
3764
3762
  if codes.shape[1] != self.config.num_quantizers:
3765
3763
  raise ValueError(f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}")
3766
3764
  hidden = self.code_embedding(codes + self.code_offset).mean(1)
@@ -1205,6 +1205,7 @@ class Qwen3OmniMoeAudioEncoder(Qwen2_5OmniAudioEncoder):
1205
1205
  input_features,
1206
1206
  feature_lens=None,
1207
1207
  aftercnn_lens=None,
1208
+ **kwargs,
1208
1209
  ):
1209
1210
  aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
1210
1211
  chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
@@ -1521,11 +1522,8 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen2_5OmniThinkerForCondition
1521
1522
  audio_feature_lengths = None
1522
1523
 
1523
1524
  if attention_mask is not None and position_ids is None:
1524
- if (
1525
- cache_position is None
1526
- or (cache_position is not None and cache_position[0] == 0)
1527
- or self.rope_deltas is None
1528
- ):
1525
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1526
+ if past_key_values_length == 0 or self.rope_deltas is None:
1529
1527
  delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
1530
1528
  position_ids, rope_deltas = self.get_rope_index(
1531
1529
  input_ids,
@@ -1540,7 +1538,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen2_5OmniThinkerForCondition
1540
1538
  self.rope_deltas = rope_deltas
1541
1539
  else:
1542
1540
  batch_size, seq_length = input_ids.shape
1543
- delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
1541
+ delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
1544
1542
  position_ids = torch.arange(seq_length, device=input_ids.device)
1545
1543
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1546
1544
  position_ids = position_ids.add(delta)
@@ -1961,12 +1959,9 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3MoeForCausalLM):
1961
1959
  if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
1962
1960
  generation_step = -1
1963
1961
  residual_codes = None
1964
- if attention_mask is not None:
1965
- if (
1966
- cache_position is None
1967
- or (cache_position is not None and cache_position[0] == 0)
1968
- or self.rope_deltas is None
1969
- ):
1962
+ if position_ids is None:
1963
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1964
+ if past_key_values_length == 0 or self.rope_deltas is None:
1970
1965
  delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
1971
1966
  position_ids, rope_deltas = self.get_rope_index(
1972
1967
  talker_input_ids,
@@ -1981,7 +1976,7 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3MoeForCausalLM):
1981
1976
  self.rope_deltas = rope_deltas
1982
1977
  else:
1983
1978
  batch_size, seq_length = input_ids.shape
1984
- delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
1979
+ delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
1985
1980
  position_ids = torch.arange(seq_length, device=input_ids.device)
1986
1981
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1987
1982
  position_ids = position_ids.add(delta)
@@ -2044,7 +2039,10 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3MoeForCausalLM):
2044
2039
  inputs = super().prepare_inputs_for_generation(
2045
2040
  input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
2046
2041
  )
2047
- # Decode stage
2042
+
2043
+ # Qwen3-Omni will prepare position ids in forward with deltas
2044
+ inputs["position_ids"] = None
2045
+
2048
2046
  # TODO(raushan, gante): Refactor this part to a utility function
2049
2047
  if cache_position[0] != 0:
2050
2048
  input_ids = input_ids[:, -1:]
@@ -2339,7 +2337,7 @@ class Qwen3OmniMoeCode2WavDecoderBlock(Qwen3OmniMoePreTrainedModel):
2339
2337
 
2340
2338
  self.block = nn.ModuleList(block)
2341
2339
 
2342
- def forward(self, hidden):
2340
+ def forward(self, hidden, **kwargs):
2343
2341
  for block in self.block:
2344
2342
  hidden = block(hidden)
2345
2343
  return hidden
@@ -2381,7 +2379,7 @@ class Qwen3OmniMoeCode2Wav(Qwen3OmniMoePreTrainedModel):
2381
2379
 
2382
2380
  self.post_init()
2383
2381
 
2384
- def forward(self, codes):
2382
+ def forward(self, codes, **kwargs):
2385
2383
  if codes.shape[1] != self.config.num_quantizers:
2386
2384
  raise ValueError(f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}")
2387
2385
  hidden = self.code_embedding(codes + self.code_offset).mean(1)
@@ -30,7 +30,7 @@ import torch.nn.functional as F
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
33
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
36
36
  from ...modeling_layers import GradientCheckpointingLayer
@@ -38,8 +38,8 @@ from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
38
38
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
40
  from ...processing_utils import Unpack
41
- from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling
42
- from ...utils.generic import check_model_inputs
41
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
42
+ from ...utils.generic import check_model_inputs, maybe_autocast
43
43
  from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig
44
44
 
45
45
 
@@ -337,7 +337,7 @@ class Qwen3VLTextRotaryEmbedding(nn.Module):
337
337
  position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
338
338
 
339
339
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
340
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
340
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
341
341
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
342
342
  freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
343
343
  emb = torch.cat((freqs, freqs), dim=-1)
@@ -413,6 +413,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
413
413
  return q_embed, k_embed
414
414
 
415
415
 
416
+ @use_kernelized_func(apply_rotary_pos_emb)
416
417
  class Qwen3VLTextAttention(nn.Module):
417
418
  """Multi-headed attention from 'Attention Is All You Need' paper"""
418
419
 
@@ -439,7 +440,6 @@ class Qwen3VLTextAttention(nn.Module):
439
440
  self.o_proj = nn.Linear(
440
441
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
441
442
  )
442
- self.rotary_fn = apply_rotary_pos_emb
443
443
  self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
444
444
  self.k_norm = Qwen3VLTextRMSNorm(
445
445
  self.head_dim, eps=config.rms_norm_eps
@@ -1201,44 +1201,19 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
1201
1201
  deepstack_visual_embeds = deepstack_video_embeds
1202
1202
 
1203
1203
  if position_ids is None:
1204
- attention_mask_tensor = (
1205
- attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
1206
- )
1207
- if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
1208
- attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
1209
- # Only apply conversion for floating point tensors (inverted masks)
1210
- if attention_mask_tensor.dtype.is_floating_point:
1211
- attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
1212
- attention_mask_tensor = (1.0 - attention_mask_tensor).int()
1213
-
1214
- # Calculate RoPE index once per generation in the pre-fill stage only.
1215
- # When compiling, we can't check tensor values thus we check only input length
1216
- # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1217
- # models currently cannot do asssisted decoding
1218
- prefill_compiled_stage = is_torchdynamo_compiling() and (
1219
- (input_ids is not None and input_ids.shape[1] != 1)
1220
- or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
1221
- )
1222
- prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
1223
- (cache_position is not None and cache_position[0] == 0)
1224
- or (past_key_values is None or past_key_values.get_seq_length() == 0)
1225
- )
1226
- if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
1204
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1205
+ if self.rope_deltas is None or past_key_values_length == 0:
1227
1206
  position_ids, rope_deltas = self.get_rope_index(
1228
1207
  input_ids,
1229
1208
  image_grid_thw,
1230
1209
  video_grid_thw,
1231
- attention_mask=attention_mask_tensor,
1210
+ attention_mask=attention_mask,
1232
1211
  )
1233
1212
  self.rope_deltas = rope_deltas
1234
1213
  # then use the prev pre-calculated rope-deltas to get the correct position ids
1235
1214
  else:
1236
1215
  batch_size, seq_length, _ = inputs_embeds.shape
1237
- delta = (
1238
- (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1239
- if cache_position is not None
1240
- else 0
1241
- )
1216
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
1242
1217
  position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1243
1218
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1244
1219
  if cache_position is not None: # otherwise `deltas` is an int `0`
@@ -1322,7 +1297,7 @@ class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
1322
1297
  def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1323
1298
  return self.model.get_image_features(pixel_values, image_grid_thw)
1324
1299
 
1325
- @check_model_inputs
1300
+ @can_return_tuple
1326
1301
  def forward(
1327
1302
  self,
1328
1303
  input_ids: torch.LongTensor = None,
@@ -1414,6 +1389,8 @@ class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
1414
1389
  loss=loss,
1415
1390
  logits=logits,
1416
1391
  past_key_values=outputs.past_key_values,
1392
+ hidden_states=outputs.hidden_states,
1393
+ attentions=outputs.attentions,
1417
1394
  rope_deltas=outputs.rope_deltas,
1418
1395
  )
1419
1396
 
@@ -1449,8 +1426,33 @@ class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
1449
1426
  **kwargs,
1450
1427
  )
1451
1428
 
1452
- # Qwen3VL position_ids are prepareed with rope_deltas in forward
1453
- model_inputs["position_ids"] = None
1429
+ # Qwen3VL position_ids are prepared with rope_deltas
1430
+ if position_ids is None:
1431
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1432
+ # When compiling, we can't check tensor values thus we check only input length
1433
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1434
+ # models currently cannot do asssisted decoding
1435
+ if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
1436
+ vision_positions, rope_deltas = self.model.get_rope_index(
1437
+ model_inputs.get("input_ids", None),
1438
+ image_grid_thw=image_grid_thw,
1439
+ video_grid_thw=video_grid_thw,
1440
+ attention_mask=attention_mask,
1441
+ )
1442
+ self.model.rope_deltas = rope_deltas
1443
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1444
+ elif "position_ids" in model_inputs:
1445
+ batch_size, seq_length = model_inputs["position_ids"].shape
1446
+ device = model_inputs["position_ids"].device
1447
+ position_ids = torch.arange(seq_length, device=device)
1448
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1449
+ delta = cache_position[0] + self.model.rope_deltas
1450
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1451
+ vision_positions = position_ids + delta.expand_as(position_ids)
1452
+
1453
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
1454
+ text_positions = model_inputs["position_ids"][None, ...]
1455
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
1454
1456
 
1455
1457
  if cache_position[0] != 0:
1456
1458
  model_inputs["pixel_values"] = None
@@ -34,8 +34,8 @@ from ...modeling_rope_utils import RopeParameters, dynamic_rope_update
34
34
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
35
35
  from ...processing_utils import ProcessingKwargs, Unpack
36
36
  from ...tokenization_utils_base import PreTokenizedInput, TextInput
37
- from ...utils import auto_docstring, is_torchdynamo_compiling, logging
38
- from ...utils.generic import check_model_inputs
37
+ from ...utils import auto_docstring, can_return_tuple, logging
38
+ from ...utils.generic import check_model_inputs, maybe_autocast
39
39
  from ...video_utils import VideoInput
40
40
  from ..llama.modeling_llama import LlamaRotaryEmbedding
41
41
  from ..qwen2_5_vl.modeling_qwen2_5_vl import (
@@ -389,7 +389,7 @@ class Qwen3VLTextRotaryEmbedding(LlamaRotaryEmbedding):
389
389
  position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
390
390
 
391
391
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
392
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
392
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
393
393
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
394
394
  freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
395
395
  emb = torch.cat((freqs, freqs), dim=-1)
@@ -1033,44 +1033,19 @@ class Qwen3VLModel(Qwen2_5_VLModel):
1033
1033
  deepstack_visual_embeds = deepstack_video_embeds
1034
1034
 
1035
1035
  if position_ids is None:
1036
- attention_mask_tensor = (
1037
- attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
1038
- )
1039
- if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
1040
- attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
1041
- # Only apply conversion for floating point tensors (inverted masks)
1042
- if attention_mask_tensor.dtype.is_floating_point:
1043
- attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
1044
- attention_mask_tensor = (1.0 - attention_mask_tensor).int()
1045
-
1046
- # Calculate RoPE index once per generation in the pre-fill stage only.
1047
- # When compiling, we can't check tensor values thus we check only input length
1048
- # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1049
- # models currently cannot do asssisted decoding
1050
- prefill_compiled_stage = is_torchdynamo_compiling() and (
1051
- (input_ids is not None and input_ids.shape[1] != 1)
1052
- or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
1053
- )
1054
- prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
1055
- (cache_position is not None and cache_position[0] == 0)
1056
- or (past_key_values is None or past_key_values.get_seq_length() == 0)
1057
- )
1058
- if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
1036
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1037
+ if self.rope_deltas is None or past_key_values_length == 0:
1059
1038
  position_ids, rope_deltas = self.get_rope_index(
1060
1039
  input_ids,
1061
1040
  image_grid_thw,
1062
1041
  video_grid_thw,
1063
- attention_mask=attention_mask_tensor,
1042
+ attention_mask=attention_mask,
1064
1043
  )
1065
1044
  self.rope_deltas = rope_deltas
1066
1045
  # then use the prev pre-calculated rope-deltas to get the correct position ids
1067
1046
  else:
1068
1047
  batch_size, seq_length, _ = inputs_embeds.shape
1069
- delta = (
1070
- (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1071
- if cache_position is not None
1072
- else 0
1073
- )
1048
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
1074
1049
  position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1075
1050
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1076
1051
  if cache_position is not None: # otherwise `deltas` is an int `0`
@@ -1105,7 +1080,7 @@ class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
1105
1080
  config: Qwen3VLConfig
1106
1081
  _checkpoint_conversion_mapping = {}
1107
1082
 
1108
- @check_model_inputs
1083
+ @can_return_tuple
1109
1084
  def forward(
1110
1085
  self,
1111
1086
  input_ids: torch.LongTensor = None,
@@ -1197,6 +1172,8 @@ class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
1197
1172
  loss=loss,
1198
1173
  logits=logits,
1199
1174
  past_key_values=outputs.past_key_values,
1175
+ hidden_states=outputs.hidden_states,
1176
+ attentions=outputs.attentions,
1200
1177
  rope_deltas=outputs.rope_deltas,
1201
1178
  )
1202
1179
 
@@ -1232,8 +1209,33 @@ class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
1232
1209
  **kwargs,
1233
1210
  )
1234
1211
 
1235
- # Qwen3VL position_ids are prepareed with rope_deltas in forward
1236
- model_inputs["position_ids"] = None
1212
+ # Qwen3VL position_ids are prepared with rope_deltas
1213
+ if position_ids is None:
1214
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1215
+ # When compiling, we can't check tensor values thus we check only input length
1216
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1217
+ # models currently cannot do asssisted decoding
1218
+ if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
1219
+ vision_positions, rope_deltas = self.model.get_rope_index(
1220
+ model_inputs.get("input_ids", None),
1221
+ image_grid_thw=image_grid_thw,
1222
+ video_grid_thw=video_grid_thw,
1223
+ attention_mask=attention_mask,
1224
+ )
1225
+ self.model.rope_deltas = rope_deltas
1226
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1227
+ elif "position_ids" in model_inputs:
1228
+ batch_size, seq_length = model_inputs["position_ids"].shape
1229
+ device = model_inputs["position_ids"].device
1230
+ position_ids = torch.arange(seq_length, device=device)
1231
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1232
+ delta = cache_position[0] + self.model.rope_deltas
1233
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1234
+ vision_positions = position_ids + delta.expand_as(position_ids)
1235
+
1236
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
1237
+ text_positions = model_inputs["position_ids"][None, ...]
1238
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
1237
1239
 
1238
1240
  if cache_position[0] != 0:
1239
1241
  model_inputs["pixel_values"] = None