transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -293,6 +293,12 @@ class GraniteSpeechPreTrainedModel(PreTrainedModel):
293
293
  super()._init_weights(module)
294
294
  if isinstance(module, GraniteSpeechEncoderProjector):
295
295
  init.normal_(module.query)
296
+ elif isinstance(module, GraniteSpeechCTCEncoder):
297
+ context_size = module.config.context_size
298
+ seq = torch.arange(context_size)
299
+ relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
300
+ attention_dists = torch.clamp(relpos_dist, -context_size, context_size) + module.config.max_pos_emb
301
+ init.copy_(module.attention_dists, attention_dists)
296
302
 
297
303
 
298
304
  @auto_docstring(
@@ -322,6 +328,12 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
322
328
 
323
329
  self.post_init()
324
330
 
331
+ def set_decoder(self, decoder):
332
+ self.language_model.set_decoder(decoder)
333
+
334
+ def get_decoder(self):
335
+ return self.language_model.get_decoder()
336
+
325
337
  def set_input_embeddings(self, value):
326
338
  self.language_model.set_input_embeddings(value)
327
339
 
@@ -458,6 +470,7 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
458
470
  attention_mask=None,
459
471
  cache_position=None,
460
472
  logits_to_keep=None,
473
+ is_first_iteration=False,
461
474
  **kwargs,
462
475
  ):
463
476
  # Overwritten -- in specific circumstances we don't want to forward audio inputs to the model
@@ -469,13 +482,14 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
469
482
  attention_mask=attention_mask,
470
483
  cache_position=cache_position,
471
484
  logits_to_keep=logits_to_keep,
485
+ is_first_iteration=is_first_iteration,
472
486
  **kwargs,
473
487
  )
474
488
 
475
489
  # If we're in cached decoding stage, input_features should be None because
476
490
  # input ids do not contain special audio token anymore Otherwise we need
477
491
  # input feature values to be passed to the model
478
- if cache_position[0] == 0:
492
+ if is_first_iteration or not kwargs.get("use_cache", True):
479
493
  model_inputs["input_features"] = input_features
480
494
  return model_inputs
481
495
 
@@ -30,7 +30,7 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
33
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_layers import GradientCheckpointingLayer
36
36
  from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...utils import TransformersKwargs, auto_docstring
41
- from ...utils.generic import can_return_tuple, check_model_inputs
41
+ from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast
42
42
  from .configuration_granitemoe import GraniteMoeConfig
43
43
 
44
44
 
@@ -80,7 +80,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
80
80
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
81
81
 
82
82
  self.register_buffer("inv_freq", inv_freq, persistent=False)
83
- self.original_inv_freq = inv_freq
83
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
84
84
 
85
85
  @staticmethod
