transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ states before downsampling, which is different from the default Swin Transformer
19
19
  import collections.abc
20
20
  import math
21
21
  from dataclasses import dataclass
22
- from typing import Optional
22
+ from typing import Optional, Union
23
23
 
24
24
  import torch
25
25
  from torch import Tensor, nn
@@ -331,18 +331,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
331
331
  torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
332
332
  )
333
333
 
334
- # get pair-wise relative position index for each token inside the window
335
- coords_h = torch.arange(self.window_size[0])
336
- coords_w = torch.arange(self.window_size[1])
337
- coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
338
- coords_flatten = torch.flatten(coords, 1)
339
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
340
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
341
- relative_coords[:, :, 0] += self.window_size[0] - 1
342
- relative_coords[:, :, 1] += self.window_size[1] - 1
343
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
344
- relative_position_index = relative_coords.sum(-1)
345
- self.register_buffer("relative_position_index", relative_position_index)
334
+ self.register_buffer("relative_position_index", self.create_relative_position_index())
346
335
 
347
336
  self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
348
337
  self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
@@ -401,6 +390,20 @@ class MaskFormerSwinSelfAttention(nn.Module):
401
390
 
402
391
  return outputs
403
392
 
393
+ def create_relative_position_index(self):
394
+ # get pair-wise relative position index for each token inside the window
395
+ coords_h = torch.arange(self.window_size[0])
396
+ coords_w = torch.arange(self.window_size[1])
397
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
398
+ coords_flatten = torch.flatten(coords, 1)
399
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
400
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
401
+ relative_coords[:, :, 0] += self.window_size[0] - 1
402
+ relative_coords[:, :, 1] += self.window_size[1] - 1
403
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
404
+ relative_position_index = relative_coords.sum(-1)
405
+ return relative_position_index
406
+
404
407
 
405
408
  # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
406
409
  class MaskFormerSwinSelfOutput(nn.Module):
@@ -656,7 +659,7 @@ class MaskFormerSwinEncoder(nn.Module):
656
659
  output_attentions=False,
657
660
  output_hidden_states=False,
658
661
  return_dict=True,
659
- ):
662
+ ) -> Union[tuple, MaskFormerSwinBaseModelOutput]:
660
663
  all_hidden_states = () if output_hidden_states else None
661
664
  all_input_dimensions = ()
662
665
  all_self_attentions = () if output_attentions else None
@@ -711,6 +714,7 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
711
714
  init.zeros_(module.position_embeddings)
712
715
  elif isinstance(module, MaskFormerSwinSelfAttention):
713
716
  init.zeros_(module.relative_position_bias_table)
717
+ init.copy_(module.relative_position_index, module.create_relative_position_index())
714
718
 
715
719
 
716
720
  class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
@@ -738,7 +742,8 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
738
742
  output_hidden_states=None,
739
743
  interpolate_pos_encoding=False,
740
744
  return_dict=None,
741
- ):
745
+ **kwargs,
746
+ ) -> Union[tuple, MaskFormerSwinModelOutputWithPooling]:
742
747
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
743
748
  output_hidden_states = (
744
749
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -815,6 +820,7 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
815
820
  output_hidden_states: Optional[bool] = None,
816
821
  output_attentions: Optional[bool] = None,
817
822
  return_dict: Optional[bool] = None,
823
+ **kwargs,
818
824
  ) -> BackboneOutput:
