transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional
16
+ from typing import Optional, Union
17
17
 
18
18
  from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
19
19
  from tokenizers.models import Unigram
@@ -79,13 +79,14 @@ class MBart50Tokenizer(TokenizersBackend):
79
79
 
80
80
  vocab_files_names = VOCAB_FILES_NAMES
81
81
  model_input_names = ["input_ids", "attention_mask"]
82
- slow_tokenizer_class = None
82
+ model = Unigram
83
83
 
84
84
  prefix_tokens: list[int] = []
85
85
  suffix_tokens: list[int] = []
86
86
 
87
87
  def __init__(
88
88
  self,
89
+ vocab: Optional[Union[str, dict, list]] = None,
89
90
  src_lang=None,
90
91
  tgt_lang=None,
91
92
  eos_token="</s>",
@@ -94,21 +95,16 @@ class MBart50Tokenizer(TokenizersBackend):
94
95
  unk_token="<unk>",
95
96
  pad_token="<pad>",
96
97
  mask_token="<mask>",
97
- vocab=None,
98
- merges=None, # Ignored for Unigram
99
- vocab_file=None,
100
98
  **kwargs,
101
99
  ):
102
100
  mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
103
101
 
104
- self.vocab_file = vocab_file
105
-
106
102
  # Do not pass language codes via extra_special_tokens to super().__init__.
107
103
  # We will mark them as special AFTER backend construction to avoid re-adding tokens
108
104
  # when loading from pretrained files.
109
105
 
110
106
  # Always construct a tokenizer_object without referencing external tokenizer files
111
- if vocab is not None:
107
+ if isinstance(vocab, list):
112
108
  # MBart50 uses fairseq vocab alignment matching MBart50Converter:
113
109
  # <s>=0, <pad>=1, </s>=2, <unk>=3, then tokens, lang codes, <mask>
114
110
 
@@ -180,9 +176,9 @@ class MBart50Tokenizer(TokenizersBackend):
180
176
  self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True)
181
177
 
182
178
  self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
183
-
179
+ additional_special_tokens = kwargs.pop("additional_special_tokens", []) or []
180
+ additional_special_tokens.extend(FAIRSEQ_LANGUAGE_CODES)
184
181
  super().__init__(
185
- tokenizer_object=self._tokenizer,
186
182
  src_lang=src_lang,
187
183
  tgt_lang=tgt_lang,
188
184
  eos_token=eos_token,
@@ -191,6 +187,7 @@ class MBart50Tokenizer(TokenizersBackend):
191
187
  unk_token=unk_token,
192
188
  pad_token=pad_token,
193
189
  mask_token=mask_token,
190
+ additional_special_tokens=additional_special_tokens,
194
191
  **kwargs,
195
192
  )
196
193
 
@@ -608,6 +608,7 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
608
608
  output_hidden_states: Optional[bool] = None,
609
609
  return_dict: Optional[bool] = None,
610
610
  cache_position: Optional[torch.Tensor] = None,
611
+ **kwargs,
611
612
  ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
612
613
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
613
614
  output_hidden_states = (
@@ -735,6 +736,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
735
736
  output_attentions: Optional[bool] = None,
736
737
  output_hidden_states: Optional[bool] = None,
737
738
  return_dict: Optional[bool] = None,
739
+ **kwargs,
738
740
  ) -> Union[tuple, MegatronBertForPreTrainingOutput]:
739
741
  r"""
740
742
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -955,6 +957,7 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
955
957
  output_attentions: Optional[bool] = None,
956
958
  output_hidden_states: Optional[bool] = None,
957
959
  return_dict: Optional[bool] = None,
960
+ **kwargs,
958
961
  ) -> Union[tuple, MaskedLMOutput]:
959
962
  r"""
960
963
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1140,6 +1143,7 @@ class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel):
1140
1143
  output_attentions: Optional[bool] = None,
1141
1144
  output_hidden_states: Optional[bool] = None,
1142
1145
  return_dict: Optional[bool] = None,
1146
+ **kwargs,
1143
1147
  ) -> Union[tuple, SequenceClassifierOutput]:
1144
1148
  r"""
1145
1149
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1223,6 +1227,7 @@ class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel):
1223
1227
  output_attentions: Optional[bool] = None,
1224
1228
  output_hidden_states: Optional[bool] = None,
1225
1229
  return_dict: Optional[bool] = None,
1230
+ **kwargs,
1226
1231
  ) -> Union[tuple, MultipleChoiceModelOutput]:
1227
1232
  r"""
