transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -39,7 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
40
  from ...processing_utils import Unpack
41
41
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
42
- from ...utils.generic import check_model_inputs
42
+ from ...utils.generic import check_model_inputs, maybe_autocast
43
43
  from .configuration_modernbert_decoder import ModernBertDecoderConfig
44
44
 
45
45
 
@@ -168,7 +168,7 @@ class ModernBertDecoderRotaryEmbedding(nn.Module):
168
168
  position_ids_expanded = position_ids[:, None, :].float()
169
169
 
170
170
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
171
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
171
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
172
172
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
173
173
  emb = torch.cat((freqs, freqs), dim=-1)
174
174
  cos = emb.cos() * attention_scaling
@@ -342,7 +342,7 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
342
342
  attention_mask: Optional[torch.Tensor] = None,
343
343
  past_key_values: Optional[Cache] = None,
344
344
  cache_position: Optional[torch.LongTensor] = None,
345
- **kwargs,
345
+ **kwargs: Unpack[TransformersKwargs],
346
346
  ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
347
347
  residual = hidden_states
348
348
  hidden_states = self.attn_norm(hidden_states)
@@ -477,7 +477,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
477
477
  inputs_embeds: Optional[torch.Tensor] = None,
478
478
  use_cache: Optional[bool] = None,
479
479
  cache_position: Optional[torch.LongTensor] = None,
480
- **kwargs,
480
+ **kwargs: Unpack[TransformersKwargs],
481
481
  ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
482
482
  if (input_ids is None) == (inputs_embeds is None):
483
483
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@@ -489,7 +489,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
489
489
  batch_size, seq_length = inputs_embeds.shape[:2]
490
490
 
491
491
  # Handle past_key_values and cache setup
492
- if use_cache and past_key_values is None and not self.training:
492
+ if use_cache and past_key_values is None:
493
493
  past_key_values = DynamicCache(config=self.config)
494
494
 
495
495
  if cache_position is None:
@@ -527,13 +527,12 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
527
527
  for layer_type in self.config.layer_types:
528
528
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
529
529
 
530
- for idx, decoder_layer in enumerate(self.layers):
530
+ for decoder_layer in self.layers:
531
531
  hidden_states = decoder_layer(
532
532
  hidden_states,
533
533
  attention_mask=causal_mask_mapping[decoder_layer.attention_type],
534
534
  position_embeddings=position_embeddings[decoder_layer.attention_type],
535
535
  past_key_values=past_key_values,
536
- use_cache=use_cache,
537
536
  cache_position=cache_position,
538
537
  position_ids=position_ids,
539
538
  **kwargs,
@@ -583,7 +582,7 @@ class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationM
583
582
  labels: Optional[torch.LongTensor] = None,
584
583
  use_cache: Optional[bool] = None,
585
584
  logits_to_keep: Union[int, torch.Tensor] = 0,
586
- **kwargs,
585
+ **kwargs: Unpack[TransformersKwargs],
587
586
  ) -> Union[tuple, CausalLMOutputWithPast]:
588
587
  r"""
589
588
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -686,7 +685,7 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
686
685
  inputs_embeds: Optional[torch.Tensor] = None,
687
686
  labels: Optional[torch.LongTensor] = None,
688
687
  use_cache: Optional[bool] = None,
689
- **kwargs,
688
+ **kwargs: Unpack[TransformersKwargs],
690
689
  ) -> Union[tuple, SequenceClassifierOutputWithPast]:
691
690
  r"""
692
691
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -394,7 +394,7 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
394
394
  attention_mask: Optional[torch.Tensor] = None,
395
395
  past_key_values: Optional[Cache] = None,
396
396
  cache_position: Optional[torch.LongTensor] = None,
397
- **kwargs,
397
+ **kwargs: Unpack[TransformersKwargs],
398
398
  ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
399
399
  residual = hidden_states
