transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,12 @@
14
14
  # limitations under the License.
15
15
  """Tokenization classes for OpenAI GPT."""
16
16
 
17
- from typing import Optional
17
+ from typing import Optional, Union
18
18
 
19
19
  from tokenizers import Tokenizer, decoders, pre_tokenizers
20
20
  from tokenizers.models import BPE
21
21
 
22
- from ...tokenization_utils_tokenizers import TokenizersBackend
22
+ from ...tokenization_utils_tokenizers import AddedToken, TokenizersBackend
23
23
  from ...utils import logging
24
24
 
25
25
 
@@ -84,45 +84,31 @@ class GPT2Tokenizer(TokenizersBackend):
84
84
  add_bos_token (`bool`, *optional*, defaults to `False`):
85
85
  Whether or not to add an initial beginning of sentence token to the input. This allows to treat the leading
86
86
  word just as any other word.
87
- vocab (`dict`, *optional*):
88
- Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
89
- merges (`list`, *optional*):
90
- Custom merges list. If not provided, merges are loaded from merges_file.
87
+ vocab (`str` or `dict[str, int]`, *optional*):
88
+ Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`.
89
+ merges (`str` or `list[str]`, *optional*):
90
+ Custom merges list. If not provided, merges are loaded from `merges_file`.
91
91
  """
92
92
 
93
93
  vocab_files_names = VOCAB_FILES_NAMES
94
94
  model_input_names = ["input_ids", "attention_mask"]
95
- slow_tokenizer_class = None
95
+ model = BPE
96
96
 
97
97
  def __init__(
98
98
  self,
99
- errors="replace",
100
- unk_token="<|endoftext|>",
101
- bos_token="<|endoftext|>",
102
- eos_token="<|endoftext|>",
103
- pad_token=None,
99
+ vocab: Optional[Union[str, dict[str, int]]] = None,
100
+ merges: Optional[Union[str, list[str]]] = None,
101
+ errors: str = "replace",
102
+ unk_token: Union[AddedToken, str] = "<|endoftext|>",
103
+ bos_token: Union[AddedToken, str] = "<|endoftext|>",
104
+ eos_token: Union[AddedToken, str] = "<|endoftext|>",
105
+ pad_token: Optional[Union[AddedToken, str]] = None,
104
106
  add_prefix_space=False,
105
- add_bos_token=False,
106
- vocab: Optional[dict] = None,
107
- merges: Optional[list] = None,
108
107
  **kwargs,
109
108
  ):
110
- # self.add_bos_token = add_bos_token
111
-
112
109
  self.add_prefix_space = add_prefix_space
113
-
114
- if vocab is not None:
115
- self._vocab = (
116
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
117
- )
118
- else:
119
- self._vocab = {}
120
-
121
- if merges is not None:
122
- self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
123
- else:
124
- self._merges = []
125
-
110
+ self._vocab = vocab if vocab is not None else {}
111
+ self._merges = merges or []
126
112
  self._tokenizer = Tokenizer(
127
113
  BPE(
128
114
  vocab=self._vocab,
@@ -133,31 +119,17 @@ class GPT2Tokenizer(TokenizersBackend):
133
119
  fuse_unk=False,
134
120
  )
135
121
  )
136
-
137
122
  self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
138
123
  self._tokenizer.decoder = decoders.ByteLevel()
139
-
140
- tokenizer_object = self._tokenizer
141
-
142
- # Set these before calling super().__init__() so the base class _post_init() can use them
143
- self._add_bos_token = add_bos_token
144
- self._add_eos_token = False
145
-
146
124
  super().__init__(
147
- tokenizer_object=tokenizer_object,
148
125
  errors=errors,
149
126
  unk_token=unk_token,
150
127
  bos_token=bos_token,
151
128
  eos_token=eos_token,
152
129
  pad_token=pad_token,
153
130
  add_prefix_space=add_prefix_space,
154
- add_bos_token=add_bos_token,
155
131
  **kwargs,
156
132
  )
157
133
 
158
- # Call _post_init for tokenizers created directly (not from_pretrained)
159
- # For from_pretrained, this will be called again after loading the tokenizer from file
160
- self._post_init()
161
-
162
134
 
163
135
  __all__ = ["GPT2Tokenizer"]
@@ -826,6 +826,7 @@ class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel):
826
826
  output_attentions: Optional[bool] = None,
827
827
  output_hidden_states: Optional[bool] = None,
828
828
  return_dict: Optional[bool] = None,
829
+ **kwargs,
829
830
  ) -> Union[tuple, TokenClassifierOutput]:
830
831
  r"""
831
832
  input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
@@ -419,6 +419,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
419
419
  output_hidden_states: Optional[bool] = None,
420
420
  return_dict: Optional[bool] = None,
421
421
  cache_position: Optional[torch.LongTensor] = None,
422
+ **kwargs,
422
423
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
423
424
  r"""
424
425
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -773,6 +774,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
773
774
  output_attentions: Optional[bool] = None,
774
775
  output_hidden_states: Optional[bool] = None,
775
776
  return_dict: Optional[bool] = None,
777
+ **kwargs,
776
778
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
777
779
  r"""
778
780
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -894,6 +896,7 @@ class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
894
896
  output_attentions: Optional[bool] = None,
895
897
  output_hidden_states: Optional[bool] = None,
896
898
  return_dict: Optional[bool] = None,
899
+ **kwargs,
897
900
  ) -> Union[tuple, TokenClassifierOutput]:
898
901
  r"""
899
902
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -974,6 +977,7 @@ class GPTNeoForQuestionAnswering(GPTNeoPreTrainedModel):
974
977
  output_attentions: Optional[bool] = None,
975
978
  output_hidden_states: Optional[bool] = None,
976
979
  return_dict: Optional[bool] = None,
980
+ **kwargs,
977
981
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
978
982
  r"""
979
983
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -28,7 +28,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
28
28
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
29
29
  from ...processing_utils import Unpack
30
30
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
31
- from ...utils.generic import check_model_inputs
31
+ from ...utils.generic import check_model_inputs, maybe_autocast
32
32
  from .configuration_gpt_neox import GPTNeoXConfig
33
33
 
34
34
 
@@ -107,7 +107,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
107
107
  position_ids_expanded = position_ids[:, None, :].float()
108
108
 
109
109
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
110
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
110
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
111
111
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
112
112
  emb = torch.cat((freqs, freqs), dim=-1)
113
113
  cos = emb.cos() * self.attention_scaling
@@ -645,6 +645,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
645
645
  use_cache: Optional[bool] = None,
646
646
  output_attentions: Optional[bool] = None,
647
647
  output_hidden_states: Optional[bool] = None,
648
+ **kwargs,
648
649
  ) -> SequenceClassifierOutputWithPast:
649
650
  r"""
650
651
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -724,6 +725,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
724
725
  use_cache: Optional[bool] = None,
725
726
  output_attentions: Optional[bool] = None,
726
727
  output_hidden_states: Optional[bool] = None,
728
+ **kwargs,
727
729
  ) -> TokenClassifierOutput:
728
730
  r"""
729
731
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -783,6 +785,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel):
783
785
  end_positions: Optional[torch.LongTensor] = None,
784
786
  output_attentions: Optional[bool] = None,
785
787
  output_hidden_states: Optional[bool] = None,
788
+ **kwargs,
786
789
  ) -> QuestionAnsweringModelOutput:
787
790
  outputs: BaseModelOutputWithPast = self.gpt_neox(
788
791
  input_ids,
@@ -518,6 +518,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
518
518
  use_cache: Optional[bool] = None,
519
519
  output_attentions: Optional[bool] = None,
520
520
  output_hidden_states: Optional[bool] = None,
521
+ **kwargs,
521
522
  ) -> SequenceClassifierOutputWithPast:
522
523
  r"""
523
524
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -597,6 +598,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
597
598
  use_cache: Optional[bool] = None,
598
599
  output_attentions: Optional[bool] = None,
599
600
  output_hidden_states: Optional[bool] = None,
601
+ **kwargs,
600
602
  ) -> TokenClassifierOutput:
601
603
  r"""
602
604
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -656,6 +658,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel):
656
658
  end_positions: Optional[torch.LongTensor] = None,
657
659
  output_attentions: Optional[bool] = None,
658
660
  output_hidden_states: Optional[bool] = None,
661
+ **kwargs,
659
662
  ) -> QuestionAnsweringModelOutput:
660
663
  outputs: BaseModelOutputWithPast = self.gpt_neox(
661
664
  input_ids,
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
  """Tokenization classes for GPTNeoX."""
16
16
 
17
- from typing import Optional
17
+ from typing import Optional, Union
18
18
 
19
19
  from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
20
20
  from tokenizers.models import BPE
@@ -87,51 +87,34 @@ class GPTNeoXTokenizer(TokenizersBackend):
87
87
  Whether or not to add an `eos_token` at the end of sequences.
88
88
  trim_offsets (`bool`, *optional*, defaults to `True`):
89
89
  Whether or not the post-processing step should trim offsets to avoid including whitespaces.
90
- vocab (`dict`, *optional*):
91
- Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
92
- merges (`list`, *optional*):
93
- Custom merges list. If not provided, merges are loaded from merges_file.
90
+ vocab (`str` or `dict[str, int]`, *optional*):
91
+ Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`.
92
+ merges (`str` or `list[str]`, *optional*):
93
+ Custom merges list. If not provided, merges are loaded from `merges_file`.
94
94
  """
95
95
 
96
96
  vocab_files_names = VOCAB_FILES_NAMES
97
97
  model_input_names = ["input_ids", "attention_mask"]
98
- slow_tokenizer_class = None
98
+ model = BPE
99
99
 
100
100
  def __init__(
101
101
  self,
102
+ vocab: Optional[Union[str, dict[str, int]]] = None,
103
+ merges: Optional[Union[str, list[str]]] = None,
102
104
  errors: str = "replace",
103
105
  unk_token: str = "<|endoftext|>",
104
106
  bos_token: str = "<|endoftext|>",
105
107
  eos_token: str = "<|endoftext|>",
106
108
  pad_token: str = "<|padding|>",
107
- add_bos_token: bool = False,
108
- add_eos_token: bool = False,
109
109
  add_prefix_space: bool = False,
110
110
  trim_offsets: bool = True,
111
- vocab: Optional[dict] = None,
112
- merges: Optional[list] = None,
113
111
  **kwargs,
114
112
  ):
115
- self._add_bos_token = add_bos_token
116
- self._add_eos_token = add_eos_token
117
113
  self.add_prefix_space = add_prefix_space
118
114
  self.trim_offsets = trim_offsets
119
115
 
120
- if vocab is not None:
121
- self._vocab = (
122
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
123
- )
124
- else:
125
- self._vocab = {
126
- str(unk_token): 0,
127
- str(pad_token): 1,
128
- }
129
-
130
- if merges is not None:
131
- self._merges = merges
132
- else:
133
- self._merges = []
134
-
116
+ self._vocab = vocab if vocab is not None else {str(unk_token): 0, str(pad_token): 1}
117
+ self._merges = merges or []
135
118
  self._tokenizer = Tokenizer(
136
119
  BPE(
137
120
  vocab=self._vocab,
@@ -149,38 +132,16 @@ class GPTNeoXTokenizer(TokenizersBackend):
149
132
  )
150
133
  self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False, trim_offsets=True)
151
134
 
152
- tokenizer_object = self._tokenizer
153
-
154
135
  super().__init__(
155
- tokenizer_object=tokenizer_object,
156
136
  errors=errors,
157
137
  unk_token=unk_token,
158
138
  bos_token=bos_token,
159
139
  eos_token=eos_token,
160
140
  pad_token=pad_token,
161
- add_bos_token=add_bos_token,
162
- add_eos_token=add_eos_token,
163
141
  add_prefix_space=add_prefix_space,
164
142
  trim_offsets=trim_offsets,
165
143
  **kwargs,
166
144
  )
167
145
 
168
- self.update_post_processor()
169
-
170
- def _post_init(self):
171
- """Post-initialization to ensure tokenizer settings are applied correctly."""
172
- # Re-apply settings to ensure they're correct after loading from pretrained
173
- self._tokenizer.normalizer = normalizers.NFC()
174
- self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(
175
- add_prefix_space=self.add_prefix_space, trim_offsets=self.trim_offsets
176
- )
177
- self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False, trim_offsets=True)
178
-
179
- # Call parent to handle AddedToken properties
180
- super()._post_init()
181
-
182
- # Update post processor with current bos/eos settings
183
- self.update_post_processor()
184
-
185
146
 
186
147
  __all__ = ["GPTNeoXTokenizer"]
@@ -30,6 +30,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
30
30
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
31
31
  from ...modeling_utils import PreTrainedModel
32
32
  from ...utils import auto_docstring, is_torch_flex_attn_available, logging
33
+ from ...utils.generic import maybe_autocast
33
34
  from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig
34
35
 
35
36
 
@@ -116,7 +117,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
116
117
  position_ids_expanded = position_ids[:, None, :].float()
117
118
 
118
119
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
119
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
120
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
120
121
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
121
122
  emb = torch.cat((freqs, freqs), dim=-1)
122
123
  cos = emb.cos() * self.attention_scaling
@@ -431,6 +432,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
431
432
  output_hidden_states: Optional[bool] = None,
432
433
  return_dict: Optional[bool] = None,
433
434
  cache_position: Optional[torch.LongTensor] = None,
435
+ **kwargs,
434
436
  ) -> Union[tuple, BaseModelOutputWithPast]:
435
437
  r"""
