transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import torch
20
20
  from torch import nn
21
21
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
22
22
 
23
+ from ... import initialization as init
23
24
  from ...activations import ACT2FN
24
25
  from ...cache_utils import Cache, DynamicCache
25
26
  from ...generation import GenerationMixin
@@ -70,11 +71,11 @@ class GPTNeoSelfAttention(nn.Module):
70
71
  # local causal self attention is a sliding window where each token can only attend to the previous
71
72
  # window_size tokens. This is implemented by updating the causal mask such that for each token
72
73
  # all other tokens are masked except the previous window_size tokens.
74
+ self.attention_type = attention_type
73
75
  if attention_type == "local":
74
76
  bias = torch.bitwise_xor(bias, torch.tril(bias, -config.window_size))
75
77
 
76
78
  self.register_buffer("bias", bias, persistent=False)
77
- self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
78
79
 
79
80
  self.attn_dropout = nn.Dropout(float(config.attention_dropout))
80
81
  self.resid_dropout = nn.Dropout(float(config.resid_dropout))
@@ -237,8 +238,8 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
237
238
  else torch.get_autocast_gpu_dtype()
238
239
  )
239
240
  # Handle the case where the model is quantized
240
- elif hasattr(self.config, "_pre_quantization_dtype"):
241
- target_dtype = self.config._pre_quantization_dtype
241
+ elif hasattr(self.config, "quantization_config"):
242
+ target_dtype = self.config.dtype
242
243
  else:
243
244
  target_dtype = self.q_proj.weight.dtype
244
245
 
@@ -382,6 +383,17 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
382
383
  _supports_flash_attn = True
383
384
  _can_compile_fullgraph = False # TODO: needs a hybrid cache
384
385
 
386
+ def _init_weights(self, module):
387
+ super()._init_weights(module)
388
+ if isinstance(module, GPTNeoSelfAttention):
389
+ max_positions = module.config.max_position_embeddings
390
+ bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view(
391
+ 1, 1, max_positions, max_positions
392
+ )
393
+ if module.attention_type == "local":
394
+ bias = torch.bitwise_xor(bias, torch.tril(bias, -module.config.window_size))
395
+ init.copy_(module.bias, bias)
396
+
385
397
 
386
398
  @auto_docstring
387
399
  class GPTNeoModel(GPTNeoPreTrainedModel):
@@ -419,6 +431,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
419
431
  output_hidden_states: Optional[bool] = None,
420
432
  return_dict: Optional[bool] = None,
421
433
  cache_position: Optional[torch.LongTensor] = None,
434
+ **kwargs,
422
435
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
423
436
  r"""
424
437
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -773,6 +786,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
773
786
  output_attentions: Optional[bool] = None,
774
787
  output_hidden_states: Optional[bool] = None,
775
788
  return_dict: Optional[bool] = None,
789
+ **kwargs,
776
790
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
777
791
  r"""
778
792
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -894,6 +908,7 @@ class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
894
908
  output_attentions: Optional[bool] = None,
895
909
  output_hidden_states: Optional[bool] = None,
896
910
  return_dict: Optional[bool] = None,
911
+ **kwargs,
897
912
  ) -> Union[tuple, TokenClassifierOutput]:
898
913
  r"""
899
914
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -974,6 +989,7 @@ class GPTNeoForQuestionAnswering(GPTNeoPreTrainedModel):
974
989
  output_attentions: Optional[bool] = None,
975
990
  output_hidden_states: Optional[bool] = None,
976
991
  return_dict: Optional[bool] = None,
992
+ **kwargs,
977
993
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
978
994
  r"""
979
995
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -28,7 +28,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
28
28
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
29
29
  from ...processing_utils import Unpack
30
30
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
31
- from ...utils.generic import check_model_inputs
31
+ from ...utils.generic import check_model_inputs, maybe_autocast
32
32
  from .configuration_gpt_neox import GPTNeoXConfig
33
33
 
34
34
 
@@ -66,7 +66,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
66
66
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
67
67
 
68
68
  self.register_buffer("inv_freq", inv_freq, persistent=False)
69
- self.original_inv_freq = inv_freq
69
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
70
70
 
71
71
  @staticmethod