400
400
  hidden_states = self.attn_norm(hidden_states)
@@ -525,7 +525,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
525
525
  inputs_embeds: Optional[torch.Tensor] = None,
526
526
  use_cache: Optional[bool] = None,
527
527
  cache_position: Optional[torch.LongTensor] = None,
528
- **kwargs,
528
+ **kwargs: Unpack[TransformersKwargs],
529
529
  ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
530
530
  if (input_ids is None) == (inputs_embeds is None):
531
531
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@@ -537,7 +537,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
537
537
  batch_size, seq_length = inputs_embeds.shape[:2]
538
538
 
539
539
  # Handle past_key_values and cache setup
540
- if use_cache and past_key_values is None and not self.training:
540
+ if use_cache and past_key_values is None:
541
541
  past_key_values = DynamicCache(config=self.config)
542
542
 
543
543
  if cache_position is None:
@@ -575,13 +575,12 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
575
575
  for layer_type in self.config.layer_types:
576
576
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
577
577
 
578
- for idx, decoder_layer in enumerate(self.layers):
578
+ for decoder_layer in self.layers:
579
579
  hidden_states = decoder_layer(
580
580
  hidden_states,
581
581
  attention_mask=causal_mask_mapping[decoder_layer.attention_type],
582
582
  position_embeddings=position_embeddings[decoder_layer.attention_type],
583
583
  past_key_values=past_key_values,
584
- use_cache=use_cache,
585
584
  cache_position=cache_position,
586
585
  position_ids=position_ids,
587
586
  **kwargs,
@@ -631,7 +630,7 @@ class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationM
631
630
  labels: Optional[torch.LongTensor] = None,
632
631
  use_cache: Optional[bool] = None,
633
632
  logits_to_keep: Union[int, torch.Tensor] = 0,
634
- **kwargs,
633
+ **kwargs: Unpack[TransformersKwargs],
635
634
  ) -> Union[tuple, CausalLMOutputWithPast]:
636
635
  r"""
637
636
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -734,7 +733,7 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
734
733
  inputs_embeds: Optional[torch.Tensor] = None,
735
734
  labels: Optional[torch.LongTensor] = None,
736
735
  use_cache: Optional[bool] = None,
737
- **kwargs,
736
+ **kwargs: Unpack[TransformersKwargs],
738
737
  ) -> Union[tuple, SequenceClassifierOutputWithPast]:
739
738
  r"""
740
739
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -30,6 +30,7 @@ from transformers.utils.generic import OutputRecorder, check_model_inputs
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
32
32
  from ...generation import GenerationMixin
33
+ from ...integrations import use_kernelized_func
33
34
  from ...masking_utils import create_causal_mask
34
35
  from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
35
36
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -45,6 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
45
46
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
47
  from ...processing_utils import Unpack
47
48
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
49
+ from ...utils.generic import maybe_autocast
48
50
  from .configuration_moonshine import MoonshineConfig
49
51
 
50
52
 
@@ -137,7 +139,7 @@ class MoonshineRotaryEmbedding(nn.Module):
137
139
  position_ids_expanded = position_ids[:, None, :].float()
138
140
 
139
141
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
140
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
142
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
141
143
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
142
144
  emb = torch.cat((freqs, freqs), dim=-1)
143
145
  cos = emb.cos() * self.attention_scaling
@@ -233,6 +235,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
233
235
  return q_embed, k_embed
234
236
 
235
237
 
238
+ @use_kernelized_func(apply_rotary_pos_emb)
236
239
  class MoonshineAttention(nn.Module):
237
240
  """Multi-headed attention from 'Attention Is All You Need' paper"""
238
241
 
@@ -264,7 +267,6 @@ class MoonshineAttention(nn.Module):
264
267
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
265
268
  )
266
269
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
267
- self.rotary_fn = apply_rotary_pos_emb
268
270
 
269
271
  # Pad head dimension to the next specified multiple.
