transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,12 @@ from ... import initialization as init
31
31
  from ...activations import ACT2FN
32
32
  from ...cache_utils import Cache, DynamicCache
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
34
+ from ...integrations import (
35
+ use_experts_implementation,
36
+ use_kernel_forward_from_hub,
37
+ use_kernel_func_from_hub,
38
+ use_kernelized_func,
39
+ )
35
40
  from ...masking_utils import create_causal_mask
36
41
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
37
42
  from ...modeling_layers import GradientCheckpointingLayer
@@ -39,8 +44,8 @@ from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
39
44
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
45
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
46
  from ...processing_utils import Unpack
42
- from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling
43
- from ...utils.generic import OutputRecorder, check_model_inputs
47
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
48
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
44
49
  from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig
45
50
 
46
51
 
@@ -65,92 +70,77 @@ class Qwen3VLMoeTextRMSNorm(nn.Module):
65
70
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
66
71
 
67
72
 
73
+ @use_experts_implementation
68
74
  class Qwen3VLMoeTextExperts(nn.Module):
75
+ """Collection of expert weights stored as 3D tensors."""
76
+
69
77
  def __init__(self, config):
70
78
  super().__init__()
71
79
  self.num_experts = config.num_experts
72
- self.intermediate_size = config.moe_intermediate_size
73
- self.hidden_size = config.hidden_size
74
- self.expert_dim = self.intermediate_size
75
- self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
76
- self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
80
+ self.hidden_dim = config.hidden_size
81
+ self.intermediate_dim = config.moe_intermediate_size
82
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
83
+ self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
77
84
  self.act_fn = ACT2FN[config.hidden_act]
78
85
 
79
86
  def forward(
80
- self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ top_k_index: torch.Tensor,
90
+ top_k_weights: torch.Tensor,
81
91
  ) -> torch.Tensor:
82
- """
83
- When training it is more efficient to just loop over the experts and compute the output for each expert
84
- as otherwise the memory would explode.
92
+ final_hidden_states = torch.zeros_like(hidden_states)
93
+ with torch.no_grad():
94
+ expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
95
+ expert_mask = expert_mask.permute(2, 1, 0)
96
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
97
+
98
+ for expert_idx in expert_hit:
99
+ expert_idx = expert_idx[0]
100
+ if expert_idx == self.num_experts:
101
+ continue
102
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
103
+ current_state = hidden_states[token_idx]
104
+ gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
105
+ current_hidden_states = self.act_fn(gate) * up
106
+ current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
107
+ current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
108
+ final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
109
+
110
+ return final_hidden_states
85
111
 
86
- For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
87
112
 
88
- Args:
89
- hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
90
- routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
91
- router_indices (torch.Tensor): (batch_size * token_num, top_k)
92
- Returns:
93
- torch.Tensor
94
- """
95
- batch_size = hidden_states.shape[0]
96
- hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
97
- if self.training:
98
- next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
99
- with torch.no_grad():
100
- expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
101
- expert_mask = expert_mask.permute(2, 1, 0)
102
- # we sum on the top_k and on the sequence length to get which experts
103
- # are hit this time around
104
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
105
- for expert_idx in expert_hit[:]:
106
- with torch.no_grad():
107
- _, token_idx = torch.where(expert_mask[expert_idx[0]])
108
- current_state = hidden_states[token_idx]
109
- gate_up = current_state @ self.gate_up_proj[expert_idx]
110
- gate, up = gate_up.chunk(2, dim=-1)
111
- gated_output = up * self.act_fn(gate)
112
- out = gated_output @ self.down_proj[expert_idx]
113
- weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
114
- next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
115
- next_states = next_states.view(batch_size, -1, self.hidden_size)
116
- else:
117
- hidden_states = hidden_states.repeat(self.num_experts, 1)
118
- hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
119
- gate_up = torch.bmm(hidden_states, self.gate_up_proj)
120
- gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
121
- next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
122
- next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size)
123
- next_states = (
124
- next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
125
- )
126
- next_states = next_states.sum(dim=0)
127
- return next_states
113
+ class Qwen3VLMoeTextTopKRouter(nn.Module):
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ self.top_k = config.num_experts_per_tok
117
+ self.num_experts = config.num_experts
118
+ self.hidden_dim = config.hidden_size
119
+ self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
120
+
121
+ def forward(self, hidden_states):
122
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
123
+ router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
124
+ router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
125
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
126
+ router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
127
+ router_top_value = router_top_value.to(router_logits.dtype)
128
+ router_scores = router_top_value
129
+ return router_logits, router_scores, router_indices
128
130
 