1228
1233
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -1326,6 +1331,7 @@ class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
1326
1331
  output_attentions: Optional[bool] = None,
1327
1332
  output_hidden_states: Optional[bool] = None,
1328
1333
  return_dict: Optional[bool] = None,
1334
+ **kwargs,
1329
1335
  ) -> Union[tuple, TokenClassifierOutput]:
1330
1336
  r"""
1331
1337
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1391,6 +1397,7 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
1391
1397
  output_attentions: Optional[bool] = None,
1392
1398
  output_hidden_states: Optional[bool] = None,
1393
1399
  return_dict: Optional[bool] = None,
1400
+ **kwargs,
1394
1401
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1395
1402
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1396
1403
 
@@ -322,6 +322,7 @@ class MgpstrModel(MgpstrPreTrainedModel):
322
322
  output_attentions: Optional[bool] = None,
323
323
  output_hidden_states: Optional[bool] = None,
324
324
  return_dict: Optional[bool] = None,
325
+ **kwargs,
325
326
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
326
327
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
327
328
  output_hidden_states = (
@@ -385,6 +386,7 @@ class MgpstrForSceneTextRecognition(MgpstrPreTrainedModel):
385
386
  output_a3_attentions: Optional[bool] = None,
386
387
  output_hidden_states: Optional[bool] = None,
387
388
  return_dict: Optional[bool] = None,
389
+ **kwargs,
388
390
  ) -> Union[tuple[torch.FloatTensor], MgpstrModelOutput]:
389
391
  r"""
390
392
  output_a3_attentions (`bool`, *optional*):
@@ -32,6 +32,7 @@ from ...modeling_outputs import BaseModelOutputWithPast
32
32
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
33
33
  from ...modeling_utils import PreTrainedModel
34
34
  from ...utils import ModelOutput, auto_docstring, logging
35
+ from ...utils.generic import maybe_autocast
35
36
  from .configuration_mimi import MimiConfig
36
37
 
37
38
 
@@ -559,7 +560,7 @@ class MimiRotaryEmbedding(nn.Module):
559
560
  position_ids_expanded = position_ids[:, None, :].float()
560
561
 
561
562
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
562
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
563
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
563
564
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
564
565
  emb = torch.cat((freqs, freqs), dim=-1)
565
566
  cos = emb.cos() * self.attention_scaling
@@ -1685,6 +1686,7 @@ class MimiModel(MimiPreTrainedModel):
1685
1686
  encoder_past_key_values: Optional[Cache] = None,
1686
1687
  decoder_past_key_values: Optional[Cache] = None,
1687
1688
  return_dict: Optional[bool] = None,
1689
+ **kwargs,
1688
1690
  ) -> Union[tuple[torch.Tensor, torch.Tensor], MimiOutput]:
1689
1691
  r"""
1690
1692
  input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
@@ -31,7 +31,7 @@ from ... import initialization as init
31
31
  from ...activations import ACT2FN
32
32
  from ...cache_utils import Cache, DynamicCache
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
34
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
35
35
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
36
36
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
37
37
  from ...modeling_layers import (
@@ -45,7 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
45
45
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
46
  from ...processing_utils import Unpack
47
47
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
48
- from ...utils.generic import OutputRecorder, check_model_inputs
48
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
49
49
  from .configuration_minimax import MiniMaxConfig
50
50
 
51
51
 
@@ -310,7 +310,7 @@ class MiniMaxRotaryEmbedding(nn.Module):
310
310
  position_ids_expanded = position_ids[:, None, :].float()
311
311
 
312
312
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
313
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
313
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
314
314
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
315
315
  emb = torch.cat((freqs, freqs), dim=-1)
316
316
  cos = emb.cos() * self.attention_scaling
@@ -392,6 +392,7 @@ def eager_attention_forward(
392
392
  return attn_output, attn_weights
393
393
 
394
394
 
395
+ @use_kernelized_func(apply_rotary_pos_emb)
395
396
  class MiniMaxAttention(nn.Module):
396
397
  """Multi-headed attention from 'Attention Is All You Need' paper"""
397
398
 
@@ -408,7 +409,6 @@ class MiniMaxAttention(nn.Module):
408
409
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
409
410
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
410
411
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
411
- self.rotary_fn = apply_rotary_pos_emb
412
412
 
413
413
  def forward(
414
414
  self,
@@ -13,7 +13,7 @@ from torch import nn
13
13
  from ...activations import ACT2FN
14
14
  from ...cache_utils import Cache, DynamicCache
15
15
  from ...generation import GenerationMixin
16
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
16
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
17
17
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
18
18
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
19
19
  from ...modeling_layers import (
@@ -27,7 +27,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
27
27
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
28
28
  from ...processing_utils import Unpack
29
29
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
30
- from ...utils.generic import check_model_inputs
30
+ from ...utils.generic import check_model_inputs, maybe_autocast
31
31
  from .configuration_ministral import MinistralConfig
32
32
 
33
33
 
@@ -120,6 +120,7 @@ def eager_attention_forward(
120
120
  return attn_output, attn_weights
121
121
 
122
122
 
123
+ @use_kernelized_func(apply_rotary_pos_emb)
123
124
  class MinistralAttention(nn.Module):
124
125
  """Multi-headed attention from 'Attention Is All You Need' paper"""
125
126
 
@@ -138,7 +139,6 @@ class MinistralAttention(nn.Module):
138
139
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
139
140
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
140
141
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
141
- self.rotary_fn = apply_rotary_pos_emb
142
142
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
143
143
 
144
144
  def forward(
@@ -328,7 +328,7 @@ class MinistralRotaryEmbedding(nn.Module):
328
328
  position_ids_expanded = position_ids[:, None, :].float()
329
329
 
330
330
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
331
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
331
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
332
332
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
333
333
  emb = torch.cat((freqs, freqs), dim=-1)
334
334
  cos = emb.cos() * self.attention_scaling
@@ -193,7 +193,7 @@ class Ministral3Config(PreTrainedConfig):
193
193
  bos_token_id=bos_token_id,
194
194
  eos_token_id=eos_token_id,
195
195
  tie_word_embeddings=tie_word_embeddings,
196
- ignore_keys_at_rope_validation={"llama_4_scaling_beta"},
196
+ ignore_keys_at_rope_validation={"llama_4_scaling_beta", "max_position_embeddings"},
197
197
  **kwargs,
198
198
  )
199
199
 
@@ -15,7 +15,7 @@ from transformers.utils.generic import check_model_inputs
15
15
  from ...activations import ACT2FN
16
16
  from ...cache_utils import Cache, DynamicCache
17
17
  from ...generation import GenerationMixin
18
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
18
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
19
19
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
20
20
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
21
21
  from ...modeling_layers import (
@@ -29,6 +29,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
29
29
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
30
30
  from ...processing_utils import Unpack
31
31
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
32
+ from ...utils.generic import maybe_autocast
32
33
  from .configuration_ministral3 import Ministral3Config
33
34
 
34
35
 
@@ -110,6 +111,7 @@ def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_positi
110
111
  return scaling.unsqueeze(-1)
111
112
 
112
113
 
114
+ @use_kernelized_func(apply_rotary_pos_emb)
113
115
  class Ministral3Attention(nn.Module):
114
116
  """Multi-headed attention from 'Attention Is All You Need' paper"""
115
117
 
@@ -126,7 +128,6 @@ class Ministral3Attention(nn.Module):
126
128
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
127
129
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
128
130
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
129
- self.rotary_fn = apply_rotary_pos_emb
130
131
 
131
132
  def forward(
132
133
  self,
@@ -333,7 +334,7 @@ class Ministral3RotaryEmbedding(nn.Module):
333
334
  position_ids_expanded = position_ids[:, None, :].float()
334
335
 
335
336
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
336
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
337
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
337
338
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
338
339
  emb = torch.cat((freqs, freqs), dim=-1)
339
340
  cos = emb.cos() * self.attention_scaling
@@ -15,7 +15,7 @@ from transformers.utils.generic import check_model_inputs
15
15
  from ...activations import ACT2FN
16
16
  from ...cache_utils import Cache, DynamicCache
17
17
  from ...generation import GenerationMixin
18
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
18
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
19
19
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
20
20
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
21
21
  from ...modeling_layers import (
@@ -29,6 +29,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
29
29
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
30
30
  from ...processing_utils import Unpack
31
31
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
32
+ from ...utils.generic import maybe_autocast
32
33
  from .configuration_mistral import MistralConfig
33
34
 
34
35
 
@@ -121,6 +122,7 @@ def eager_attention_forward(
121
122
  return attn_output, attn_weights
122
123
 
123
124
 
125
+ @use_kernelized_func(apply_rotary_pos_emb)
124
126
  class MistralAttention(nn.Module):
125
127
  """Multi-headed attention from 'Attention Is All You Need' paper"""
126
128
 
@@ -137,7 +139,6 @@ class MistralAttention(nn.Module):
137
139
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
138
140
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
139
141
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
140
- self.rotary_fn = apply_rotary_pos_emb
141
142
 
142
143
  def forward(
143
144
  self,
@@ -323,7 +324,7 @@ class MistralRotaryEmbedding(nn.Module):
323
324
  position_ids_expanded = position_ids[:, None, :].float()
324
325
 
325
326
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
326
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
327
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
327
328
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
328
329
  emb = torch.cat((freqs, freqs), dim=-1)
329
330
  cos = emb.cos() * self.attention_scaling
@@ -37,7 +37,7 @@ from ... import initialization as init
37
37
  from ...activations import ACT2FN
38
38
  from ...cache_utils import Cache, DynamicCache
39
39
  from ...generation import GenerationMixin
40
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
40
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
41
41
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
42
42
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
43
43
  from ...modeling_layers import (
@@ -51,7 +51,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
51
51
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
52
52
  from ...processing_utils import Unpack
53
53
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
54
- from ...utils.generic import OutputRecorder
54
+ from ...utils.generic import OutputRecorder, maybe_autocast
55
55
  from .configuration_mixtral import MixtralConfig
56
56
 
57
57
 
@@ -208,7 +208,7 @@ class MixtralRotaryEmbedding(nn.Module):
208
208
  position_ids_expanded = position_ids[:, None, :].float()
209
209
 
210
210
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
211
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
211
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
212
212
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
213
213
  emb = torch.cat((freqs, freqs), dim=-1)
214
214
  cos = emb.cos() * self.attention_scaling
@@ -290,6 +290,7 @@ def eager_attention_forward(
290
290
  return attn_output, attn_weights
291
291
 
292
292
 
293
+ @use_kernelized_func(apply_rotary_pos_emb)
293
294
  class MixtralAttention(nn.Module):
294
295
  """Multi-headed attention from 'Attention Is All You Need' paper"""
295
296
 
@@ -306,7 +307,6 @@ class MixtralAttention(nn.Module):
306
307
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
307
308
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
308
309
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
309
- self.rotary_fn = apply_rotary_pos_emb
310
310
 
311
311
  def forward(
312
312
  self,
@@ -37,7 +37,7 @@ from ...modeling_rope_utils import (
37
37
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
38
  from ...processing_utils import Unpack
39
39
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
40
- from ...utils.generic import OutputRecorder, check_model_inputs
40
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
41
41
  from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig
42
42
 
43
43
 
@@ -781,7 +781,7 @@ class MllamaRotaryEmbedding(nn.Module):
781
781
  position_ids_expanded = position_ids[:, None, :].float()
782
782
 
783
783
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
784
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
784
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
785
785
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
786
786
  emb = torch.cat((freqs, freqs), dim=-1)
787
787
  cos = emb.cos() * self.attention_scaling
@@ -234,8 +234,8 @@ class MLukeTokenizer(TokenizersBackend):
234
234
  entity_pad_token="[PAD]",
235
235
  entity_mask_token="[MASK]",
236
236
  entity_mask2_token="[MASK2]",
237
- vocab: Optional[list] = None,
238
- entity_vocab: Optional[dict] = None,
237
+ vocab: Optional[Union[str, dict, list]] = None,
238
+ entity_vocab: Optional[Union[str, dict, list]] = None,
239
239
  **kwargs,
240
240
  ) -> None:
241
241
  # Mask token behave like a normal word, i.e. include the space before it
@@ -263,10 +263,13 @@ class MLukeTokenizer(TokenizersBackend):
263
263
  entity_vocab = kwargs.pop("entity_vocab")
264
264
 
265
265
  # Build vocab from data (list of (token, score) tuples)
266
- if vocab is not None:
266
+ if isinstance(vocab, list):
267
267
  # vocab is list of (token, score) tuples from SentencePieceExtractor
268
268
  self._vocab = [(token, float(score)) for token, score in vocab]
269
269
  self._vocab_size = len(self._vocab)
270
+ elif vocab is not None:
271
+ self._vocab = vocab
272
+ self._vocab_size = 0
270
273
  else:
271
274
  # Create minimal vocab with <unk> to satisfy Unigram requirements
272
275
  self._vocab = [("<unk>", 0.0)]
@@ -365,10 +368,7 @@ class MLukeTokenizer(TokenizersBackend):
365
368
 
366
369
  kwargs["extra_special_tokens"] = extra_tokens
367
370
 
368
- tokenizer_object = self._tokenizer
369
-
370
371
  super().__init__(
371
- tokenizer_object=tokenizer_object,
372
372
  bos_token=bos_token,
373
373
  eos_token=eos_token,
374
374
  unk_token=unk_token,
@@ -1180,6 +1180,7 @@ class MMGroundingDinoEncoder(MMGroundingDinoPreTrainedModel):
1180
1180
  output_attentions=None,
1181
1181
  output_hidden_states=None,
1182
1182
  return_dict=None,
1183
+ **kwargs,
1183
1184
  ):
1184
1185
  r"""