270
272
  if self.config.pad_head_dim_to_multiple_of is not None:
@@ -34,6 +34,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast,
34
34
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
35
  from ...modeling_utils import PreTrainedModel
36
36
  from ...utils import auto_docstring, is_torch_flex_attn_available, logging
37
+ from ...utils.generic import maybe_autocast
37
38
  from ..auto.modeling_auto import AutoModel
38
39
  from .configuration_moshi import MoshiConfig, MoshiDepthConfig
39
40
 
@@ -327,7 +328,7 @@ class MoshiRotaryEmbedding(nn.Module):
327
328
  position_ids_expanded = position_ids[:, None, :].float()
328
329
 
329
330
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
330
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
331
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
331
332
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
332
333
  emb = torch.cat((freqs, freqs), dim=-1)
333
334
  cos = emb.cos() * self.attention_scaling
@@ -882,6 +883,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
882
883
  position_ids: Optional[torch.LongTensor] = None,
883
884
  labels: Optional[torch.LongTensor] = None,
884
885
  cache_position: Optional[torch.LongTensor] = None,
886
+ **kwargs,
885
887
  ) -> Union[tuple, BaseModelOutputWithPast]:
886
888
  """
887
889
  Args:
@@ -957,7 +959,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
957
959
  )
958
960
  use_cache = False
959
961
 
960
- if use_cache and past_key_values is None and not self.training:
962
+ if use_cache and past_key_values is None:
961
963
  past_key_values = DynamicCache(config=self.config)
962
964
 
963
965
  past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
@@ -1228,6 +1230,7 @@ class MoshiModel(MoshiPreTrainedModel):
1228
1230
  output_hidden_states: Optional[bool] = None,
1229
1231
  return_dict: Optional[bool] = None,
1230
1232
  cache_position: Optional[torch.LongTensor] = None,
1233
+ **kwargs,
1231
1234
  ) -> Union[tuple, BaseModelOutputWithPast]:
1232
1235
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1233
1236
  output_hidden_states = (
@@ -488,6 +488,7 @@ class MPNetForMaskedLM(MPNetPreTrainedModel):
488
488
  output_attentions: Optional[bool] = None,
489
489
  output_hidden_states: Optional[bool] = None,
490
490
  return_dict: Optional[bool] = None,
491
+ **kwargs,
491
492
  ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
492
493
  r"""
493
494
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -577,6 +578,7 @@ class MPNetForSequenceClassification(MPNetPreTrainedModel):
577
578
  output_attentions: Optional[bool] = None,
578
579
  output_hidden_states: Optional[bool] = None,
579
580
  return_dict: Optional[bool] = None,
581
+ **kwargs,
580
582
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
581
583
  r"""
582
584
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -656,6 +658,7 @@ class MPNetForMultipleChoice(MPNetPreTrainedModel):
656
658
  output_attentions: Optional[bool] = None,
657
659
  output_hidden_states: Optional[bool] = None,
658
660
  return_dict: Optional[bool] = None,
661
+ **kwargs,
659
662
  ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
660
663
  r"""
661
664
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -748,6 +751,7 @@ class MPNetForTokenClassification(MPNetPreTrainedModel):
748
751
  output_attentions: Optional[bool] = None,
749
752
  output_hidden_states: Optional[bool] = None,
750
753
  return_dict: Optional[bool] = None,