129
131
 
130
132
  class Qwen3VLMoeTextSparseMoeBlock(nn.Module):
131
- def __init__(self, config):
133
+ def __init__(self, config: Qwen3VLMoeTextConfig):
132
134
  super().__init__()
133
- self.hidden_size = config.hidden_size
134
- self.num_experts = config.num_experts
135
- self.top_k = config.num_experts_per_tok
136
- self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
137
135
  self.experts = Qwen3VLMoeTextExperts(config)
136
+ self.gate = Qwen3VLMoeTextTopKRouter(config)
138
137
 
139
- # since all the models use norm_topk_prob, we don't need to have a extra check for it
140
- # self.norm_topk_prob = config.norm_topk_prob
141
-
142
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
143
- batch_size = hidden_states.shape[0]
144
- hidden_states = hidden_states.reshape(-1, self.hidden_size)
145
- router_logits = self.gate(hidden_states)
146
- routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
147
- routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
148
- routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
149
- routing_weights = routing_weights.to(router_logits.dtype)
150
- router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
151
- hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
152
- routed_out = self.experts(hidden_states, router_weights, router_indices)
153
- return routed_out
138
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
139
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
140
+ hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
141
+ _, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
142
+ final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
143
+ return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
154
144
 
155
145
 
156
146
  def rotate_half(x):
@@ -226,6 +216,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
226
216
  return q_embed, k_embed
227
217
 
228
218
 
219
+ @use_kernelized_func(apply_rotary_pos_emb)
229
220
  class Qwen3VLMoeTextAttention(nn.Module):
230
221
  """Multi-headed attention from 'Attention Is All You Need' paper"""
231
222
 
@@ -252,7 +243,6 @@ class Qwen3VLMoeTextAttention(nn.Module):
252
243
  self.o_proj = nn.Linear(
253
244
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
254
245
  )
255
- self.rotary_fn = apply_rotary_pos_emb
256
246
  self.q_norm = Qwen3VLMoeTextRMSNorm(
257
247
  self.head_dim, eps=config.rms_norm_eps
258
248
  ) # unlike olmo, only on the head dim!
@@ -368,27 +358,6 @@ class Qwen3VLMoeTextDecoderLayer(GradientCheckpointingLayer):
368
358
  return hidden_states
369
359
 
370
360
 
371
- class Qwen3VLMoeTextTopKRouter(nn.Module):
372
- def __init__(self, config):
373
- super().__init__()
374
- self.top_k = config.num_experts_per_tok
375
- self.num_experts = config.num_experts
376
- self.norm_topk_prob = config.norm_topk_prob
377
- self.hidden_dim = config.hidden_size
378
- self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
379
-
380
- def forward(self, hidden_states):
381
- hidden_states = hidden_states.reshape(-1, self.hidden_dim)
382
- router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
383
- router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
384
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
385
- if self.norm_topk_prob:
386
- router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
387
- router_top_value = router_top_value.to(router_logits.dtype)
388
- router_scores = router_top_value
389
- return router_logits, router_scores, router_indices
390
-
391
-
392
361
  @auto_docstring
393
362
  class Qwen3VLMoePreTrainedModel(PreTrainedModel):
394
363
  config: Qwen3VLMoeConfig
@@ -399,7 +368,9 @@ class Qwen3VLMoePreTrainedModel(PreTrainedModel):
399
368
  _supports_flash_attn = True
400
369
  _supports_sdpa = True
401
370
  _supports_flex_attn = True