436
438
  Example:
@@ -28,6 +28,7 @@ from torch.nn import functional as F
28
28
  from ... import initialization as init
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
+ from ...integrations import use_kernelized_func
31
32
  from ...integrations.hub_kernels import use_kernel_forward_from_hub
32
33
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
33
34
  from ...modeling_layers import (
@@ -40,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
41
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
42
  from ...processing_utils import Unpack
42
43
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
43
- from ...utils.generic import OutputRecorder, check_model_inputs
44
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
44
45
  from .configuration_gpt_oss import GptOssConfig
45
46
 
46
47
 
@@ -235,7 +236,7 @@ class GptOssRotaryEmbedding(nn.Module):
235
236
  position_ids_expanded = position_ids[:, None, :].float()
236
237
 
237
238
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
238
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
239
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
239
240
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
240
241
  emb = freqs
241
242
  cos = emb.cos() * self.attention_scaling
@@ -301,12 +302,13 @@ def eager_attention_forward(
301
302
  combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
302
303
  probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
303
304
  scores = probs[..., :-1] # we drop the sink here
304
- attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
305
+ attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype)
305
306
  attn_output = torch.matmul(attn_weights, value_states)
306
307
  attn_output = attn_output.transpose(1, 2).contiguous()
307
308
  return attn_output, attn_weights
308
309
 
309
310
 
311
+ @use_kernelized_func(apply_rotary_pos_emb)
310
312
  class GptOssAttention(nn.Module):
311
313
  """Multi-headed attention from 'Attention Is All You Need' paper"""
312
314
 
@@ -332,7 +334,6 @@ class GptOssAttention(nn.Module):
332
334
  self.o_proj = nn.Linear(
333
335
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
334
336
  )
335
- self.rotary_fn = apply_rotary_pos_emb
336
337
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
337
338
  self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
338
339
 
@@ -343,7 +344,6 @@ class GptOssAttention(nn.Module):
343
344
  attention_mask: Optional[torch.Tensor],
344
345
  past_key_values: Optional[Cache] = None,
345
346
  cache_position: Optional[torch.LongTensor] = None,
346
- position_ids: Optional[torch.LongTensor] = None,
347
347
  **kwargs: Unpack[TransformersKwargs],
348
348
  ) -> tuple[torch.Tensor, torch.Tensor]:
349
349
  input_shape = hidden_states.shape[:-1]
@@ -373,7 +373,6 @@ class GptOssAttention(nn.Module):
373
373
  dropout=0.0 if not self.training else self.attention_dropout,
374
374
  scaling=self.scaling,
375
375
  sliding_window=self.sliding_window,
376
- position_ids=position_ids,
377
376
  s_aux=self.sinks, # diff with Llama
378
377
  **kwargs,
379
378
  )
@@ -34,7 +34,7 @@ from ...utils import (
34
34
  auto_docstring,
35
35
  logging,
36
36
  )
37
- from ...utils.generic import OutputRecorder, check_model_inputs
37
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
38
38
  from ..llama.modeling_llama import (
39
39
  LlamaDecoderLayer,
40
40
  LlamaPreTrainedModel,
@@ -185,7 +185,7 @@ class GptOssRotaryEmbedding(Qwen2RotaryEmbedding):
185
185
  position_ids_expanded = position_ids[:, None, :].float()
186
186
 
187
187
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
188
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
188
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
189
189
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
190
190
  emb = freqs
191
191
  cos = emb.cos() * self.attention_scaling
@@ -239,7 +239,7 @@ def eager_attention_forward(
239
239
  combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
240
240
  probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
241
241
  scores = probs[..., :-1] # we drop the sink here
242
- attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
242
+ attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype)
243
243
  attn_output = torch.matmul(attn_weights, value_states)
244
244
  attn_output = attn_output.transpose(1, 2).contiguous()
245
245
  return attn_output, attn_weights
@@ -269,7 +269,6 @@ class GptOssAttention(Qwen2Attention):
269
269
  attention_mask: Optional[torch.Tensor],
270
270
  past_key_values: Optional[Cache] = None,
271
271
  cache_position: Optional[torch.LongTensor] = None,
272
- position_ids: Optional[torch.LongTensor] = None,
273
272
  **kwargs: Unpack[TransformersKwargs],
274
273
  ) -> tuple[torch.Tensor, torch.Tensor]:
275
274
  input_shape = hidden_states.shape[:-1]
@@ -299,7 +298,6 @@ class GptOssAttention(Qwen2Attention):
299
298
  dropout=0.0 if not self.training else self.attention_dropout,
300
299
  scaling=self.scaling,
301
300
  sliding_window=self.sliding_window,
302
- position_ids=position_ids,
303
301
  s_aux=self.sinks, # diff with Llama
304
302
  **kwargs,
305
303
  )
@@ -482,6 +482,7 @@ class GPTJModel(GPTJPreTrainedModel):
482
482
  output_hidden_states: Optional[bool] = None,
483
483
  return_dict: Optional[bool] = None,
484
484
  cache_position: Optional[torch.LongTensor] = None,
485
+ **kwargs,
485
486
  ) -> Union[tuple, BaseModelOutputWithPast]:
486
487
  r"""
487
488
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
@@ -819,6 +820,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
819
820
  output_attentions: Optional[bool] = None,
820
821
  output_hidden_states: Optional[bool] = None,
821
822
  return_dict: Optional[bool] = None,
823
+ **kwargs,
822
824
  ) -> Union[tuple, SequenceClassifierOutputWithPast]:
823
825
  r"""
824
826
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
@@ -930,6 +932,7 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
930
932
  output_attentions: Optional[bool] = None,