754
+ **kwargs,
751
755
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
752
756
  r"""
753
757
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -831,6 +835,7 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
831
835
  output_attentions: Optional[bool] = None,
832
836
  output_hidden_states: Optional[bool] = None,
833
837
  return_dict: Optional[bool] = None,
838
+ **kwargs,
834
839
  ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
835
840
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
836
841
 
@@ -15,7 +15,7 @@
15
15
  # limitations under the License.
16
16
  """Tokenization classes for MPNet."""
17
17
 
18
- from typing import Optional
18
+ from typing import Optional, Union
19
19
 
20
20
  from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
21
21
  from tokenizers.models import WordPiece
@@ -38,7 +38,7 @@ class MPNetTokenizer(TokenizersBackend):
38
38
  refer to this superclass for more information regarding those methods.
39
39
 
40
40
  Args:
41
- vocab (`dict`, *optional*):
41
+ vocab (`str` or `dict[str, int]`, *optional*):
42
42
  Dictionary mapping tokens to their IDs. If not provided, an empty vocab is initialized.
43
43
  do_lower_case (`bool`, *optional*, defaults to `True`):
44
44
  Whether or not to lowercase the input when tokenizing.
@@ -87,10 +87,11 @@ class MPNetTokenizer(TokenizersBackend):
87
87
 
88
88
  vocab_files_names = VOCAB_FILES_NAMES
89
89
  model_input_names = ["input_ids", "attention_mask"]
90
+ model = WordPiece
90
91
 
91
92
  def __init__(
92
93
  self,
93
- vocab: Optional[dict] = None,
94
+ vocab: Optional[Union[str, dict[str, int]]] = None,
94
95
  do_lower_case=True,
95
96
  bos_token="<s>",
96
97
  eos_token="</s>",
@@ -104,12 +105,7 @@ class MPNetTokenizer(TokenizersBackend):
104
105
  **kwargs,
105
106
  ):
106
107
  # Initialize vocab
107
- if vocab is not None:
108
- self._vocab = (
109
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
110
- )
111
- else:
112
- self._vocab = {}
108
+ self._vocab = vocab if vocab is not None else {}
113
109
 
114
110
  # Initialize the tokenizer with WordPiece model
115
111
  self._tokenizer = Tokenizer(WordPiece(self._vocab, unk_token=str(unk_token)))
@@ -142,11 +138,7 @@ class MPNetTokenizer(TokenizersBackend):
142
138
  # Mask token behave like a normal word, i.e. include the space before it
143
139
  mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
144
140
 
145
- # Store for later use
146
- tokenizer_object = self._tokenizer
147
-
148
141
  super().__init__(
149
- tokenizer_object=tokenizer_object,
150
142
  do_lower_case=do_lower_case,
151
143
  bos_token=bos_token,
152
144
  eos_token=eos_token,
@@ -498,6 +498,7 @@ class MptForSequenceClassification(MptPreTrainedModel):
498
498
  output_attentions: Optional[bool] = None,
499
499
  output_hidden_states: Optional[bool] = None,
500
500
  return_dict: Optional[bool] = None,
501
+ **kwargs,
501
502
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
502
503
  r"""
503
504
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -700,6 +701,7 @@ class MptForQuestionAnswering(MptPreTrainedModel):
700
701
  output_attentions: Optional[bool] = None,
701
702
  output_hidden_states: Optional[bool] = None,
702
703
  return_dict: Optional[bool] = None,
704
+ **kwargs,
703
705
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
704
706
  r"""
705
707
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -826,6 +826,7 @@ class MraModel(MraPreTrainedModel):
826
826
  inputs_embeds: Optional[torch.Tensor] = None,
827
827
  output_hidden_states: Optional[bool] = None,
828
828
  return_dict: Optional[bool] = None,
829
+ **kwargs,
829
830
  ) -> Union[tuple, BaseModelOutputWithCrossAttentions]:
830
831
  output_hidden_states = (
831
832
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -919,6 +920,7 @@ class MraForMaskedLM(MraPreTrainedModel):
919
920
  labels: Optional[torch.Tensor] = None,
920
921
  output_hidden_states: Optional[bool] = None,
921
922
  return_dict: Optional[bool] = None,
923
+ **kwargs,
922
924
  ) -> Union[tuple, MaskedLMOutput]:
923
925
  r"""