402
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
371
+ _can_compile_fullgraph = (
372
+ is_grouped_mm_available()
373
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
403
374
  _supports_attention_backend = True
404
375
  _can_record_outputs = {
405
376
  "router_logits": OutputRecorder(Qwen3VLMoeTextTopKRouter, layer_name="mlp.gate", index=0),
@@ -418,6 +389,27 @@ class Qwen3VLMoePreTrainedModel(PreTrainedModel):
418
389
  if isinstance(module, Qwen3VLMoeTextExperts):
419
390
  init.normal_(module.gate_up_proj, mean=0.0, std=std)
420
391
  init.normal_(module.down_proj, mean=0.0, std=std)
392
+ elif isinstance(module, Qwen3VLMoeTextTopKRouter):
393
+ init.normal_(module.weight, mean=0.0, std=std)
394
+ elif isinstance(module, Qwen3VLMoeVisionRotaryEmbedding):
395
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
396
+ init.copy_(module.inv_freq, inv_freq)
397
+
398
+
399
+ class Qwen3VLMoeVisionRotaryEmbedding(nn.Module):
400
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
401
+
402
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
403
+ super().__init__()
404
+ self.dim = dim
405
+ self.theta = theta
406
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
407
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
408
+
409
+ def forward(self, seqlen: int) -> torch.Tensor:
410
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
411
+ freqs = torch.outer(seq, self.inv_freq)
412
+ return freqs
421
413
 
422
414
 
423
415
  class Qwen3VLMoeVisionMLP(nn.Module):
@@ -453,20 +445,6 @@ class Qwen3VLMoeVisionPatchEmbed(nn.Module):
453
445
  return hidden_states
454
446
 
455
447
 
456
- class Qwen3VLMoeVisionRotaryEmbedding(nn.Module):
457
- inv_freq: torch.Tensor # fix linting for `register_buffer`
458
-
459
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
460
- super().__init__()
461
- inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
462
- self.register_buffer("inv_freq", inv_freq, persistent=False)
463
-
464
- def forward(self, seqlen: int) -> torch.Tensor:
465
- seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
466
- freqs = torch.outer(seq, self.inv_freq)
467
- return freqs
468
-
469
-
470
448
  class Qwen3VLMoeVisionPatchMerger(nn.Module):
471
449
  def __init__(self, config: Qwen3VLMoeVisionConfig, use_postshuffle_norm=False) -> None:
472
450
  super().__init__()
@@ -534,8 +512,8 @@ class Qwen3VLMoeVisionAttention(nn.Module):
534
512
  if self.config._attn_implementation != "eager":
535
513
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
536
514
 
537
- if self.config._attn_implementation == "flash_attention_2":
538
- # Flash Attention 2: Use cu_seqlens for variable length attention
515
+ if "flash" in self.config._attn_implementation:
516
+ # Flash Attention: Use cu_seqlens for variable length attention
539
517
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
540
518
  attn_output, _ = attention_interface(
541
519
  self,
@@ -646,6 +624,8 @@ class Qwen3VLMoeVisionModel(Qwen3VLMoePreTrainedModel):
646
624
 
647
625
  self.gradient_checkpointing = False
648
626
 
627
+ self.post_init()
628
+
649
629
  def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
650
630
  merge_size = self.spatial_merge_size
651
631
 
@@ -815,7 +795,7 @@ class Qwen3VLMoeTextRotaryEmbedding(nn.Module):
815
795
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
816
796
 
817
797
  self.register_buffer("inv_freq", inv_freq, persistent=False)
818
- self.original_inv_freq = inv_freq
798
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
819
799
 
820
800
  self.mrope_section = config.rope_parameters.get("mrope_section", [24, 20, 20])
821
801
 
@@ -860,7 +840,7 @@ class Qwen3VLMoeTextRotaryEmbedding(nn.Module):
860
840
  position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
861
841
 
862
842
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
863
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
843
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
864
844
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
865
845
  freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
866
846
  emb = torch.cat((freqs, freqs), dim=-1)
@@ -1358,44 +1338,19 @@ class Qwen3VLMoeModel(Qwen3VLMoePreTrainedModel):
1358
1338
  deepstack_visual_embeds = deepstack_video_embeds
1359
1339
 
1360
1340
  if position_ids is None:
1361
- attention_mask_tensor = (
1362
- attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
1363
- )
1364
- if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
1365
- attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
1366
- # Only apply conversion for floating point tensors (inverted masks)
1367
- if attention_mask_tensor.dtype.is_floating_point:
1368
- attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
1369
- attention_mask_tensor = (1.0 - attention_mask_tensor).int()
1370
-
1371
- # Calculate RoPE index once per generation in the pre-fill stage only.
1372
- # When compiling, we can't check tensor values thus we check only input length
1373
- # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1374
- # models currently cannot do asssisted decoding
1375
- prefill_compiled_stage = is_torchdynamo_compiling() and (
1376
- (input_ids is not None and input_ids.shape[1] != 1)
1377
- or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
1378
- )
1379
- prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
1380
- (cache_position is not None and cache_position[0] == 0)
1381
- or (past_key_values is None or past_key_values.get_seq_length() == 0)
1382
- )
1383
- if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
1341
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1342
+ if self.rope_deltas is None or past_key_values_length == 0:
1384
1343
  position_ids, rope_deltas = self.get_rope_index(
1385
1344
  input_ids,
1386
1345
  image_grid_thw,
1387
1346
  video_grid_thw,
1388
- attention_mask=attention_mask_tensor,
1347
+ attention_mask=attention_mask,
1389
1348
  )
1390
1349
  self.rope_deltas = rope_deltas
1391
1350
  # then use the prev pre-calculated rope-deltas to get the correct position ids
1392
1351
  else:
1393
1352
  batch_size, seq_length, _ = inputs_embeds.shape
1394
- delta = (
1395
- (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1396
- if cache_position is not None
1397
- else 0
1398
- )
1353
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
1399
1354
  position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1400
1355
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1401
1356
  if cache_position is not None: # otherwise `deltas` is an int `0`
@@ -1532,7 +1487,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMi
1532
1487
  def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1533
1488
  return self.model.get_image_features(pixel_values, image_grid_thw)
1534
1489
 
1535
- @check_model_inputs
1490
+ @can_return_tuple
1536
1491
  def forward(
1537
1492
  self,
1538
1493
  input_ids: torch.LongTensor = None,
@@ -1642,6 +1597,8 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMi
1642
1597
  aux_loss=aux_loss,
1643
1598
  logits=logits,
1644
1599
  past_key_values=outputs.past_key_values,
1600
+ hidden_states=outputs.hidden_states,
1601
+ attentions=outputs.attentions,
1645
1602
  rope_deltas=outputs.rope_deltas,
1646
1603
  )
1647
1604
 
@@ -1658,6 +1615,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMi
1658
1615
  pixel_values_videos=None,
1659
1616
  image_grid_thw=None,
1660
1617
  video_grid_thw=None,
1618
+ is_first_iteration=False,
1661
1619
  **kwargs,
1662
1620
  ):
1663
1621
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -1674,13 +1632,39 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMi
1674
1632
  image_grid_thw=image_grid_thw,
1675
1633
  video_grid_thw=video_grid_thw,
1676
1634
  use_cache=use_cache,
1635
+ is_first_iteration=is_first_iteration,
1677
1636
  **kwargs,
1678
1637
  )
1679
1638
 
1680
- # Qwen3VLMoe position_ids are prepareed with rope_deltas in forward
1681
- model_inputs["position_ids"] = None
1682
-
1683
- if cache_position[0] != 0:
1639
+ # Qwen3VLMoe position_ids are prepared with rope_deltas
1640
+ if position_ids is None:
1641
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1642
+ # When compiling, we can't check tensor values thus we check only input length
1643
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1644
+ # models currently cannot do asssisted decoding
1645
+ if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
1646
+ vision_positions, rope_deltas = self.model.get_rope_index(
1647
+ model_inputs.get("input_ids", None),
1648
+ image_grid_thw=image_grid_thw,
1649
+ video_grid_thw=video_grid_thw,
1650
+ attention_mask=attention_mask,
1651
+ )
1652
+ self.model.rope_deltas = rope_deltas
1653
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1654
+ elif "position_ids" in model_inputs:
1655
+ batch_size, seq_length = model_inputs["position_ids"].shape
1656
+ device = model_inputs["position_ids"].device
1657
+ position_ids = torch.arange(seq_length, device=device)
1658
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1659
+ delta = cache_position[0] + self.model.rope_deltas
1660
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1661
+ vision_positions = position_ids + delta.expand_as(position_ids)
1662
+
1663
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
1664
+ text_positions = model_inputs["position_ids"][None, ...]
1665
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
1666
+
1667
+ if not is_first_iteration and use_cache:
1684
1668
  model_inputs["pixel_values"] = None
1685
1669
  model_inputs["pixel_values_videos"] = None
1686
1670
 
@@ -18,19 +18,21 @@ from typing import Optional, Union
18
18
 
19
19
  import torch
20
20
  import torch.nn as nn
21
+ import torch.nn.functional as F
21
22
 
22
23
  from ... import initialization as init
23
- from ...activations import ACT2FN
24
24
  from ...cache_utils import Cache
25
25
  from ...configuration_utils import PreTrainedConfig
26
26
  from ...modeling_rope_utils import RopeParameters
27
27
  from ...modeling_utils import PreTrainedModel
28
28
  from ...processing_utils import Unpack
29
- from ...utils import TransformersKwargs, logging
29
+ from ...utils import TransformersKwargs, can_return_tuple, logging
30
30
  from ..qwen3_moe.modeling_qwen3_moe import (
31
31
  Qwen3MoeDecoderLayer,
32
+ Qwen3MoeExperts,
32
33
  Qwen3MoePreTrainedModel,
33
34
  Qwen3MoeRMSNorm,
35
+ Qwen3MoeSparseMoeBlock,
34
36
  load_balancing_loss_func,
35
37
  )
36
38
  from ..qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
@@ -41,6 +43,7 @@ from ..qwen3_vl.modeling_qwen3_vl import (
41
43
  Qwen3VLTextAttention,
42
44
  Qwen3VLTextModel,
43
45
  Qwen3VLVisionModel,
46
+ Qwen3VLVisionRotaryEmbedding,
44
47
  )
45
48
 
46
49
 
@@ -257,92 +260,31 @@ class Qwen3VLMoeTextRMSNorm(Qwen3MoeRMSNorm):
257
260
  pass
258
261
 
259
262
 
260
- class Qwen3VLMoeTextExperts(nn.Module):
263
+ class Qwen3VLMoeTextExperts(Qwen3MoeExperts):
264
+ pass
265
+
266
+
267
+ class Qwen3VLMoeTextTopKRouter(nn.Module):
261
268
  def __init__(self, config):
262
269
  super().__init__()
270
+ self.top_k = config.num_experts_per_tok
263
271
  self.num_experts = config.num_experts
264
- self.intermediate_size = config.moe_intermediate_size
265
- self.hidden_size = config.hidden_size
266
- self.expert_dim = self.intermediate_size
267
- self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
268
- self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
269
- self.act_fn = ACT2FN[config.hidden_act]
272
+ self.hidden_dim = config.hidden_size
273
+ self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
270
274
 
271
- def forward(
272
- self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
273
- ) -> torch.Tensor:
274
- """
275
- When training it is more efficient to just loop over the experts and compute the output for each expert
276
- as otherwise the memory would explode.
277
-
278
- For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
279
-
280
- Args:
281
- hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
282
- routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
283
- router_indices (torch.Tensor): (batch_size * token_num, top_k)
284
- Returns:
285
- torch.Tensor
286
- """
287
- batch_size = hidden_states.shape[0]
288
- hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
289
- if self.training:
290
- next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
291
- with torch.no_grad():
292
- expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
293
- expert_mask = expert_mask.permute(2, 1, 0)
294
- # we sum on the top_k and on the sequence length to get which experts
295
- # are hit this time around
296
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
297
- for expert_idx in expert_hit[:]:
298
- with torch.no_grad():
299
- _, token_idx = torch.where(expert_mask[expert_idx[0]])
300
- current_state = hidden_states[token_idx]
301
- gate_up = current_state @ self.gate_up_proj[expert_idx]
302
- gate, up = gate_up.chunk(2, dim=-1)
303
- gated_output = up * self.act_fn(gate)
304
- out = gated_output @ self.down_proj[expert_idx]
305
- weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
306
- next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
307
- next_states = next_states.view(batch_size, -1, self.hidden_size)
308
- else:
309
- hidden_states = hidden_states.repeat(self.num_experts, 1)
310
- hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
311
- gate_up = torch.bmm(hidden_states, self.gate_up_proj)
312
- gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
313
- next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
314
- next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size)
315
- next_states = (
316
- next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
317
- )
318
- next_states = next_states.sum(dim=0)
319
- return next_states
275
+ def forward(self, hidden_states):
276
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
277
+ router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
278
+ router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
279
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
280
+ router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
281
+ router_top_value = router_top_value.to(router_logits.dtype)
282
+ router_scores = router_top_value
283
+ return router_logits, router_scores, router_indices
320
284
 