86
86
  def compute_default_rope_parameters(
@@ -119,7 +119,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
119
119
  position_ids_expanded = position_ids[:, None, :].float()
120
120
 
121
121
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
122
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
122
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
123
123
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
124
124
  emb = torch.cat((freqs, freqs), dim=-1)
125
125
  cos = emb.cos() * self.attention_scaling
@@ -338,6 +338,7 @@ def eager_attention_forward(
338
338
  return attn_output, attn_weights
339
339
 
340
340
 
341
+ @use_kernelized_func(apply_rotary_pos_emb)
341
342
  class GraniteMoeAttention(nn.Module):
342
343
  """Multi-headed attention from 'Attention Is All You Need' paper"""
343
344
 
@@ -363,7 +364,6 @@ class GraniteMoeAttention(nn.Module):
363
364
  self.o_proj = nn.Linear(
364
365
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
365
366
  )
366
- self.rotary_fn = apply_rotary_pos_emb
367
367
 
368
368
  def forward(
369
369
  self,
@@ -456,8 +456,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
456
456
  _supports_flash_attn = True
457
457
  _supports_sdpa = True
458
458
  _supports_flex_attn = True
459
-
460
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
459
+ _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
461
460
  _supports_attention_backend = True
462
461
  _can_record_outputs = {
463
462
  "hidden_states": GraniteMoeDecoderLayer,
@@ -714,8 +713,6 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
714
713
 
715
714
  loss = None
716
715
  if labels is not None:
717
- # Upcast to float if we need to compute the loss to avoid potential precision issues
718
- logits = logits.float()
719
716
  # Flatten the tokens
720
717
  loss = self.loss_function(
721
718
  logits,
@@ -146,8 +146,7 @@ class GraniteMoePreTrainedModel(LlamaPreTrainedModel, PreTrainedModel):
146
146
  _skip_keys_device_placement = ["past_key_values"]
147
147
  _supports_flash_attn = True
148
148
  _supports_sdpa = True
149
-
150
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
149
+ _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
151
150
 
152
151
  @torch.no_grad()
153
152
  def _init_weights(self, module):
@@ -295,8 +294,6 @@ class GraniteMoeForCausalLM(MixtralForCausalLM):
295
294
 
296
295
  loss = None
297
296
  if labels is not None:
298
- # Upcast to float if we need to compute the loss to avoid potential precision issues
299
- logits = logits.float()
300
297
  # Flatten the tokens
301
298
  loss = self.loss_function(
302
299
  logits,
@@ -92,6 +92,8 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
92
92
  allow the model to output the auxiliary loss.
93
93
  router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router auxiliary loss coefficient
94
94
  shared_intermediate_size (`int`, *optional*, defaults to 1024): intermediate size for shared experts.
95
+ position_embedding_type (`str`, *optional*):
96
+ Positional embedding type to be used; defaults to None. Allowed options: `[None, "rope"]`
95
97
  layer_types (`List`, *optional*): list of strings to be used as layer types.
96
98
  Allowed choices: "mamba", "attention".
97
99
  mamba_n_heads (`int`, *optional*, defaults to 128):
@@ -159,6 +161,7 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
159
161
  output_router_logits: Optional[bool] = False,
160
162
  router_aux_loss_coef: Optional[float] = 0.001,
161
163
  shared_intermediate_size: Optional[int] = 1024,
164
+ position_embedding_type: Optional[str] = None,
162
165
  layer_types: Optional[list[str]] = None,
163
166
  mamba_n_heads: Optional[int] = 128,
164
167
  mamba_n_groups: Optional[int] = 1,
@@ -198,6 +201,7 @@ class GraniteMoeHybridConfig(PreTrainedConfig):
198
201
  self.output_router_logits = output_router_logits
199
202
  self.router_aux_loss_coef = router_aux_loss_coef
200
203
  self.shared_intermediate_size = shared_intermediate_size
204
+ self.position_embedding_type = position_embedding_type
201
205
  self.rope_parameters = rope_parameters
202
206
 
203
207
  mamba_intermediate = mamba_expand * hidden_size
@@ -31,7 +31,12 @@ from transformers.activations import ACT2FN
31
31
  from ... import initialization as init
32
32
  from ...cache_utils import Cache
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
34
+ from ...integrations import (
35
+ lazy_load_kernel,
36
+ use_kernel_forward_from_hub,
37
+ use_kernel_func_from_hub,
38
+ use_kernelized_func,
39
+ )
35
40
  from ...masking_utils import create_causal_mask
36
41
  from ...modeling_layers import GradientCheckpointingLayer
37
42
  from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -39,23 +44,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
44
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
45
  from ...processing_utils import Unpack
41
46
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
42
- from ...utils.generic import check_model_inputs
43
- from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
47
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
48
  from .configuration_granitemoehybrid import GraniteMoeHybridConfig
45
49
 
46
50
 
47
- if is_mamba_2_ssm_available():
48
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
49
- from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
50
- else:
51
- selective_state_update = None
52
-
53
- if is_causal_conv1d_available():
54
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
55
- else:
56
- causal_conv1d_update, causal_conv1d_fn = None, None
57
-
58
-
59
51
  logger = logging.get_logger(__name__)
60
52
 
61
53
 
@@ -132,6 +124,7 @@ def eager_attention_forward(
132
124
  return attn_output, attn_weights
133
125
 
134
126
 
127
+ @use_kernelized_func(apply_rotary_pos_emb)
135
128
  class GraniteMoeHybridAttention(nn.Module):
136
129
  """Multi-headed attention from 'Attention Is All You Need' paper"""
137
130
 
@@ -157,7 +150,6 @@ class GraniteMoeHybridAttention(nn.Module):
157
150
  self.o_proj = nn.Linear(
158
151
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
159
152
  )
160
- self.rotary_fn = apply_rotary_pos_emb
161
153
 
162
154
  def forward(
163
155
  self,
@@ -165,6 +157,7 @@ class GraniteMoeHybridAttention(nn.Module):
165
157
  attention_mask: Optional[torch.Tensor],
166
158
  past_key_values: Optional[Cache] = None,
167
159
  cache_position: Optional[torch.LongTensor] = None,
160
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
168
161
  **kwargs: Unpack[TransformersKwargs],
169
162
  ) -> tuple[torch.Tensor, torch.Tensor]:
170
163
  input_shape = hidden_states.shape[:-1]
@@ -174,6 +167,10 @@ class GraniteMoeHybridAttention(nn.Module):
174
167
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
175
168
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
176
169
 
170
+ if position_embeddings is not None:
171
+ cos, sin = position_embeddings
172
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
173
+
177
174
  if past_key_values is not None:
178
175
  cache_kwargs = {"cache_position": cache_position}
179
176
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -371,9 +368,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
371
368
  return hidden_states
372
369
 
373
370
 
374
- is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
375
-
376
-
377
371
  # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
378
372
  class GraniteMoeHybridMambaLayer(nn.Module):
379
373
  """
@@ -445,6 +439,20 @@ class GraniteMoeHybridMambaLayer(nn.Module):
445
439
 
446
440
  self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
447
441
 
442
+ global causal_conv1d_update, causal_conv1d_fn
443
+ causal_conv1d = lazy_load_kernel("causal-conv1d")
444
+ causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
445
+ causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
446
+
447
+ global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
448
+ mamba_ssm = lazy_load_kernel("mamba-ssm")
449
+ selective_state_update = getattr(mamba_ssm, "selective_state_update", None)
450
+ mamba_chunk_scan_combined = getattr(mamba_ssm, "mamba_chunk_scan_combined", None)
451
+ mamba_split_conv1d_scan_combined = getattr(mamba_ssm, "mamba_split_conv1d_scan_combined", None)
452
+
453
+ global is_fast_path_available
454
+ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
455
+
448
456
  if not is_fast_path_available:
449
457
  logger.warning_once(
450
458
  "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
@@ -915,7 +923,7 @@ class GraniteMoeHybridRotaryEmbedding(nn.Module):
915
923
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
916
924
 
917
925
  self.register_buffer("inv_freq", inv_freq, persistent=False)
918
- self.original_inv_freq = inv_freq
926
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
919
927
 
920
928
  @staticmethod
921
929
  def compute_default_rope_parameters(
@@ -954,7 +962,7 @@ class GraniteMoeHybridRotaryEmbedding(nn.Module):
954
962
  position_ids_expanded = position_ids[:, None, :].float()
955
963
 
956
964
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
957
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
965
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
958
966
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
959
967
  emb = torch.cat((freqs, freqs), dim=-1)
960
968
  cos = emb.cos() * self.attention_scaling
@@ -1231,8 +1239,7 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel):
1231
1239
  _supports_flash_attn = True
1232
1240
  _supports_sdpa = True
1233
1241
  _supports_flex_attn = True
1234
-
1235
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
1242
+ _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
1236
1243
  _supports_attention_backend = True
1237
1244
  _can_record_outputs = {
1238
1245
  "hidden_states": GraniteMoeHybridDecoderLayer,
@@ -1265,7 +1272,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
1265
1272
  [GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1266
1273
  )
1267
1274
  self.norm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1268
- self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config=config)
1275
+ self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if config.position_embedding_type == "rope" else None
1269
1276
  self.gradient_checkpointing = False
1270
1277
  self.embedding_multiplier = config.embedding_multiplier
1271
1278
 
@@ -1313,7 +1320,9 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
1313
1320
 
1314
1321
  # embed positions
1315
1322
  hidden_states = inputs_embeds
1316
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
1323
+ position_embeddings = None
1324
+ if self.rotary_emb is not None:
1325
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1317
1326
 
1318
1327
  for decoder_layer in self.layers:
1319
1328
  # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
@@ -1510,8 +1519,6 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
1510
1519
 
1511
1520
  loss = None
1512
1521
  if labels is not None:
1513
- # Upcast to float if we need to compute the loss to avoid potential precision issues
1514
- logits = logits.float()
1515
1522
  # Flatten the tokens
1516
1523
  loss = self.loss_function(
1517
1524
  logits,
@@ -1549,6 +1556,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
1549
1556
  cache_position=None,
1550
1557
  position_ids=None,
1551
1558
  use_cache=True,
1559
+ is_first_iteration=False,
1552
1560
  **kwargs,
1553
1561
  ):
1554
1562
  # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
@@ -1581,7 +1589,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
1581
1589
  position_ids = position_ids[:, -input_ids.shape[1] :]
1582
1590
 
1583
1591
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1584
- if inputs_embeds is not None and empty_past_kv:
1592
+ if inputs_embeds is not None and is_first_iteration:
1585
1593
  model_inputs = {"inputs_embeds": inputs_embeds}
1586
1594
  else:
1587
1595
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
@@ -39,6 +39,7 @@ from ..granitemoeshared.modeling_granitemoeshared import (
39
39
  GraniteMoeSharedModel,
40
40
  GraniteMoeSharedMoE,
41
41
  GraniteMoeSharedPreTrainedModel,
42
+ apply_rotary_pos_emb,
42
43
  eager_attention_forward,
43
44
  )
44
45
  from .configuration_granitemoehybrid import GraniteMoeHybridConfig
@@ -57,6 +58,7 @@ class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
57
58
  attention_mask: Optional[torch.Tensor],
58
59
  past_key_values: Optional[Cache] = None,
59
60
  cache_position: Optional[torch.LongTensor] = None,
61
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
60
62
  **kwargs: Unpack[TransformersKwargs],
61
63
  ) -> tuple[torch.Tensor, torch.Tensor]:
62
64
  input_shape = hidden_states.shape[:-1]
@@ -66,6 +68,10 @@ class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
66
68
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
67
69
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
68
70
 
71
+ if position_embeddings is not None:
72
+ cos, sin = position_embeddings
73
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
74
+
69
75
  if past_key_values is not None:
70
76
  cache_kwargs = {"cache_position": cache_position}
71
77
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -203,6 +209,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
203
209
  [GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
204
210
  )
205
211
  self.embedding_multiplier = config.embedding_multiplier
212
+ self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if config.position_embedding_type == "rope" else None
206
213
 
207
214
  @auto_docstring
208
215
  @check_model_inputs
@@ -245,7 +252,9 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
245
252
 
246
253
  # embed positions
247
254
  hidden_states = inputs_embeds
248
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
255
+ position_embeddings = None
256
+ if self.rotary_emb is not None:
257
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
249
258
 
250
259
  for decoder_layer in self.layers:
251
260
  # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
@@ -300,6 +309,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
300
309
  cache_position=None,
301
310
  position_ids=None,
302
311
  use_cache=True,
312
+ is_first_iteration=False,
303
313
  **kwargs,
304
314
  ):
305
315
  # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
@@ -332,7 +342,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
332
342
  position_ids = position_ids[:, -input_ids.shape[1] :]
333
343
 
334
344
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
335
- if inputs_embeds is not None and empty_past_kv:
345
+ if inputs_embeds is not None and is_first_iteration:
336
346
  model_inputs = {"inputs_embeds": inputs_embeds}
337
347
  else:
338
348
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
@@ -30,7 +30,7 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
33
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_layers import GradientCheckpointingLayer
36
36
  from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...utils import TransformersKwargs, auto_docstring
41
- from ...utils.generic import can_return_tuple, check_model_inputs
41
+ from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast
42
42
  from .configuration_granitemoeshared import GraniteMoeSharedConfig
43
43
 
44
44
 
@@ -328,6 +328,7 @@ def eager_attention_forward(
328
328
  return attn_output, attn_weights
329
329
 
330
330
 
331
+ @use_kernelized_func(apply_rotary_pos_emb)
331
332
  class GraniteMoeSharedAttention(nn.Module):
332
333
  """Multi-headed attention from 'Attention Is All You Need' paper"""
333
334
 
@@ -353,7 +354,6 @@ class GraniteMoeSharedAttention(nn.Module):
353
354
  self.o_proj = nn.Linear(
354
355
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
355
356
  )
356
- self.rotary_fn = apply_rotary_pos_emb
357
357
 
358
358
  def forward(
359
359
  self,
@@ -462,8 +462,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
462
462
  _supports_flash_attn = True
463
463
  _supports_sdpa = True
464
464
  _supports_flex_attn = True
465
-
466
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
465
+ _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
467
466
  _supports_attention_backend = True
468
467
  _can_record_outputs = {
469
468
  "hidden_states": GraniteMoeSharedDecoderLayer,
@@ -494,7 +493,7 @@ class GraniteMoeSharedRotaryEmbedding(nn.Module):
494
493
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
495
494
 
496
495
  self.register_buffer("inv_freq", inv_freq, persistent=False)
497
- self.original_inv_freq = inv_freq
496
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
498
497
 
499
498
  @staticmethod
500
499
  def compute_default_rope_parameters(
@@ -533,7 +532,7 @@ class GraniteMoeSharedRotaryEmbedding(nn.Module):
533
532
  position_ids_expanded = position_ids[:, None, :].float()
534
533
 
535
534
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
536
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
535
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
537
536
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
538
537
  emb = torch.cat((freqs, freqs), dim=-1)
539
538
  cos = emb.cos() * self.attention_scaling
@@ -785,8 +784,6 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
785
784
 
786
785
  loss = None
787
786
  if labels is not None:
788
- # Upcast to float if we need to compute the loss to avoid potential precision issues
789
- logits = logits.float()
790
787
  # Flatten the tokens
791
788
  loss = self.loss_function(
792
789
  logits,
@@ -34,7 +34,7 @@ class GroundingDinoConfig(PreTrainedConfig):
34
34
  documentation from [`PreTrainedConfig`] for more information.
35
35
 
36
36
  Args:
37
- backbone_config (`PreTrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
37
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `SwinConfig()`):
38
38
  The configuration of the backbone model.
39
39
  backbone (`str`, *optional*):
40
40
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -285,9 +285,8 @@ class GroundingDinoConfig(PreTrainedConfig):
285
285
  self.positional_embedding_temperature = positional_embedding_temperature
286
286
  self.init_std = init_std
287
287
  self.layer_norm_eps = layer_norm_eps
288
+
288
289
  super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
289
- self.tie_encoder_decoder = True
290
- self.tie_encoder_decoder = True
291
290
 
292
291
 
293
292
  __all__ = ["GroundingDinoConfig"]
@@ -1415,7 +1415,7 @@ class GroundingDinoPreTrainedModel(PreTrainedModel):
1415
1415
  elif isinstance(module, GroundingDinoFusionLayer):
1416
1416
  init.constant_(module.vision_param, 1e-4)
1417
1417
  init.constant_(module.text_param, 1e-4)
1418
- elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
1418
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
1419
1419
  init.normal_(module.weight, mean=0.0, std=std)
1420
1420
  if module.bias is not None:
1421
1421
  init.zeros_(module.bias)
@@ -1510,7 +1510,8 @@ class GroundingDinoEncoder(GroundingDinoPreTrainedModel):
1510
1510
  output_attentions=None,
1511
1511
  output_hidden_states=None,
1512
1512
  return_dict=None,
1513
- ):
1513
+ **kwargs,
1514
+ ) -> Union[tuple, GroundingDinoEncoderOutput]:
1514
1515
  r"""
1515
1516
  Args:
1516
1517
  vision_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -1664,7 +1665,8 @@ class GroundingDinoDecoder(GroundingDinoPreTrainedModel):
1664
1665
  output_attentions=None,
1665
1666
  output_hidden_states=None,
1666
1667
  return_dict=None,
1667
- ):
1668
+ **kwargs,
1669
+ ) -> Union[tuple, GroundingDinoDecoderOutput]:
1668
1670
  r"""
1669
1671
  Args:
1670
1672
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
@@ -2056,7 +2058,8 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel):
2056
2058
  output_attentions=None,
2057
2059
  output_hidden_states=None,
2058
2060
  return_dict=None,
2059
- ):
2061
+ **kwargs,
2062
+ ) -> Union[tuple, GroundingDinoModelOutput]:
2060
2063
  r"""
2061
2064
  input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
2062
2065
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -2460,6 +2463,7 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
2460
2463
  output_hidden_states: Optional[bool] = None,
2461
2464
  return_dict: Optional[bool] = None,
2462
2465
  labels: Optional[list[dict[str, Union[torch.LongTensor, torch.FloatTensor]]]] = None,
2466
+ **kwargs,
2463
2467
  ):
2464
2468
  r"""
2465
2469
  input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
@@ -758,14 +758,19 @@ class GroupViTPreTrainedModel(PreTrainedModel):
758
758
  init.normal_(module.weight, mean=0.0, std=init_range)
759
759
  if module.bias is not None:
760
760
  init.zeros_(module.bias)
761
- elif isinstance(module, nn.LayerNorm):
761
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
762
762
  init.zeros_(module.bias)
763
763
  init.ones_(module.weight)
764
+ if getattr(module, "running_mean", None) is not None:
765
+ init.zeros_(module.running_mean)
766
+ init.ones_(module.running_var)
767
+ init.zeros_(module.num_batches_tracked)
764
768
 
765
769
  factor = self.config.initializer_factor
766
770
  if isinstance(module, GroupViTTextEmbeddings):
767
771
  init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
768
772
  init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
773
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
769
774
  elif isinstance(module, GroupViTAttention):
770
775
  factor = self.config.initializer_factor
771
776
  in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
@@ -1045,6 +1050,7 @@ class GroupViTTextModel(GroupViTPreTrainedModel):
1045
1050
  output_attentions: Optional[bool] = None,
1046
1051
  output_hidden_states: Optional[bool] = None,
1047
1052
  return_dict: Optional[bool] = None,
1053
+ **kwargs,
1048
1054
  ) -> Union[tuple, BaseModelOutputWithPooling]:
1049
1055
  r"""
1050
1056
  Examples:
@@ -1145,6 +1151,7 @@ class GroupViTVisionModel(GroupViTPreTrainedModel):
1145
1151
  output_attentions: Optional[bool] = None,
1146
1152
  output_hidden_states: Optional[bool] = None,
1147
1153
  return_dict: Optional[bool] = None,
1154
+ **kwargs,
1148
1155
  ) -> Union[tuple, BaseModelOutputWithPooling]:
1149
1156
  r"""
1150
1157
  Examples:
@@ -1297,6 +1304,7 @@ class GroupViTModel(GroupViTPreTrainedModel):
1297
1304
  output_hidden_states: Optional[bool] = None,
1298
1305
  output_segmentation: Optional[bool] = None,
1299
1306
  return_dict: Optional[bool] = None,
1307
+ **kwargs,
1300
1308
  ) -> Union[tuple, GroupViTModelOutput]:
1301
1309
  r"""
1302
1310
  return_loss (`bool`, *optional*):
@@ -29,6 +29,7 @@ import torch.nn as nn
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
31
  from ...generation import GenerationMixin
32
+ from ...integrations import use_kernelized_func
32
33
  from ...masking_utils import create_causal_mask
33
34
  from ...modeling_layers import (
34
35
  GenericForSequenceClassification,
@@ -40,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
41
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
42
  from ...processing_utils import Unpack
42
43
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
43
- from ...utils.generic import check_model_inputs
44
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
45
  from .configuration_helium import HeliumConfig
45
46
 
46
47
 
@@ -78,7 +79,7 @@ class HeliumRotaryEmbedding(nn.Module):
78
79
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
79
80
 
80
81
  self.register_buffer("inv_freq", inv_freq, persistent=False)
81
- self.original_inv_freq = inv_freq
82
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
82
83
 
83
84
  @staticmethod
84
85
  def compute_default_rope_parameters(
@@ -117,7 +118,7 @@ class HeliumRotaryEmbedding(nn.Module):
117
118
  position_ids_expanded = position_ids[:, None, :].float()
118
119
 
119
120
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
120
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
121
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
121
122
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
122
123
  emb = torch.cat((freqs, freqs), dim=-1)
123
124
  cos = emb.cos() * self.attention_scaling
@@ -220,6 +221,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
220
221
  return q_embed, k_embed
221
222
 
222
223
 
224
+ @use_kernelized_func(apply_rotary_pos_emb)
223
225
  class HeliumAttention(nn.Module):
224
226
  """Multi-headed attention from 'Attention Is All You Need' paper"""
225
227
 
@@ -243,7 +245,6 @@ class HeliumAttention(nn.Module):
243
245
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
244
246
  )
245
247
  self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
246
- self.rotary_fn = apply_rotary_pos_emb
247
248
 
248
249
  def forward(
249
250
  self,