1185
1186
  Args:
@@ -1476,6 +1477,7 @@ class MMGroundingDinoDecoder(MMGroundingDinoPreTrainedModel):
1476
1477
  output_attentions=None,
1477
1478
  output_hidden_states=None,
1478
1479
  return_dict=None,
1480
+ **kwargs,
1479
1481
  ):
1480
1482
  r"""
1481
1483
  Args:
@@ -1951,6 +1953,7 @@ class MMGroundingDinoModel(MMGroundingDinoPreTrainedModel):
1951
1953
  output_attentions=None,
1952
1954
  output_hidden_states=None,
1953
1955
  return_dict=None,
1956
+ **kwargs,
1954
1957
  ):
1955
1958
  r"""
1956
1959
  input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
@@ -2431,6 +2434,7 @@ class MMGroundingDinoForObjectDetection(MMGroundingDinoPreTrainedModel):
2431
2434
  output_hidden_states: Optional[bool] = None,
2432
2435
  return_dict: Optional[bool] = None,
2433
2436
  labels: Optional[list[dict[str, Union[torch.LongTensor, torch.FloatTensor]]]] = None,
2437
+ **kwargs,
2434
2438
  ):
2435
2439
  r"""
2436
2440
  input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
@@ -195,6 +195,7 @@ class MobileNetV1Model(MobileNetV1PreTrainedModel):
195
195
  pixel_values: Optional[torch.Tensor] = None,
196
196
  output_hidden_states: Optional[bool] = None,
197
197
  return_dict: Optional[bool] = None,
198
+ **kwargs,
198
199
  ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
199
200
  output_hidden_states = (
200
201
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -260,6 +261,7 @@ class MobileNetV1ForImageClassification(MobileNetV1PreTrainedModel):
260
261
  output_hidden_states: Optional[bool] = None,
261
262
  labels: Optional[torch.Tensor] = None,
262
263
  return_dict: Optional[bool] = None,
264
+ **kwargs,
263
265
  ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
264
266
  r"""