321
285
 
322
- class Qwen3VLMoeTextSparseMoeBlock(nn.Module):
323
- def __init__(self, config):
324
- super().__init__()
325
- self.hidden_size = config.hidden_size
326
- self.num_experts = config.num_experts
327
- self.top_k = config.num_experts_per_tok
328
- self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
329
- self.experts = Qwen3VLMoeTextExperts(config)
330
-
331
- # since all the models use norm_topk_prob, we don't need to have a extra check for it
332
- # self.norm_topk_prob = config.norm_topk_prob
333
-
334
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
335
- batch_size = hidden_states.shape[0]
336
- hidden_states = hidden_states.reshape(-1, self.hidden_size)
337
- router_logits = self.gate(hidden_states)
338
- routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
339
- routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
340
- routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
341
- routing_weights = routing_weights.to(router_logits.dtype)
342
- router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
343
- hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
344
- routed_out = self.experts(hidden_states, router_weights, router_indices)
345
- return routed_out
286
+ class Qwen3VLMoeTextSparseMoeBlock(Qwen3MoeSparseMoeBlock):
287
+ pass
346
288
 
347
289
 
348
290
  class Qwen3VLMoeTextAttention(Qwen3VLTextAttention):
@@ -368,6 +310,15 @@ class Qwen3VLMoePreTrainedModel(Qwen3MoePreTrainedModel):
368
310
  if isinstance(module, Qwen3VLMoeTextExperts):