931
933
  output_hidden_states: Optional[bool] = None,
932
934
  return_dict: Optional[bool] = None,
935
+ **kwargs,
933
936
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
934
937
  r"""
935
938
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
@@ -28,7 +28,7 @@ from torch import nn
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
32
32
  from ...masking_utils import create_causal_mask
33
33
  from ...modeling_layers import GradientCheckpointingLayer
34
34
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -36,7 +36,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
36
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
37
  from ...processing_utils import Unpack
38
38
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
39
- from ...utils.generic import check_model_inputs
39
+ from ...utils.generic import check_model_inputs, maybe_autocast
40
40
  from .configuration_granite import GraniteConfig
41
41
 
42
42
 
@@ -116,6 +116,7 @@ def eager_attention_forward(
116
116
  return attn_output, attn_weights
117
117
 
118
118
 
119
+ @use_kernelized_func(apply_rotary_pos_emb)
119
120
  class GraniteAttention(nn.Module):
120
121
  """Multi-headed attention from 'Attention Is All You Need' paper"""
121
122
 
@@ -141,7 +142,6 @@ class GraniteAttention(nn.Module):
141
142
  self.o_proj = nn.Linear(
142
143
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
143
144
  )
144
- self.rotary_fn = apply_rotary_pos_emb
145
145
 
146
146
  def forward(
147
147
  self,
@@ -376,7 +376,7 @@ class GraniteRotaryEmbedding(nn.Module):
376
376
  position_ids_expanded = position_ids[:, None, :].float()
377
377
 
378
378
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
379
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
379
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
380
380
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
381
381
  emb = torch.cat((freqs, freqs), dim=-1)
382
382
  cos = emb.cos() * self.attention_scaling
@@ -30,7 +30,7 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
33
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_layers import GradientCheckpointingLayer
36
36
  from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...utils import TransformersKwargs, auto_docstring
41
- from ...utils.generic import can_return_tuple, check_model_inputs
41
+ from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast
42
42
  from .configuration_granitemoe import GraniteMoeConfig
43
43
 
44
44
 
@@ -119,7 +119,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
119
119
  position_ids_expanded = position_ids[:, None, :].float()
120
120
 
121
121
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
122
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
122
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
123
123
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
124
124
  emb = torch.cat((freqs, freqs), dim=-1)
125
125
  cos = emb.cos() * self.attention_scaling
@@ -338,6 +338,7 @@ def eager_attention_forward(
338
338
  return attn_output, attn_weights
339
339
 
340
340
 
341
+ @use_kernelized_func(apply_rotary_pos_emb)
341
342
  class GraniteMoeAttention(nn.Module):
342
343
  """Multi-headed attention from 'Attention Is All You Need' paper"""
343
344
 
@@ -363,7 +364,6 @@ class GraniteMoeAttention(nn.Module):
363
364
  self.o_proj = nn.Linear(
364
365
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
365
366
  )
366
- self.rotary_fn = apply_rotary_pos_emb
367
367
 
368
368
  def forward(
369
369
  self,
@@ -714,8 +714,6 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
714
714
 
715
715
  loss = None
716
716
  if labels is not None:
717
- # Upcast to float if we need to compute the loss to avoid potential precision issues
718
- logits = logits.float()
719
717
  # Flatten the tokens
720
718
  loss = self.loss_function(
721
719
  logits,
@@ -295,8 +295,6 @@ class GraniteMoeForCausalLM(MixtralForCausalLM):
295
295
 
296
296
  loss = None
297
297
  if labels is not None:
298
- # Upcast to float if we need to compute the loss to avoid potential precision issues
299
- logits = logits.float()
300
298
  # Flatten the tokens
301
299
  loss = self.loss_function(
302
300
  logits,
@@ -31,7 +31,7 @@ from transformers.activations import ACT2FN
31
31
  from ... import initialization as init
32
32
  from ...cache_utils import Cache
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
34
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
35
35
  from ...masking_utils import create_causal_mask
36
36
  from ...modeling_layers import GradientCheckpointingLayer
37
37
  from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -39,7 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
40
  from ...processing_utils import Unpack
41
41
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
42
- from ...utils.generic import check_model_inputs
42
+ from ...utils.generic import check_model_inputs, maybe_autocast
43
43
  from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
44
44
  from .configuration_granitemoehybrid import GraniteMoeHybridConfig
45
45
 
@@ -132,6 +132,7 @@ def eager_attention_forward(
132
132
  return attn_output, attn_weights
133
133
 
134
134
 
135
+ @use_kernelized_func(apply_rotary_pos_emb)
135
136
  class GraniteMoeHybridAttention(nn.Module):
136
137
  """Multi-headed attention from 'Attention Is All You Need' paper"""
137
138
 
@@ -157,7 +158,6 @@ class GraniteMoeHybridAttention(nn.Module):
157
158
  self.o_proj = nn.Linear(
158
159
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
159
160
  )
160
- self.rotary_fn = apply_rotary_pos_emb
161
161
 
162
162
  def forward(
163
163
  self,
@@ -954,7 +954,7 @@ class GraniteMoeHybridRotaryEmbedding(nn.Module):
954
954
  position_ids_expanded = position_ids[:, None, :].float()
955
955
 
956
956
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
957
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
957
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
958
958
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
959
959
  emb = torch.cat((freqs, freqs), dim=-1)
960
960
  cos = emb.cos() * self.attention_scaling
@@ -1510,8 +1510,6 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
1510
1510
 
1511
1511
  loss = None
1512
1512
  if labels is not None:
1513
- # Upcast to float if we need to compute the loss to avoid potential precision issues
1514
- logits = logits.float()
1515
1513
  # Flatten the tokens
1516
1514
  loss = self.loss_function(
1517
1515
  logits,