924
926
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1007,6 +1009,7 @@ class MraForSequenceClassification(MraPreTrainedModel):
1007
1009
  labels: Optional[torch.Tensor] = None,
1008
1010
  output_hidden_states: Optional[bool] = None,
1009
1011
  return_dict: Optional[bool] = None,
1012
+ **kwargs,
1010
1013
  ) -> Union[tuple, SequenceClassifierOutput]:
1011
1014
  r"""
1012
1015
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1086,6 +1089,7 @@ class MraForMultipleChoice(MraPreTrainedModel):
1086
1089
  labels: Optional[torch.Tensor] = None,
1087
1090
  output_hidden_states: Optional[bool] = None,
1088
1091
  return_dict: Optional[bool] = None,
1092
+ **kwargs,
1089
1093
  ) -> Union[tuple, MultipleChoiceModelOutput]:
1090
1094
  r"""
1091
1095
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -1189,6 +1193,7 @@ class MraForTokenClassification(MraPreTrainedModel):
1189
1193
  labels: Optional[torch.Tensor] = None,
1190
1194
  output_hidden_states: Optional[bool] = None,
1191
1195
  return_dict: Optional[bool] = None,
1196
+ **kwargs,
1192
1197
  ) -> Union[tuple, TokenClassifierOutput]:
1193
1198
  r"""
1194
1199
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1263,6 +1268,7 @@ class MraForQuestionAnswering(MraPreTrainedModel):
1263
1268
  end_positions: Optional[torch.Tensor] = None,
1264
1269
  output_hidden_states: Optional[bool] = None,
1265
1270
  return_dict: Optional[bool] = None,
1271
+ **kwargs,
1266
1272
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1267
1273
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1268
1274
 
@@ -671,6 +671,7 @@ class MT5Stack(MT5PreTrainedModel):
671
671
  output_hidden_states=None,
672
672
  return_dict=None,
673
673
  cache_position=None,
674
+ **kwargs,
674
675
  ):
675
676
  use_cache = use_cache if use_cache is not None else self.config.use_cache
676
677
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -898,6 +899,7 @@ class MT5Model(MT5PreTrainedModel):
898
899
  output_hidden_states: Optional[bool] = None,
899
900
  return_dict: Optional[bool] = None,
900
901
  cache_position: Optional[torch.LongTensor] = None,
902
+ **kwargs,
901
903
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
902
904
  r"""
903
905
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1081,6 +1083,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin):
1081
1083
  output_hidden_states: Optional[bool] = None,
1082
1084
  return_dict: Optional[bool] = None,
1083
1085
  cache_position: Optional[torch.LongTensor] = None,
1086
+ **kwargs,
1084
1087
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1085
1088
  r"""
1086
1089
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1268,6 +1271,7 @@ class MT5EncoderModel(MT5PreTrainedModel):
1268
1271
  output_attentions: Optional[bool] = None,
1269
1272
  output_hidden_states: Optional[bool] = None,
1270
1273
  return_dict: Optional[bool] = None,
1274
+ **kwargs,
1271
1275
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
1272
1276
  r"""
1273
1277
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1340,6 +1344,7 @@ class MT5ForSequenceClassification(MT5PreTrainedModel):
1340
1344
  output_attentions: Optional[bool] = None,
1341
1345
  output_hidden_states: Optional[bool] = None,
1342
1346
  return_dict: Optional[bool] = None,
1347
+ **kwargs,
1343
1348
  ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
1344
1349
  r"""
1345
1350
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1480,6 +1485,7 @@ class MT5ForTokenClassification(MT5PreTrainedModel):
1480
1485
  output_attentions: Optional[bool] = None,
1481
1486
  output_hidden_states: Optional[bool] = None,
1482
1487
  return_dict: Optional[bool] = None,
1488
+ **kwargs,
1483
1489
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1484
1490
  r"""
1485
1491
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1587,6 +1593,7 @@ class MT5ForQuestionAnswering(MT5PreTrainedModel):
1587
1593
  output_attentions: Optional[bool] = None,