819
825
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
820
826
  output_hidden_states = (
@@ -147,6 +147,7 @@ class MBartConfig(PreTrainedConfig):
147
147
  self.use_cache = use_cache
148
148
  self.num_hidden_layers = encoder_layers
149
149
  self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
150
+
150
151
  super().__init__(
151
152
  pad_token_id=pad_token_id,
152
153
  bos_token_id=bos_token_id,
@@ -22,6 +22,7 @@ import torch
22
22
  from torch import nn
23
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
24
 
25
+ from ... import initialization as init
25
26
  from ...activations import ACT2FN
26
27
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
27
28
  from ...generation import GenerationMixin
@@ -478,6 +479,11 @@ class MBartPreTrainedModel(PreTrainedModel):
478
479
  _supports_flex_attn = True
479
480
  _can_compile_fullgraph = True
480
481
 
482
+ def _init_weights(self, module):
483
+ super()._init_weights(module)
484
+ if isinstance(module, MBartForConditionalGeneration):
485
+ init.zeros_(module.final_logits_bias)
486
+
481
487
  @property
482
488
  def dummy_inputs(self):
483
489
  pad_token = self.config.pad_token_id
@@ -540,6 +546,7 @@ class MBartEncoder(MBartPreTrainedModel):
540
546
  output_attentions: Optional[bool] = None,
541
547
  output_hidden_states: Optional[bool] = None,
542
548
  return_dict: Optional[bool] = None,
549
+ **kwargs,
543
550
  ) -> Union[tuple, BaseModelOutput]:
544
551
  r"""
545
552
  Args:
@@ -691,6 +698,7 @@ class MBartDecoder(MBartPreTrainedModel):
691
698
  output_hidden_states: Optional[bool] = None,
692
699
  return_dict: Optional[bool] = None,
693
700
  cache_position: Optional[torch.Tensor] = None,
701
+ **kwargs,
694
702
  ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
695
703
  r"""
696
704
  Args:
@@ -919,6 +927,7 @@ class MBartModel(MBartPreTrainedModel):
919
927
  output_hidden_states: Optional[bool] = None,
920
928
  return_dict: Optional[bool] = None,
921
929
  cache_position: Optional[torch.Tensor] = None,
930
+ **kwargs,
922
931
  ) -> Union[Seq2SeqModelOutput, tuple[torch.FloatTensor]]:
923
932
  r"""
924
933
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1052,6 +1061,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin):
1052
1061
  output_hidden_states: Optional[bool] = None,
1053
1062
  return_dict: Optional[bool] = None,
1054
1063
  cache_position: Optional[torch.Tensor] = None,
1064
+ **kwargs,
1055
1065
  ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
1056
1066
  r"""
1057
1067
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1205,6 +1215,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
1205
1215
  output_hidden_states: Optional[bool] = None,
1206
1216
  return_dict: Optional[bool] = None,
1207
1217
  cache_position: Optional[torch.LongTensor] = None,
1218
+ **kwargs,
1208
1219
  ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
1209
1220
  r"""
1210
1221
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1338,6 +1349,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
1338
1349
  output_hidden_states: Optional[bool] = None,
1339
1350
  return_dict: Optional[bool] = None,
1340
1351
  cache_position: Optional[torch.LongTensor] = None,
1352
+ **kwargs,
1341
1353
  ) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]:
1342
1354
  r"""
1343
1355
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1436,6 +1448,7 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
1436
1448
  def __init__(self, config):
1437
1449
  super().__init__(config)
1438
1450
  self.decoder = MBartDecoder(config)
1451
+ self.post_init()
1439
1452
 
1440
1453
  def forward(self, *args, **kwargs):
1441
1454
  return self.decoder(*args, **kwargs)
@@ -1480,6 +1493,7 @@ class MBartForCausalLM(MBartPreTrainedModel, GenerationMixin):
1480
1493
  return_dict: Optional[bool] = None,
1481
1494
  cache_position: Optional[torch.LongTensor] = None,
1482
1495
  logits_to_keep: Union[int, torch.Tensor] = 0,
1496
+ **kwargs,
1483
1497
  ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