72
72
  def compute_default_rope_parameters(
@@ -107,7 +107,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
107
107
  position_ids_expanded = position_ids[:, None, :].float()
108
108
 
109
109
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
110
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
110
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
111
111
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
112
112
  emb = torch.cat((freqs, freqs), dim=-1)
113
113
  cos = emb.cos() * self.attention_scaling
@@ -645,6 +645,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
645
645
  use_cache: Optional[bool] = None,
646
646
  output_attentions: Optional[bool] = None,
647
647
  output_hidden_states: Optional[bool] = None,
648
+ **kwargs,
648
649
  ) -> SequenceClassifierOutputWithPast:
649
650
  r"""
650
651
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -724,6 +725,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
724
725
  use_cache: Optional[bool] = None,
725
726
  output_attentions: Optional[bool] = None,
726
727
  output_hidden_states: Optional[bool] = None,
728
+ **kwargs,
727
729
  ) -> TokenClassifierOutput:
728
730
  r"""
729
731
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -783,6 +785,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel):
783
785
  end_positions: Optional[torch.LongTensor] = None,
784
786
  output_attentions: Optional[bool] = None,
785
787
  output_hidden_states: Optional[bool] = None,
788
+ **kwargs,
786
789
  ) -> QuestionAnsweringModelOutput:
787
790
  outputs: BaseModelOutputWithPast = self.gpt_neox(
788
791
  input_ids,
@@ -518,6 +518,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
518
518
  use_cache: Optional[bool] = None,
519
519
  output_attentions: Optional[bool] = None,
520
520
  output_hidden_states: Optional[bool] = None,
521
+ **kwargs,
521
522
  ) -> SequenceClassifierOutputWithPast:
522
523
  r"""
523
524
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -597,6 +598,7 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
597
598
  use_cache: Optional[bool] = None,
598
599
  output_attentions: Optional[bool] = None,
599
600
  output_hidden_states: Optional[bool] = None,
601
+ **kwargs,
600
602
  ) -> TokenClassifierOutput:
601
603
  r"""
602
604
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -656,6 +658,7 @@ class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel):
656
658
  end_positions: Optional[torch.LongTensor] = None,
657
659
  output_attentions: Optional[bool] = None,
658
660
  output_hidden_states: Optional[bool] = None,
661
+ **kwargs,
659
662
  ) -> QuestionAnsweringModelOutput:
660
663
  outputs: BaseModelOutputWithPast = self.gpt_neox(
661
664
  input_ids,
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
  """Tokenization classes for GPTNeoX."""
16
16
 
17
- from typing import Optional
17
+ from typing import Optional, Union
18
18
 
19
19
  from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
20
20
  from tokenizers.models import BPE
@@ -87,51 +87,34 @@ class GPTNeoXTokenizer(TokenizersBackend):
87
87
  Whether or not to add an `eos_token` at the end of sequences.
88
88
  trim_offsets (`bool`, *optional*, defaults to `True`):
89
89
  Whether or not the post-processing step should trim offsets to avoid including whitespaces.
90
- vocab (`dict`, *optional*):
91
- Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
92
- merges (`list`, *optional*):
93
- Custom merges list. If not provided, merges are loaded from merges_file.
90
+ vocab (`str` or `dict[str, int]`, *optional*):
91
+ Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`.
92
+ merges (`str` or `list[str]`, *optional*):
93
+ Custom merges list. If not provided, merges are loaded from `merges_file`.
94
94
  """
95
95
 
96
96
  vocab_files_names = VOCAB_FILES_NAMES
97
97
  model_input_names = ["input_ids", "attention_mask"]
98
- slow_tokenizer_class = None
98
+ model = BPE
99
99
 
100
100
  def __init__(
101
101
  self,
102
+ vocab: Optional[Union[str, dict[str, int]]] = None,
103
+ merges: Optional[Union[str, list[str]]] = None,
102
104
  errors: str = "replace",
103
105
  unk_token: str = "<|endoftext|>",
104
106
  bos_token: str = "<|endoftext|>",
105
107
  eos_token: str = "<|endoftext|>",
106
108
  pad_token: str = "<|padding|>",
107
- add_bos_token: bool = False,
108
- add_eos_token: bool = False,
109
109
  add_prefix_space: bool = False,
110
110
  trim_offsets: bool = True,
111
- vocab: Optional[dict] = None,
112
- merges: Optional[list] = None,
113
111
  **kwargs,
114
112
  ):
115
- self._add_bos_token = add_bos_token
116
- self._add_eos_token = add_eos_token
117
113
  self.add_prefix_space = add_prefix_space
118
114
  self.trim_offsets = trim_offsets
119
115
 
120
- if vocab is not None:
121
- self._vocab = (
122
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
123
- )
124
- else:
125
- self._vocab = {
126
- str(unk_token): 0,
127
- str(pad_token): 1,
128
- }
129
-
130
- if merges is not None:
131
- self._merges = merges
132
- else:
133
- self._merges = []
134
-
116
+ self._vocab = vocab if vocab is not None else {str(unk_token): 0, str(pad_token): 1}
117
+ self._merges = merges or []
135
118
  self._tokenizer = Tokenizer(
136
119
  BPE(
137
120
  vocab=self._vocab,
@@ -149,38 +132,16 @@ class GPTNeoXTokenizer(TokenizersBackend):
149
132
  )
150
133
  self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False, trim_offsets=True)
151
134
 
152
- tokenizer_object = self._tokenizer
153
-
154
135
  super().__init__(
155
- tokenizer_object=tokenizer_object,
156
136
  errors=errors,
157
137
  unk_token=unk_token,
158
138
  bos_token=bos_token,
159
139
  eos_token=eos_token,
160
140
  pad_token=pad_token,
161
- add_bos_token=add_bos_token,
162
- add_eos_token=add_eos_token,
163
141
  add_prefix_space=add_prefix_space,
164
142
  trim_offsets=trim_offsets,
165
143
  **kwargs,
166
144
  )
167
145
 
168
- self.update_post_processor()
169
-
170
- def _post_init(self):
171
- """Post-initialization to ensure tokenizer settings are applied correctly."""
172
- # Re-apply settings to ensure they're correct after loading from pretrained
173
- self._tokenizer.normalizer = normalizers.NFC()
174
- self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(
175
- add_prefix_space=self.add_prefix_space, trim_offsets=self.trim_offsets
176
- )
177
- self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False, trim_offsets=True)
178
-
179
- # Call parent to handle AddedToken properties
180
- super()._post_init()
181
-
182
- # Update post processor with current bos/eos settings
183
- self.update_post_processor()
184
-
185
146
 
186
147
  __all__ = ["GPTNeoXTokenizer"]
@@ -30,6 +30,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
30
30
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
31
31
  from ...modeling_utils import PreTrainedModel
32
32
  from ...utils import auto_docstring, is_torch_flex_attn_available, logging
33
+ from ...utils.generic import maybe_autocast
33
34
  from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig
34
35
 
35
36
 
@@ -77,7 +78,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
77
78
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
78
79
 
79
80
  self.register_buffer("inv_freq", inv_freq, persistent=False)
80
- self.original_inv_freq = inv_freq
81
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
81
82
 
82
83
  @staticmethod
83
84
  def compute_default_rope_parameters(
@@ -116,7 +117,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
116
117
  position_ids_expanded = position_ids[:, None, :].float()
117
118
 
118
119
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
119
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
120
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
120
121
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
121
122
  emb = torch.cat((freqs, freqs), dim=-1)
122
123
  cos = emb.cos() * self.attention_scaling
@@ -431,6 +432,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
431
432
  output_hidden_states: Optional[bool] = None,
432
433
  return_dict: Optional[bool] = None,
433
434
  cache_position: Optional[torch.LongTensor] = None,
435
+ **kwargs,
434
436
  ) -> Union[tuple, BaseModelOutputWithPast]:
435
437
  r"""
436
438
  Example:
@@ -117,5 +117,22 @@ class GptOssConfig(PreTrainedConfig):
117
117
  **kwargs,
118
118
  )
119
119
 
120
+ def __setattr__(self, key, value):
121
+ """
122
+ Overwritten to allow checking for the proper attention implementation to be used.
123
+
124
+ Due to `set_attn_implementation` which internally assigns `_attn_implementation_internal = "..."`, simply overwriting
125
+ the specific attention setter is not enough. Using a property/setter for `_attn_implementation_internal` would result in
126
+ a recursive dependency (as `_attn_implementation` acts as a wrapper around `_attn_implementation_internal`) - hence, this
127
+ workaround.
128
+ """
129
+ if key in ("_attn_implementation", "_attn_implementation_internal"):
130
+ if value and "flash" in value and value.removeprefix("paged|") != "kernels-community/vllm-flash-attn3":
131
+ raise ValueError(
132
+ f"GPT-OSS model does not support the specified flash attention implementation: {value}. "
133
+ "Only `kernels-community/vllm-flash-attn3` is supported."
134
+ )
135
+ super().__setattr__(key, value)
136
+
120
137
 
121
138
  __all__ = ["GptOssConfig"]
@@ -28,7 +28,7 @@ from torch.nn import functional as F
28
28
  from ... import initialization as init
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations.hub_kernels import use_kernel_forward_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
32
32
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
33
33
  from ...modeling_layers import (
34
34
  GenericForSequenceClassification,
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
42
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
43
- from ...utils.generic import OutputRecorder, check_model_inputs
43
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
44
44
  from .configuration_gpt_oss import GptOssConfig
45
45
 
46
46
 
@@ -88,8 +88,8 @@ class GptOssExperts(nn.Module):
88
88
 
89
89
  Args:
90
90
  hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
91
- selected_experts (torch.Tensor): (batch_size * token_num, top_k)
92
- routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
91
+ selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
92
+ routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
93
93
  Returns:
94
94
  torch.Tensor
95
95
  """
@@ -159,8 +159,8 @@ class GptOssTopKRouter(nn.Module):
159
159
 
160
160
  def forward(self, hidden_states):
161
161
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
162
- router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
163
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
162
+ router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
163
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
164
164
  router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
165
165
  router_scores = router_top_value
166
166
  return router_logits, router_scores, router_indices
@@ -196,7 +196,7 @@ class GptOssRotaryEmbedding(nn.Module):
196
196
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
197
197
 
198
198
  self.register_buffer("inv_freq", inv_freq, persistent=False)
199
- self.original_inv_freq = inv_freq
199
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
200
200
 
201
201
  @staticmethod
202
202
  def compute_default_rope_parameters(
@@ -235,7 +235,7 @@ class GptOssRotaryEmbedding(nn.Module):
235
235
  position_ids_expanded = position_ids[:, None, :].float()
236
236
 
237
237
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
238
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
238
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
239
239
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
240
240
  emb = freqs
241
241
  cos = emb.cos() * self.attention_scaling
@@ -301,12 +301,13 @@ def eager_attention_forward(
301
301
  combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
302
302
  probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
303
303
  scores = probs[..., :-1] # we drop the sink here
304
- attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
304
+ attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype)
305
305
  attn_output = torch.matmul(attn_weights, value_states)
306
306
  attn_output = attn_output.transpose(1, 2).contiguous()
307
307
  return attn_output, attn_weights
308
308
 
309
309
 
310
+ @use_kernelized_func(apply_rotary_pos_emb)
310
311
  class GptOssAttention(nn.Module):
311
312
  """Multi-headed attention from 'Attention Is All You Need' paper"""
312
313
 
@@ -332,7 +333,6 @@ class GptOssAttention(nn.Module):
332
333
  self.o_proj = nn.Linear(
333
334
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
334
335
  )
335
- self.rotary_fn = apply_rotary_pos_emb
336
336
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
337
337
  self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))
338
338
 
@@ -343,7 +343,6 @@ class GptOssAttention(nn.Module):
343
343
  attention_mask: Optional[torch.Tensor],
344
344
  past_key_values: Optional[Cache] = None,
345
345
  cache_position: Optional[torch.LongTensor] = None,
346
- position_ids: Optional[torch.LongTensor] = None,
347
346
  **kwargs: Unpack[TransformersKwargs],
348
347
  ) -> tuple[torch.Tensor, torch.Tensor]:
349
348
  input_shape = hidden_states.shape[:-1]
@@ -373,7 +372,6 @@ class GptOssAttention(nn.Module):
373
372
  dropout=0.0 if not self.training else self.attention_dropout,
374
373
  scaling=self.scaling,
375
374
  sliding_window=self.sliding_window,
376
- position_ids=position_ids,
377
375
  s_aux=self.sinks, # diff with Llama
378
376
  **kwargs,
379
377
  )
@@ -446,8 +444,6 @@ class GptOssPreTrainedModel(PreTrainedModel):
446
444
  "attentions": GptOssAttention,
447
445
  }
448
446
  _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
449
- _supports_flash_attention = False
450
- _supports_flex_attention = False
451
447
 
452
448
  @torch.no_grad()
453
449
  def _init_weights(self, module):
@@ -21,7 +21,7 @@ from torch.nn import functional as F
21
21
 
22
22
  from ... import initialization as init
23
23
  from ...cache_utils import Cache, DynamicCache
24
- from ...integrations.hub_kernels import use_kernel_forward_from_hub
24
+ from ...integrations import use_kernel_forward_from_hub
25
25
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
26
26
  from ...modeling_outputs import (
27
27
  MoeModelOutputWithPast,
@@ -34,7 +34,7 @@ from ...utils import (
34
34
  auto_docstring,
35
35
  logging,
36
36
  )
37
- from ...utils.generic import OutputRecorder, check_model_inputs
37
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
38
38
  from ..llama.modeling_llama import (
39
39
  LlamaDecoderLayer,
40
40
  LlamaPreTrainedModel,
@@ -86,8 +86,8 @@ class GptOssExperts(nn.Module):
86
86
 
87
87
  Args:
88
88
  hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
89
- selected_experts (torch.Tensor): (batch_size * token_num, top_k)
90
- routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
89
+ selected_experts (torch.Tensor): (batch_size * seq_len, top_k)
90
+ routing_weights (torch.Tensor): (batch_size * seq_len, top_k)
91
91
  Returns:
92
92
  torch.Tensor
93
93
  """
@@ -157,8 +157,8 @@ class GptOssTopKRouter(nn.Module):
157
157
 
158
158
  def forward(self, hidden_states):
159
159
  hidden_states = hidden_states.reshape(-1, self.hidden_dim)
160
- router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
161
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
160
+ router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts)
161
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k)
162
162
  router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
163
163
  router_scores = router_top_value
164
164
  return router_logits, router_scores, router_indices
@@ -185,7 +185,7 @@ class GptOssRotaryEmbedding(Qwen2RotaryEmbedding):
185
185
  position_ids_expanded = position_ids[:, None, :].float()
186
186
 
187
187
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
188
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
188
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
189
189
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
190
190
  emb = freqs
191
191
  cos = emb.cos() * self.attention_scaling
@@ -239,7 +239,7 @@ def eager_attention_forward(
239
239
  combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
240
240
  probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
241
241
  scores = probs[..., :-1] # we drop the sink here
242
- attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
242
+ attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype)
243
243
  attn_output = torch.matmul(attn_weights, value_states)
244
244
  attn_output = attn_output.transpose(1, 2).contiguous()
245
245
  return attn_output, attn_weights
@@ -269,7 +269,6 @@ class GptOssAttention(Qwen2Attention):
269
269
  attention_mask: Optional[torch.Tensor],
270
270
  past_key_values: Optional[Cache] = None,
271
271
  cache_position: Optional[torch.LongTensor] = None,
272
- position_ids: Optional[torch.LongTensor] = None,
273
272
  **kwargs: Unpack[TransformersKwargs],
274
273
  ) -> tuple[torch.Tensor, torch.Tensor]:
275
274
  input_shape = hidden_states.shape[:-1]
@@ -299,7 +298,6 @@ class GptOssAttention(Qwen2Attention):
299
298
  dropout=0.0 if not self.training else self.attention_dropout,
300
299
  scaling=self.scaling,
301
300
  sliding_window=self.sliding_window,
302
- position_ids=position_ids,
303
301
  s_aux=self.sinks, # diff with Llama
304
302
  **kwargs,
305
303
  )
@@ -356,8 +354,6 @@ class GptOssDecoderLayer(LlamaDecoderLayer):
356
354
  class GptOssPreTrainedModel(LlamaPreTrainedModel):
357
355
  _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]
358
356
  _supports_sdpa = False
359
- _supports_flash_attention = False
360
- _supports_flex_attention = False
361
357
  _can_record_outputs = {
362
358
  "router_logits": OutputRecorder(GptOssTopKRouter, index=0),
363
359
  "hidden_states": GptOssDecoderLayer,
@@ -14,12 +14,14 @@
14
14
  # limitations under the License.
15
15
  """PyTorch GPT-J model."""
16
16
 
17
+ import math
17
18
  from typing import Optional, Union
18
19
 
19
20
  import torch
20
21
  from torch import nn
21
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
22
23
 
24
+ from ... import initialization as init
23
25
  from ...activations import ACT2FN
24
26
  from ...cache_utils import Cache, DynamicCache
25
27
  from ...generation import GenerationMixin
@@ -77,7 +79,7 @@ class GPTJAttention(nn.Module):
77
79
  def __init__(self, config, layer_idx=None):
78
80
  super().__init__()
79
81
  self.config = config
80
- max_positions = config.max_position_embeddings
82
+ self.max_positions = config.max_position_embeddings
81
83
 
82
84
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
83
85
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
@@ -99,15 +101,17 @@ class GPTJAttention(nn.Module):
99
101
  f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
100
102
  f" `num_attention_heads`: {self.num_attention_heads})."
101
103
  )
102
- self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
104
+ self.scale_attn = math.sqrt(self.head_dim)
103
105
 
104
106
  self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
105
107
  self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
106
108
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
107
109
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
108
110
  self.rotary_dim = config.rotary_dim
109
- pos_embd_dim = self.rotary_dim or self.embed_dim
110
- self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
111
+ self.pos_embd_dim = self.rotary_dim or self.embed_dim
112
+ self.register_buffer(
113
+ "embed_positions", create_sinusoidal_positions(self.max_positions, self.pos_embd_dim), persistent=False
114
+ )
111
115
 
112
116
  def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
113
117
  """
@@ -334,8 +338,8 @@ class GPTJFlashAttention2(GPTJAttention):
334
338
  else torch.get_autocast_gpu_dtype()
335
339
  )
336
340
  # Handle the case where the model is quantized
337
- elif hasattr(self.config, "_pre_quantization_dtype"):
338
- target_dtype = self.config._pre_quantization_dtype
341
+ elif hasattr(self.config, "quantization_config"):
342
+ target_dtype = self.config.dtype
339
343
  else:
340
344
  target_dtype = self.q_proj.weight.dtype
341
345
 
@@ -444,6 +448,11 @@ class GPTJPreTrainedModel(PreTrainedModel):
444
448
  _supports_flash_attn = True
445
449
  _can_compile_fullgraph = True
446
450
 
451
+ def _init_weights(self, module):
452
+ super()._init_weights(module)
453
+ if isinstance(module, GPTJAttention):
454
+ init.copy_(module.embed_positions, create_sinusoidal_positions(module.max_positions, module.pos_embd_dim))
455
+
447
456
 
448
457
  @auto_docstring
449
458
  class GPTJModel(GPTJPreTrainedModel):
@@ -482,6 +491,7 @@ class GPTJModel(GPTJPreTrainedModel):
482
491
  output_hidden_states: Optional[bool] = None,
483
492
  return_dict: Optional[bool] = None,
484
493
  cache_position: Optional[torch.LongTensor] = None,
494
+ **kwargs,
485
495
  ) -> Union[tuple, BaseModelOutputWithPast]:
486
496
  r"""
487
497
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
@@ -819,6 +829,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
819
829
  output_attentions: Optional[bool] = None,
820
830
  output_hidden_states: Optional[bool] = None,
821
831
  return_dict: Optional[bool] = None,
832
+ **kwargs,
822
833
  ) -> Union[tuple, SequenceClassifierOutputWithPast]:
823
834
  r"""
824
835
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
@@ -930,6 +941,7 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
930
941
  output_attentions: Optional[bool] = None,
931
942
  output_hidden_states: Optional[bool] = None,
932
943
  return_dict: Optional[bool] = None,
944
+ **kwargs,
933
945
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
934
946
  r"""
935
947
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
@@ -28,7 +28,7 @@ from torch import nn
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
32
32
  from ...masking_utils import create_causal_mask
33
33
  from ...modeling_layers import GradientCheckpointingLayer
34
34
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -36,7 +36,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
36
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
37
  from ...processing_utils import Unpack
38
38
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
39
- from ...utils.generic import check_model_inputs
39
+ from ...utils.generic import check_model_inputs, maybe_autocast
40
40
  from .configuration_granite import GraniteConfig
41
41
 
42
42
 
@@ -116,6 +116,7 @@ def eager_attention_forward(
116
116
  return attn_output, attn_weights
117
117
 
118
118
 
119
+ @use_kernelized_func(apply_rotary_pos_emb)
119
120
  class GraniteAttention(nn.Module):
120
121
  """Multi-headed attention from 'Attention Is All You Need' paper"""
121
122
 
@@ -141,7 +142,6 @@ class GraniteAttention(nn.Module):
141
142
  self.o_proj = nn.Linear(
142
143
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
143
144
  )
144
- self.rotary_fn = apply_rotary_pos_emb
145
145
 
146
146
  def forward(
147
147
  self,
@@ -337,7 +337,7 @@ class GraniteRotaryEmbedding(nn.Module):
337
337
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
338
338
 
339
339
  self.register_buffer("inv_freq", inv_freq, persistent=False)
340
- self.original_inv_freq = inv_freq
340
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
341
341
 
342
342
  @staticmethod
343
343
  def compute_default_rope_parameters(
@@ -376,7 +376,7 @@ class GraniteRotaryEmbedding(nn.Module):
376
376
  position_ids_expanded = position_ids[:, None, :].float()
377
377
 
378
378
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
379
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
379
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
380
380
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
381
381
  emb = torch.cat((freqs, freqs), dim=-1)
382
382
  cos = emb.cos() * self.attention_scaling