369
311
  init.normal_(module.gate_up_proj, mean=0.0, std=std)
370
312
  init.normal_(module.down_proj, mean=0.0, std=std)
313
+ elif isinstance(module, Qwen3VLMoeTextTopKRouter):
314
+ init.normal_(module.weight, mean=0.0, std=std)
315
+ elif isinstance(module, Qwen3VLMoeVisionRotaryEmbedding):
316
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
317
+ init.copy_(module.inv_freq, inv_freq)
318
+
319
+
320
+ class Qwen3VLMoeVisionRotaryEmbedding(Qwen3VLVisionRotaryEmbedding):
321
+ pass
371
322
 
372
323
 
373
324
  class Qwen3VLMoeVisionModel(Qwen3VLVisionModel):
@@ -387,6 +338,7 @@ class Qwen3VLMoeModel(Qwen3VLModel):
387
338
 
388
339
 
389
340
  class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
341
+ @can_return_tuple
390
342
  def forward(
391
343
  self,
392
344
  input_ids: torch.LongTensor = None,
@@ -496,6 +448,8 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
496
448
  aux_loss=aux_loss,
497
449
  logits=logits,
498
450
  past_key_values=outputs.past_key_values,
451
+ hidden_states=outputs.hidden_states,
452
+ attentions=outputs.attentions,
499
453
  rope_deltas=outputs.rope_deltas,
500
454
  )