1484
1498
  r"""
1485
1499
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional
16
+ from typing import Optional, Union
17
17
 
18
18
  from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
19
19
  from tokenizers.models import Unigram
@@ -58,13 +58,14 @@ class MBartTokenizer(TokenizersBackend):
58
58
 
59
59
  vocab_files_names = VOCAB_FILES_NAMES
60
60
  model_input_names = ["input_ids", "attention_mask"]
61
- slow_tokenizer_class = None
61
+ model = Unigram
62
62
 
63
63
  prefix_tokens: list[int] = []
64
64
  suffix_tokens: list[int] = []
65
65
 
66
66
  def __init__(
67
67
  self,
68
+ vocab: Optional[Union[str, dict, list]] = None,
68
69
  bos_token="<s>",
69
70
  eos_token="</s>",
70
71
  sep_token="</s>",
@@ -75,9 +76,6 @@ class MBartTokenizer(TokenizersBackend):
75
76
  src_lang=None,
76
77
  tgt_lang=None,
77
78
  additional_special_tokens=None,
78
- vocab=None,
79
- merges=None, # Ignored for Unigram
80
- vocab_file=None,
81
79
  **kwargs,
82
80
  ):
83
81
  mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
@@ -88,56 +86,20 @@ class MBartTokenizer(TokenizersBackend):
88
86
  [t for t in additional_special_tokens if t not in _additional_special_tokens]
89
87
  )
90
88
 
91
- # MBart uses fairseq vocab alignment: <s>=0, <pad>=1, </s>=2, <unk>=3, then SPM pieces[3:], lang codes, <mask>
92
- if vocab is not None:
93
- # Handle different vocab formats (dict, list of tokens, or list of tuples)
94
- # SentencePieceExtractor returns list[tuple[str, float]] which is the expected format
95
- if isinstance(vocab, dict):
96
- vocab = [(token, 0.0) for token in vocab.keys()]
97
- elif isinstance(vocab, list) and len(vocab) > 0:
98
- if not isinstance(vocab[0], tuple):
99
- vocab = [(token, 0.0) for token in vocab]
100
- else:
101
- # Ensure tuples are (str, float) format
102
- vocab = [(str(item[0]), float(item[1])) for item in vocab]
103
-
104
- # Reorder to fairseq: <s>, <pad>, </s>, <unk>, ... (rest of vocab from SPM[3:])
105
- vocab_list = []
106
- vocab_list.append((str(bos_token), 0.0))
107
- vocab_list.append((str(pad_token), 0.0))
108
- vocab_list.append((str(eos_token), 0.0))
109
- vocab_list.append((str(unk_token), 0.0))
110
-
111
- # Add the rest of the SentencePiece vocab (skipping first 3: <unk>, <s>, </s>)
112
- vocab_list.extend(vocab[4:])
113
-
114
- # Add language codes
115
- for lang_code in FAIRSEQ_LANGUAGE_CODES:
116
- vocab_list.append((str(lang_code), 0.0))
117
-
118
- # Add mask token
119
- vocab_list.append((str(mask_token), 0.0))
120
-
121
- self._vocab_scores = vocab_list
122
- else:
123
- self._vocab_scores = [
89
+ if vocab is None:
90
+ vocab = [
124
91
  (str(bos_token), 0.0),
125
92
  (str(pad_token), 0.0),
126
93
  (str(eos_token), 0.0),
127
94
  (str(unk_token), 0.0),
128
- ("▁", -2.0),
129
95
  ]
96
+ vocab += [("▁", -2.0)]
130
97
  for lang_code in FAIRSEQ_LANGUAGE_CODES:
131
- self._vocab_scores.append((lang_code, 0.0))
132
- self._vocab_scores.append((str(mask_token), 0.0))
133
-
134
- self._tokenizer = Tokenizer(
135
- Unigram(
136
- self._vocab_scores,
137
- unk_id=3,
138
- byte_fallback=False,
139
- )
140
- )
98
+ vocab.append((lang_code, 0.0))
99
+ vocab.append((str(mask_token), 0.0))
100
+
101
+ self._vocab = vocab
102
+ self._tokenizer = Tokenizer(Unigram(self._vocab, unk_id=3, byte_fallback=False))
141
103
 
142
104
  self._tokenizer.normalizer = None
143
105
 
@@ -150,10 +112,7 @@ class MBartTokenizer(TokenizersBackend):
150
112
 
151
113
  self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
152
114
 
153
- tokenizer_object = self._tokenizer
154
-
155
115
  super().__init__(
156
- tokenizer_object=tokenizer_object,
157
116
  bos_token=bos_token,
158
117
  eos_token=eos_token,
159
118
  sep_token=sep_token,
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional
16
+ from typing import Optional, Union
17
17
 
18
18
  from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
19
19
  from tokenizers.models import Unigram
@@ -79,13 +79,14 @@ class MBart50Tokenizer(TokenizersBackend):
79
79
 
80
80
  vocab_files_names = VOCAB_FILES_NAMES
81
81
  model_input_names = ["input_ids", "attention_mask"]
82
- slow_tokenizer_class = None
82
+ model = Unigram
83
83
 
84
84
  prefix_tokens: list[int] = []
85
85
  suffix_tokens: list[int] = []
86
86
 
87
87
  def __init__(
88
88
  self,
89
+ vocab: Optional[Union[str, dict, list]] = None,
89
90
  src_lang=None,
90
91
  tgt_lang=None,
91
92
  eos_token="</s>",
@@ -94,21 +95,16 @@ class MBart50Tokenizer(TokenizersBackend):
94
95
  unk_token="<unk>",
95
96
  pad_token="<pad>",
96
97
  mask_token="<mask>",
97
- vocab=None,
98
- merges=None, # Ignored for Unigram
99
- vocab_file=None,
100
98
  **kwargs,
101
99
  ):
102
100
  mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
103
101
 
104
- self.vocab_file = vocab_file
105
-
106
102
  # Do not pass language codes via extra_special_tokens to super().__init__.
107
103
  # We will mark them as special AFTER backend construction to avoid re-adding tokens
108
104
  # when loading from pretrained files.
109
105
 
110
106
  # Always construct a tokenizer_object without referencing external tokenizer files
111
- if vocab is not None:
107
+ if isinstance(vocab, list):
112
108
  # MBart50 uses fairseq vocab alignment matching MBart50Converter:
113
109
  # <s>=0, <pad>=1, </s>=2, <unk>=3, then tokens, lang codes, <mask>
114
110
 
@@ -180,9 +176,9 @@ class MBart50Tokenizer(TokenizersBackend):
180
176
  self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True)
181
177
 
182
178
  self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
183
-
179
+ additional_special_tokens = kwargs.pop("additional_special_tokens", []) or []
180
+ additional_special_tokens.extend(FAIRSEQ_LANGUAGE_CODES)
184
181
  super().__init__(
185
- tokenizer_object=self._tokenizer,
186
182
  src_lang=src_lang,
187
183
  tgt_lang=tgt_lang,
188
184
  eos_token=eos_token,
@@ -191,6 +187,7 @@ class MBart50Tokenizer(TokenizersBackend):
191
187
  unk_token=unk_token,
192
188
  pad_token=pad_token,
193
189
  mask_token=mask_token,
190
+ additional_special_tokens=additional_special_tokens,
194
191
  **kwargs,
195
192
  )
196
193
 
@@ -528,6 +528,8 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
528
528
  super()._init_weights(module)
529
529
  if isinstance(module, MegatronBertLMPredictionHead):
530
530
  init.zeros_(module.bias)
531
+ elif isinstance(module, MegatronBertEmbeddings):
532
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
531
533
 
532
534
 
533
535
  @dataclass
@@ -608,6 +610,7 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
608
610
  output_hidden_states: Optional[bool] = None,
609
611
  return_dict: Optional[bool] = None,
610
612
  cache_position: Optional[torch.Tensor] = None,
613
+ **kwargs,
611
614
  ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
612
615
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
613
616
  output_hidden_states = (
@@ -735,6 +738,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
735
738
  output_attentions: Optional[bool] = None,
736
739
  output_hidden_states: Optional[bool] = None,
737
740
  return_dict: Optional[bool] = None,
741
+ **kwargs,
738
742
  ) -> Union[tuple, MegatronBertForPreTrainingOutput]:
739
743
  r"""