1588
1594
  output_hidden_states: Optional[bool] = None,
1589
1595
  return_dict: Optional[bool] = None,
1596
+ **kwargs,
1590
1597
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
1591
1598
  r"""
1592
1599
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -482,6 +482,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
482
482
  output_hidden_states: Optional[bool] = None,
483
483
  return_dict: Optional[bool] = None,
484
484
  cache_position: Optional[torch.Tensor] = None,
485
+ **kwargs,
485
486
  ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
486
487
  r"""
487
488
  input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
@@ -716,6 +717,7 @@ class MusicgenModel(MusicgenPreTrainedModel):
716
717
  output_hidden_states: Optional[bool] = None,
717
718
  return_dict: Optional[bool] = None,
718
719
  cache_position: Optional[torch.Tensor] = None,
720
+ **kwargs,
719
721
  ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
720
722
  r"""
721
723
  input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
@@ -455,6 +455,7 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel):
455
455
  output_hidden_states: Optional[bool] = None,
456
456
  return_dict: Optional[bool] = None,
457
457
  cache_position: Optional[torch.Tensor] = None,
458
+ **kwargs,
458
459
  ) -> Union[tuple, BaseModelOutputWithPast]:
459
460
  r"""
460
461
  input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
@@ -670,6 +671,7 @@ class MusicgenMelodyModel(MusicgenMelodyPreTrainedModel):
670
671
  output_hidden_states: Optional[bool] = None,
671
672
  return_dict: Optional[bool] = None,
672
673
  cache_position: Optional[torch.Tensor] = None,
674
+ **kwargs,
673
675
  ) -> Union[tuple, BaseModelOutputWithPast]:
674
676
  r"""
675
677
  input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
@@ -785,6 +787,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin):
785
787
  return_dict: Optional[bool] = None,
786
788
  labels: Optional[torch.LongTensor] = None,
787
789
  cache_position: Optional[torch.Tensor] = None,
790
+ **kwargs,
788
791
  ) -> Union[tuple, MusicgenMelodyOutputWithPast]:
789
792
  r"""
790
793
  input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
@@ -534,6 +534,7 @@ class MvpEncoder(MvpPreTrainedModel):
534
534
  output_attentions: Optional[bool] = None,
535
535
  output_hidden_states: Optional[bool] = None,
536
536
  return_dict: Optional[bool] = None,
537
+ **kwargs,
537
538
  ) -> Union[tuple, BaseModelOutput]:
538
539
  r"""
539
540
  Args:
@@ -698,6 +699,7 @@ class MvpDecoder(MvpPreTrainedModel):
698
699
  output_hidden_states: Optional[bool] = None,
699
700
  return_dict: Optional[bool] = None,
700
701
  cache_position: Optional[torch.Tensor] = None,
702
+ **kwargs,
701
703
  ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
702
704
  r"""
703
705
  Args:
@@ -917,6 +919,7 @@ class MvpModel(MvpPreTrainedModel):
917
919
  output_hidden_states: Optional[bool] = None,
918
920
  return_dict: Optional[bool] = None,
919
921
  cache_position: Optional[torch.Tensor] = None,
922
+ **kwargs,
920
923
  ) -> Union[tuple, Seq2SeqModelOutput]:
921
924
  r"""
922
925
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1065,6 +1068,7 @@ class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin):
1065
1068
  output_hidden_states: Optional[bool] = None,
1066
1069
  return_dict: Optional[bool] = None,
1067
1070
  cache_position: Optional[torch.Tensor] = None,
1071
+ **kwargs,
1068
1072
  ) -> Union[tuple, Seq2SeqLMOutput]:
1069
1073
  r"""
1070
1074
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1213,6 +1217,7 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
1213
1217
  output_attentions: Optional[bool] = None,
1214
1218
  output_hidden_states: Optional[bool] = None,