501
455
 
@@ -70,9 +70,6 @@ RAG_CONFIG_DOC = r"""
70
70
  `context_attention_mask` are returned. See returned tensors for more detail.
71
71
  use_cache (`bool`, *optional*, defaults to `True`):
72
72
  Whether or not the model should return the last key/values attentions (not used by all models).
73
- forced_eos_token_id (`int`, *optional*):
74
- The id of the token to force as the last generated token when `max_length` is reached. Usually set to
75
- `eos_token_id`.
76
73
  """
77
74
 
78
75
 
@@ -109,7 +106,6 @@ class RagConfig(PreTrainedConfig):
109
106
  do_marginalize=False,
110
107
  output_retrieved=False,
111
108
  use_cache=True,
112
- forced_eos_token_id=None,
113
109
  dataset_revision=None,
114
110
  **kwargs,
115
111
  ):
@@ -118,7 +114,6 @@ class RagConfig(PreTrainedConfig):
118
114
  pad_token_id=pad_token_id,
119
115
  eos_token_id=eos_token_id,
120
116
  decoder_start_token_id=decoder_start_token_id,
121
- forced_eos_token_id=forced_eos_token_id,
122
117
  is_encoder_decoder=is_encoder_decoder,
123
118
  prefix=prefix,
124
119
  vocab_size=vocab_size,
@@ -166,9 +161,6 @@ class RagConfig(PreTrainedConfig):
166
161
 
167
162
  self.use_cache = use_cache
168
163
 
169
- if forced_eos_token_id is None:
170
- self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)
171
-
172
164
  @classmethod
173
165
  def from_question_encoder_generator_configs(
174
166
  cls, question_encoder_config: PreTrainedConfig, generator_config: PreTrainedConfig, **kwargs