740
744
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -955,6 +959,7 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
955
959
  output_attentions: Optional[bool] = None,
956
960
  output_hidden_states: Optional[bool] = None,
957
961
  return_dict: Optional[bool] = None,
962
+ **kwargs,
958
963
  ) -> Union[tuple, MaskedLMOutput]:
959
964
  r"""
960
965
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1140,6 +1145,7 @@ class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel):
1140
1145
  output_attentions: Optional[bool] = None,
1141
1146
  output_hidden_states: Optional[bool] = None,
1142
1147
  return_dict: Optional[bool] = None,
1148
+ **kwargs,
1143
1149
  ) -> Union[tuple, SequenceClassifierOutput]:
1144
1150
  r"""
1145
1151
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1223,6 +1229,7 @@ class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel):
1223
1229
  output_attentions: Optional[bool] = None,
1224
1230
  output_hidden_states: Optional[bool] = None,
1225
1231
  return_dict: Optional[bool] = None,
1232
+ **kwargs,
1226
1233
  ) -> Union[tuple, MultipleChoiceModelOutput]:
1227
1234
  r"""
1228
1235
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -1326,6 +1333,7 @@ class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
1326
1333
  output_attentions: Optional[bool] = None,
1327
1334
  output_hidden_states: Optional[bool] = None,
1328
1335
  return_dict: Optional[bool] = None,
1336
+ **kwargs,
1329
1337
  ) -> Union[tuple, TokenClassifierOutput]:
1330
1338
  r"""