265
267
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -331,6 +331,7 @@ class MobileNetV2Model(MobileNetV2PreTrainedModel):
331
331
  pixel_values: Optional[torch.Tensor] = None,
332
332
  output_hidden_states: Optional[bool] = None,
333
333
  return_dict: Optional[bool] = None,
334
+ **kwargs,
334
335
  ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
335
336
  output_hidden_states = (
336
337
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -396,6 +397,7 @@ class MobileNetV2ForImageClassification(MobileNetV2PreTrainedModel):
396
397
  output_hidden_states: Optional[bool] = None,
397
398
  labels: Optional[torch.Tensor] = None,
398
399
  return_dict: Optional[bool] = None,
400
+ **kwargs,
399
401
  ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
400
402
  r"""
401
403
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -524,6 +526,7 @@ class MobileNetV2ForSemanticSegmentation(MobileNetV2PreTrainedModel):
524
526
  labels: Optional[torch.Tensor] = None,
525
527
  output_hidden_states: Optional[bool] = None,
526
528
  return_dict: Optional[bool] = None,
529
+ **kwargs,
527
530
  ) -> Union[tuple, SemanticSegmenterOutput]:
528
531
  r"""
529
532
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -659,6 +659,7 @@ class MobileViTModel(MobileViTPreTrainedModel):
659
659
  pixel_values: Optional[torch.Tensor] = None,
660
660
  output_hidden_states: Optional[bool] = None,
661
661
  return_dict: Optional[bool] = None,
662
+ **kwargs,
662
663
  ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
663
664
  output_hidden_states = (
664
665
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -725,6 +726,7 @@ class MobileViTForImageClassification(MobileViTPreTrainedModel):
725
726
  output_hidden_states: Optional[bool] = None,
726
727
  labels: Optional[torch.Tensor] = None,
727
728
  return_dict: Optional[bool] = None,
729
+ **kwargs,
728
730
  ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
729
731
  r"""
730
732
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -889,6 +891,7 @@ class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
889
891
  labels: Optional[torch.Tensor] = None,
890
892
  output_hidden_states: Optional[bool] = None,
891
893
  return_dict: Optional[bool] = None,
894
+ **kwargs,
892
895
  ) -> Union[tuple, SemanticSegmenterOutput]:
893
896
  r"""
894
897
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -623,6 +623,7 @@ class MobileViTV2Model(MobileViTV2PreTrainedModel):
623
623
  pixel_values: Optional[torch.Tensor] = None,
624
624
  output_hidden_states: Optional[bool] = None,
625
625
  return_dict: Optional[bool] = None,
626
+ **kwargs,
626
627
  ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
627
628
  output_hidden_states = (
628
629
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -691,6 +692,7 @@ class MobileViTV2ForImageClassification(MobileViTV2PreTrainedModel):
691
692
  output_hidden_states: Optional[bool] = None,
692
693
  labels: Optional[torch.Tensor] = None,
693
694
  return_dict: Optional[bool] = None,
695
+ **kwargs,
694
696
  ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
695
697
  r"""
696
698
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -858,6 +860,7 @@ class MobileViTV2ForSemanticSegmentation(MobileViTV2PreTrainedModel):
858
860
  labels: Optional[torch.Tensor] = None,
859
861
  output_hidden_states: Optional[bool] = None,
860
862
  return_dict: Optional[bool] = None,
863
+ **kwargs,
861
864
  ) -> Union[tuple, SemanticSegmenterOutput]:
862
865
  r"""
863
866
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -45,6 +45,7 @@ from ...modeling_outputs import (
45
45
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
46
46
  from ...modeling_utils import PreTrainedModel
47
47
  from ...utils import auto_docstring, is_flash_attn_2_available, logging
48
+ from ...utils.generic import maybe_autocast
48
49
  from ...utils.import_utils import is_triton_available
49
50
  from .configuration_modernbert import ModernBertConfig
50
51
 
@@ -316,7 +317,7 @@ class ModernBertRotaryEmbedding(nn.Module):
316
317
  position_ids_expanded = position_ids[:, None, :].float()
317
318
 
318
319
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
319
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
320
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
320
321
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
321
322
  emb = torch.cat((freqs, freqs), dim=-1)
322
323
  cos = emb.cos() * attention_scaling
@@ -852,6 +853,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
852
853
  output_attentions: Optional[bool] = None,
853
854
  output_hidden_states: Optional[bool] = None,
854
855
  return_dict: Optional[bool] = None,
856
+ **kwargs,
855
857
  ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
856
858
  r"""
857
859
  sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1345,6 +1347,7 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1345
1347
  output_attentions: Optional[bool] = None,
1346
1348
  output_hidden_states: Optional[bool] = None,
1347
1349
  return_dict: Optional[bool] = None,
1350
+ **kwargs,
1348
1351
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1349
1352
  r"""
1350
1353
  sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -975,6 +975,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
975
975
  output_attentions: Optional[bool] = None,
976
976
  output_hidden_states: Optional[bool] = None,
977
977
  return_dict: Optional[bool] = None,
978
+ **kwargs,
978
979
  ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
979
980
  r"""
980
981
  sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1468,6 +1469,7 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1468
1469
  output_attentions: Optional[bool] = None,
1469
1470
  output_hidden_states: Optional[bool] = None,
1470
1471
  return_dict: Optional[bool] = None,
1472
+ **kwargs,
1471
1473
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1472
1474
  r"""
1473
1475
  sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):