1215
1219
  return_dict: Optional[bool] = None,
1220
+ **kwargs,
1216
1221
  ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
1217
1222
  r"""
1218
1223
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1372,6 +1377,7 @@ class MvpForQuestionAnswering(MvpPreTrainedModel):
1372
1377
  output_attentions: Optional[bool] = None,
1373
1378
  output_hidden_states: Optional[bool] = None,
1374
1379
  return_dict: Optional[bool] = None,
1380
+ **kwargs,
1375
1381
  ) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]:
1376
1382
  r"""
1377
1383
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1548,6 +1554,7 @@ class MvpForCausalLM(MvpPreTrainedModel, GenerationMixin):
1548
1554
  return_dict: Optional[bool] = None,
1549
1555
  cache_position: Optional[torch.Tensor] = None,
1550
1556
  logits_to_keep: Union[int, torch.Tensor] = 0,
1557
+ **kwargs,
1551
1558
  ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
1552
1559
  r"""
1553
1560
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -30,7 +30,7 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_func_from_hub
33
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_layers import GradientCheckpointingLayer
36
36
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41
- from ...utils.generic import check_model_inputs
41
+ from ...utils.generic import check_model_inputs, maybe_autocast
42
42
  from .configuration_nanochat import NanoChatConfig
43
43
 
44
44
 
@@ -113,7 +113,7 @@ class NanoChatRotaryEmbedding(nn.Module):
113
113
  position_ids_expanded = position_ids[:, None, :].float()
114
114
 
115
115
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
116
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
116
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
117
117
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
118
118
  emb = torch.cat((freqs, freqs), dim=-1)
119
119
  cos = emb.cos() * self.attention_scaling
@@ -195,6 +195,7 @@ def rotate_half(x):
195
195
  return torch.cat((x2, -x1), dim=-1)
196
196
 
197
197
 
198
+ @use_kernelized_func(apply_rotary_pos_emb)
198
199
  class NanoChatAttention(nn.Module):
199
200
  """Multi-headed attention from 'Attention Is All You Need' paper"""
200
201
 
@@ -220,7 +221,6 @@ class NanoChatAttention(nn.Module):
220
221
  self.o_proj = nn.Linear(
221
222
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
222
223
  )
223
- self.rotary_fn = apply_rotary_pos_emb
224
224
 
225
225
  self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
226
226
  self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
@@ -45,6 +45,7 @@ from ...modeling_rope_utils import (
45
45
  )
46
46
  from ...modeling_utils import PreTrainedModel
47
47
  from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
48
+ from ...utils.generic import maybe_autocast
48
49
  from .configuration_nemotron import NemotronConfig
49
50
 
50
51
 
@@ -87,7 +88,7 @@ class NemotronLayerNorm1P(nn.LayerNorm):
87
88
  args = _cast_if_autocast_enabled(
88
89
  device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps
89
90
  )
90
- with torch.autocast(device_type=input.device.type, enabled=False):
91
+ with maybe_autocast(device_type=input.device.type, enabled=False):
91
92
  return F.layer_norm(*args)
92
93
 
93
94
 
@@ -151,7 +152,7 @@ class NemotronRotaryEmbedding(nn.Module):
151
152
  position_ids_expanded = position_ids[:, None, :].float()
152
153
 
153
154
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
154
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
155
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
155
156
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
156
157
  emb = torch.cat((freqs, freqs), dim=-1)
157
158
  cos = emb.cos() * self.attention_scaling
@@ -657,6 +658,7 @@ class NemotronModel(NemotronPreTrainedModel):
657
658
  output_attentions: Optional[bool] = None,
658
659
  output_hidden_states: Optional[bool] = None,
659
660
  cache_position: Optional[torch.LongTensor] = None,
661
+ **kwargs,
660
662
  ) -> BaseModelOutputWithPast:
661
663
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
662
664
  output_hidden_states = (