1331
1339
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1391,6 +1399,7 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
1391
1399
  output_attentions: Optional[bool] = None,
1392
1400
  output_hidden_states: Optional[bool] = None,
1393
1401
  return_dict: Optional[bool] = None,
1402
+ **kwargs,
1394
1403
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1395
1404
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1396
1405
 
@@ -306,11 +306,13 @@ class MetaClip2PreTrainedModel(PreTrainedModel):
306
306
  if isinstance(module, MetaClip2TextEmbeddings):
307
307
  init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
308
308
  init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
309
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
309
310
  elif isinstance(module, MetaClip2VisionEmbeddings):
310
311
  factor = self.config.initializer_factor
311
312
  init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
312
313
  init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
313
314
  init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
315
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
314
316
  elif isinstance(module, MetaClip2Attention):
315
317
  factor = self.config.initializer_factor
316
318
  in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
@@ -225,11 +225,13 @@ class MetaClip2PreTrainedModel(CLIPPreTrainedModel):
225
225
  if isinstance(module, MetaClip2TextEmbeddings):
226
226
  init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
227
227
  init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
228
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
228
229
  elif isinstance(module, MetaClip2VisionEmbeddings):
229
230
  factor = self.config.initializer_factor
230
231
  init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
231
232
  init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
232
233
  init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
234
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
233
235
  elif isinstance(module, MetaClip2Attention):
234
236
  factor = self.config.initializer_factor
235
237
  in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
@@ -322,6 +322,7 @@ class MgpstrModel(MgpstrPreTrainedModel):
322
322
  output_attentions: Optional[bool] = None,
323
323
  output_hidden_states: Optional[bool] = None,
324
324
  return_dict: Optional[bool] = None,
325
+ **kwargs,
325
326
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
326
327
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
327
328
  output_hidden_states = (
@@ -385,6 +386,7 @@ class MgpstrForSceneTextRecognition(MgpstrPreTrainedModel):
385
386
  output_a3_attentions: Optional[bool] = None,
386
387
  output_hidden_states: Optional[bool] = None,
387
388
  return_dict: Optional[bool] = None,
389
+ **kwargs,
388
390
  ) -> Union[tuple[torch.FloatTensor], MgpstrModelOutput]:
389
391
  r"""
390
392
  output_a3_attentions (`bool`, *optional*):
@@ -32,6 +32,7 @@ from ...modeling_outputs import BaseModelOutputWithPast
32
32
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
33
33
  from ...modeling_utils import PreTrainedModel
34
34
  from ...utils import ModelOutput, auto_docstring, logging
35
+ from ...utils.generic import maybe_autocast
35
36
  from .configuration_mimi import MimiConfig
36
37
 
37
38
 
@@ -520,7 +521,7 @@ class MimiRotaryEmbedding(nn.Module):
520
521
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
521
522
 
522
523
  self.register_buffer("inv_freq", inv_freq, persistent=False)
523
- self.original_inv_freq = inv_freq
524
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
524
525
 
525
526
  @staticmethod
526
527
  def compute_default_rope_parameters(
@@ -559,7 +560,7 @@ class MimiRotaryEmbedding(nn.Module):
559
560
  position_ids_expanded = position_ids[:, None, :].float()
560
561
 
561
562
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
562
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
563
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
563
564
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
564
565
  emb = torch.cat((freqs, freqs), dim=-1)
565
566
  cos = emb.cos() * self.attention_scaling
@@ -813,8 +814,8 @@ class MimiFlashAttention2(MimiAttention):
813
814
  else torch.get_autocast_gpu_dtype()
814
815
  )
815
816
  # Handle the case where the model is quantized
816
- elif hasattr(self.config, "_pre_quantization_dtype"):
817
- target_dtype = self.config._pre_quantization_dtype
817
+ elif hasattr(self.config, "quantization_config"):
818
+ target_dtype = self.config.dtype
818
819
  else:
819
820
  target_dtype = self.q_proj.weight.dtype
820
821
 
@@ -1379,7 +1380,7 @@ class MimiPreTrainedModel(PreTrainedModel):
1379
1380
  main_input_name = "input_values"
1380
1381
  input_modalities = "audio"
1381
1382
  supports_gradient_checkpointing = True
1382
- _no_split_modules = ["MimiDecoderLayer"]
1383
+ _no_split_modules = ["MimiResidualVectorQuantizer", "MimiTransformerLayer"]
1383
1384
  _skip_keys_device_placement = "past_key_values"
1384
1385
  _supports_flash_attn = True
1385
1386
  _supports_sdpa = True
@@ -1403,6 +1404,27 @@ class MimiPreTrainedModel(PreTrainedModel):
1403
1404
  init.uniform_(module.bias, a=-k, b=k)
1404
1405
  elif isinstance(module, MimiLayerScale):
1405
1406
  init.constant_(module.scale, self.config.layer_scale_initial_scale)
1407
+ elif isinstance(module, MimiConv1d):
1408
+ kernel_size = module.conv.kernel_size[0]
1409
+ stride = module.conv.stride[0]
1410
+ dilation = module.conv.dilation[0]
1411
+ kernel_size = (kernel_size - 1) * dilation + 1
1412
+ init.constant_(module.stride, stride)
1413
+ init.constant_(module.kernel_size, kernel_size)
1414
+ init.constant_(module.padding_total, kernel_size - stride)
1415
+ elif isinstance(module, MimiEuclideanCodebook):
1416
+ init.ones_(module.initialized)
1417
+ init.ones_(module.cluster_usage)
1418
+ init.zeros_(module.embed_sum)
1419
+ elif isinstance(module, MimiRotaryEmbedding):
1420
+ rope_fn = (
1421
+ ROPE_INIT_FUNCTIONS[module.rope_type]
1422
+ if module.rope_type != "default"
1423
+ else module.compute_default_rope_parameters
1424
+ )
1425
+ buffer_value, _ = rope_fn(module.config)
1426
+ init.copy_(module.inv_freq, buffer_value)
1427
+ init.copy_(module.original_inv_freq, buffer_value)
1406
1428
 
1407
1429
 
1408
1430
  @auto_docstring(
@@ -1685,6 +1707,7 @@ class MimiModel(MimiPreTrainedModel):
1685
1707
  encoder_past_key_values: Optional[Cache] = None,
1686
1708
  decoder_past_key_values: Optional[Cache] = None,
1687
1709
  return_dict: Optional[bool] = None,
1710
+ **kwargs,
1688
1711
  ) -> Union[tuple[torch.Tensor, torch.Tensor], MimiOutput]:
1689
1712
  r"""
1690
1713
  input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
@@ -31,7 +31,12 @@ from ... import initialization as init
31
31
  from ...activations import ACT2FN
32
32
  from ...cache_utils import Cache, DynamicCache
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
34
+ from ...integrations import (
35
+ use_experts_implementation,
36
+ use_kernel_forward_from_hub,
37
+ use_kernel_func_from_hub,
38
+ use_kernelized_func,
39
+ )
35
40
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
36
41
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
37
42
  from ...modeling_layers import (
@@ -45,7 +50,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
45
50
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
51
  from ...processing_utils import Unpack
47
52
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
48
- from ...utils.generic import OutputRecorder, check_model_inputs
53
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
49
54
  from .configuration_minimax import MiniMaxConfig
50
55
 
51
56
 
@@ -271,7 +276,7 @@ class MiniMaxRotaryEmbedding(nn.Module):
271
276
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
272
277
 
273
278
  self.register_buffer("inv_freq", inv_freq, persistent=False)
274
- self.original_inv_freq = inv_freq
279
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
275
280
 
276
281
  @staticmethod
277
282
  def compute_default_rope_parameters(
@@ -310,7 +315,7 @@ class MiniMaxRotaryEmbedding(nn.Module):
310
315
  position_ids_expanded = position_ids[:, None, :].float()
311
316
 
312
317
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
313
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
318
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
314
319
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
315
320
  emb = torch.cat((freqs, freqs), dim=-1)
316
321
  cos = emb.cos() * self.attention_scaling
@@ -392,6 +397,7 @@ def eager_attention_forward(
392
397
  return attn_output, attn_weights
393
398
 
394
399
 
400
+ @use_kernelized_func(apply_rotary_pos_emb)
395
401
  class MiniMaxAttention(nn.Module):
396
402
  """Multi-headed attention from 'Attention Is All You Need' paper"""
397
403
 
@@ -408,7 +414,6 @@ class MiniMaxAttention(nn.Module):
408
414
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
409
415
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
410
416
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
411
- self.rotary_fn = apply_rotary_pos_emb
412
417
 
413
418
  def forward(
414
419
  self,
@@ -473,6 +478,7 @@ class MiniMaxTopKRouter(nn.Module):
473
478
  return router_logits, router_scores, router_indices
474
479
 
475
480
 
481
+ @use_experts_implementation
476
482
  class MiniMaxExperts(nn.Module):
477
483
  """Collection of expert weights stored as 3D tensors."""
478
484
 
@@ -596,7 +602,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
596
602
  _supports_flash_attn = True
597
603
  _supports_sdpa = True
598
604
  _supports_flex_attn = True
599
- _can_compile_fullgraph = False
605
+ _can_compile_fullgraph = False # uses a non-compilable custom cache class MiniMaxCache
600
606
  _supports_attention_backend = True
601
607
  _can_record_outputs = {
602
608
  "router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0),
@@ -613,6 +619,13 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
613
619
  init.normal_(module.down_proj, mean=0.0, std=std)
614
620
  elif isinstance(module, MiniMaxTopKRouter):
615
621
  init.normal_(module.weight, mean=0.0, std=std)
622
+ if isinstance(module, MiniMaxLightningAttention):
623
+ slope_rate = module.get_slope_rate()
624
+ query_decay, key_decay, diagonal_decay = module.decay_factors(slope_rate)
625
+ init.copy_(module.slope_rate, slope_rate)
626
+ init.copy_(module.query_decay, query_decay)
627
+ init.copy_(module.key_decay, key_decay)
628
+ init.copy_(module.diagonal_decay, diagonal_decay)
616
629
 
617
630
 
618
631
  @auto_docstring
@@ -21,6 +21,7 @@ import torch
21
21
  import torch.nn.functional as F
22
22
  from torch import nn
23
23
 
24
+ from ... import initialization as init
24
25
  from ...activations import ACT2FN
25
26
  from ...cache_utils import Cache, DynamicCache
26
27
  from ...configuration_utils import PreTrainedConfig, layer_type_validation
@@ -520,13 +521,23 @@ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
520
521
 
521
522
 
522
523
  class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
523
- _can_compile_fullgraph = False
524
+ _can_compile_fullgraph = False # uses a non-compilable custom cache class MiniMaxCache
524
525
  _can_record_outputs = {
525
526
  "router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0),
526
527
  "hidden_states": MiniMaxDecoderLayer,
527
528
  "attentions": [MiniMaxAttention, MiniMaxLightningAttention],
528
529
  }
529
530
 
531
+ def _init_weights(self, module):
532
+ super()._init_weights(module)
533
+ if isinstance(module, MiniMaxLightningAttention):
534
+ slope_rate = module.get_slope_rate()
535
+ query_decay, key_decay, diagonal_decay = module.decay_factors(slope_rate)
536
+ init.copy_(module.slope_rate, slope_rate)
537
+ init.copy_(module.query_decay, query_decay)
538
+ init.copy_(module.key_decay, key_decay)
539
+ init.copy_(module.diagonal_decay, diagonal_decay)
540
+
530
541
 
531
542
  class MiniMaxModel(MixtralModel):
532
543
  @check_model_inputs