transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -30,7 +30,7 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
33
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_layers import GradientCheckpointingLayer
36
36
  from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...utils import TransformersKwargs, auto_docstring
41
- from ...utils.generic import can_return_tuple, check_model_inputs
41
+ from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast
42
42
  from .configuration_granitemoeshared import GraniteMoeSharedConfig
43
43
 
44
44
 
@@ -328,6 +328,7 @@ def eager_attention_forward(
328
328
  return attn_output, attn_weights
329
329
 
330
330
 
331
+ @use_kernelized_func(apply_rotary_pos_emb)
331
332
  class GraniteMoeSharedAttention(nn.Module):
332
333
  """Multi-headed attention from 'Attention Is All You Need' paper"""
333
334
 
@@ -353,7 +354,6 @@ class GraniteMoeSharedAttention(nn.Module):
353
354
  self.o_proj = nn.Linear(
354
355
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
355
356
  )
356
- self.rotary_fn = apply_rotary_pos_emb
357
357
 
358
358
  def forward(
359
359
  self,
@@ -533,7 +533,7 @@ class GraniteMoeSharedRotaryEmbedding(nn.Module):
533
533
  position_ids_expanded = position_ids[:, None, :].float()
534
534
 
535
535
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
536
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
536
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
537
537
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
538
538
  emb = torch.cat((freqs, freqs), dim=-1)
539
539
  cos = emb.cos() * self.attention_scaling
@@ -785,8 +785,6 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
785
785
 
786
786
  loss = None
787
787
  if labels is not None:
788
- # Upcast to float if we need to compute the loss to avoid potential precision issues
789
- logits = logits.float()
790
788
  # Flatten the tokens
791
789
  loss = self.loss_function(
792
790
  logits,
@@ -1510,6 +1510,7 @@ class GroundingDinoEncoder(GroundingDinoPreTrainedModel):
1510
1510
  output_attentions=None,
1511
1511
  output_hidden_states=None,
1512
1512
  return_dict=None,
1513
+ **kwargs,
1513
1514
  ):
1514
1515
  r"""
1515
1516
  Args:
@@ -1664,6 +1665,7 @@ class GroundingDinoDecoder(GroundingDinoPreTrainedModel):
1664
1665
  output_attentions=None,
1665
1666
  output_hidden_states=None,
1666
1667
  return_dict=None,
1668
+ **kwargs,
1667
1669
  ):
1668
1670
  r"""
1669
1671
  Args:
@@ -2056,6 +2058,7 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel):
2056
2058
  output_attentions=None,
2057
2059
  output_hidden_states=None,
2058
2060
  return_dict=None,
2061
+ **kwargs,
2059
2062
  ):
2060
2063
  r"""
2061
2064
  input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
@@ -2460,6 +2463,7 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
2460
2463
  output_hidden_states: Optional[bool] = None,
2461
2464
  return_dict: Optional[bool] = None,
2462
2465
  labels: Optional[list[dict[str, Union[torch.LongTensor, torch.FloatTensor]]]] = None,
2466
+ **kwargs,
2463
2467
  ):
2464
2468
  r"""
2465
2469
  input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
@@ -1045,6 +1045,7 @@ class GroupViTTextModel(GroupViTPreTrainedModel):
1045
1045
  output_attentions: Optional[bool] = None,
1046
1046
  output_hidden_states: Optional[bool] = None,
1047
1047
  return_dict: Optional[bool] = None,
1048
+ **kwargs,
1048
1049
  ) -> Union[tuple, BaseModelOutputWithPooling]:
1049
1050
  r"""
1050
1051
  Examples:
@@ -1145,6 +1146,7 @@ class GroupViTVisionModel(GroupViTPreTrainedModel):
1145
1146
  output_attentions: Optional[bool] = None,
1146
1147
  output_hidden_states: Optional[bool] = None,
1147
1148
  return_dict: Optional[bool] = None,
1149
+ **kwargs,
1148
1150
  ) -> Union[tuple, BaseModelOutputWithPooling]:
1149
1151
  r"""
1150
1152
  Examples:
@@ -1297,6 +1299,7 @@ class GroupViTModel(GroupViTPreTrainedModel):
1297
1299
  output_hidden_states: Optional[bool] = None,
1298
1300
  output_segmentation: Optional[bool] = None,
1299
1301
  return_dict: Optional[bool] = None,
1302
+ **kwargs,
1300
1303
  ) -> Union[tuple, GroupViTModelOutput]:
1301
1304
  r"""
1302
1305
  return_loss (`bool`, *optional*):
@@ -29,6 +29,7 @@ import torch.nn as nn
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
31
  from ...generation import GenerationMixin
32
+ from ...integrations import use_kernelized_func
32
33
  from ...masking_utils import create_causal_mask
33
34
  from ...modeling_layers import (
34
35
  GenericForSequenceClassification,
@@ -40,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
41
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
42
  from ...processing_utils import Unpack
42
43
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
43
- from ...utils.generic import check_model_inputs
44
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
45
  from .configuration_helium import HeliumConfig
45
46
 
46
47
 
@@ -117,7 +118,7 @@ class HeliumRotaryEmbedding(nn.Module):
117
118
  position_ids_expanded = position_ids[:, None, :].float()
118
119
 
119
120
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
120
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
121
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
121
122
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
122
123
  emb = torch.cat((freqs, freqs), dim=-1)
123
124
  cos = emb.cos() * self.attention_scaling
@@ -220,6 +221,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
220
221
  return q_embed, k_embed
221
222
 
222
223
 
224
+ @use_kernelized_func(apply_rotary_pos_emb)
223
225
  class HeliumAttention(nn.Module):
224
226
  """Multi-headed attention from 'Attention Is All You Need' paper"""
225
227
 
@@ -243,7 +245,6 @@ class HeliumAttention(nn.Module):
243
245
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
244
246
  )
245
247
  self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
246
- self.rotary_fn = apply_rotary_pos_emb
247
248
 
248
249
  def forward(
249
250
  self,
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional
16
+ from typing import Optional, Union
17
17
 
18
18
  from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
19
19
  from tokenizers.models import BPE
@@ -54,19 +54,20 @@ class HerbertTokenizer(TokenizersBackend):
54
54
  The mask token.
55
55
  sep_token (`str`, *optional*, defaults to `"</s>"`):
56
56
  The separator token.
57
- vocab (`dict`, *optional*):
57
+ vocab (`str`, `dict` or `list`, *optional*):
58
58
  Custom vocabulary dictionary.
59
- merges (`list`, *optional*):
59
+ merges (`str` or `list[str]`, *optional*):
60
60
  Custom merges list.
61
61
  """
62
62
 
63
63
  vocab_files_names = VOCAB_FILES_NAMES
64
- slow_tokenizer_class = None
64
+ model_input_names = ["input_ids", "attention_mask"]
65
+ model = BPE
65
66
 
66
67
  def __init__(
67
68
  self,
68
- vocab: Optional[dict] = None,
69
- merges: Optional[list] = None,
69
+ vocab: Optional[Union[str, dict[str, int]]] = None,
70
+ merges: Optional[Union[str, list[str]]] = None,
70
71
  cls_token: str = "<s>",
71
72
  unk_token: str = "<unk>",
72
73
  pad_token: str = "<pad>",
@@ -76,19 +77,8 @@ class HerbertTokenizer(TokenizersBackend):
76
77
  merges_file: Optional[str] = None,
77
78
  **kwargs,
78
79
  ):
79
- if vocab is not None:
80
- self._vocab = (
81
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
82
- )
83
- else:
84
- self._vocab = {}
85
-
86
- if merges is not None:
87
- # Convert lists to tuples if necessary (happens when loading from JSON)
88
- self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
89
- else:
90
- self._merges = []
91
-
80
+ self._vocab = vocab if vocab is not None else {str(unk_token): 0}
81
+ self._merges = merges or []
92
82
  self._tokenizer = Tokenizer(
93
83
  BPE(
94
84
  vocab=self._vocab,
@@ -105,13 +95,7 @@ class HerbertTokenizer(TokenizersBackend):
105
95
  self._tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
106
96
  self._tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
107
97
 
108
- tokenizer_object = self._tokenizer
109
-
110
- self.vocab_file = vocab_file
111
- self.merges_file = merges_file
112
-
113
98
  super().__init__(
114
- tokenizer_object=tokenizer_object,
115
99
  cls_token=cls_token,
116
100
  unk_token=unk_token,
117
101
  pad_token=pad_token,
@@ -347,7 +347,11 @@ class HGNetV2Backbone(HGNetV2PreTrainedModel, BackboneMixin):
347
347
 
348
348
  @auto_docstring
349
349
  def forward(
350
- self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
350
+ self,
351
+ pixel_values: Tensor,
352
+ output_hidden_states: Optional[bool] = None,
353
+ return_dict: Optional[bool] = None,
354
+ **kwargs,
351
355
  ) -> BackboneOutput:
352
356
  r"""
353
357
  Examples:
@@ -426,6 +430,7 @@ class HGNetV2ForImageClassification(HGNetV2PreTrainedModel):
426
430
  labels: Optional[torch.LongTensor] = None,
427
431
  output_hidden_states: Optional[bool] = None,
428
432
  return_dict: Optional[bool] = None,
433
+ **kwargs,
429
434
  ) -> ImageClassifierOutputWithNoAttention:
430
435
  r"""
431
436
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -470,7 +470,11 @@ class HGNetV2Backbone(HGNetV2PreTrainedModel, BackboneMixin):
470
470
 
471
471
  @auto_docstring
472
472
  def forward(
473
- self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
473
+ self,
474
+ pixel_values: Tensor,
475
+ output_hidden_states: Optional[bool] = None,
476
+ return_dict: Optional[bool] = None,
477
+ **kwargs,
474
478
  ) -> BackboneOutput:
475
479
  r"""
476
480
  Examples:
@@ -549,6 +553,7 @@ class HGNetV2ForImageClassification(HGNetV2PreTrainedModel):
549
553
  labels: Optional[torch.LongTensor] = None,
550
554
  output_hidden_states: Optional[bool] = None,
551
555
  return_dict: Optional[bool] = None,
556
+ **kwargs,
552
557
  ) -> ImageClassifierOutputWithNoAttention:
553
558
  r"""
554
559
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -848,6 +848,7 @@ class HieraModel(HieraPreTrainedModel):
848
848
  output_hidden_states: Optional[bool] = None,
849
849
  interpolate_pos_encoding: Optional[bool] = None,
850
850
  return_dict: Optional[bool] = None,
851
+ **kwargs,
851
852
  ) -> Union[tuple, BaseModelOutputWithPooling]:
852
853
  r"""
853
854
  noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*):
@@ -1132,6 +1133,7 @@ class HieraForPreTraining(HieraPreTrainedModel):
1132
1133
  output_hidden_states: Optional[bool] = None,
1133
1134
  interpolate_pos_encoding: Optional[bool] = None,
1134
1135
  return_dict: Optional[bool] = None,
1136
+ **kwargs,
1135
1137
  ) -> Union[tuple, HieraForPreTrainingOutput]:
1136
1138
  r"""
1137
1139
  noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*):
@@ -1249,6 +1251,7 @@ class HieraForImageClassification(HieraPreTrainedModel):
1249
1251
  output_hidden_states: Optional[bool] = None,
1250
1252
  interpolate_pos_encoding: Optional[bool] = None,
1251
1253
  return_dict: Optional[bool] = None,
1254
+ **kwargs,
1252
1255
  ) -> Union[tuple, HieraForImageClassificationOutput]:
1253
1256
  r"""
1254
1257
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1325,6 +1328,7 @@ class HieraBackbone(HieraPreTrainedModel, BackboneMixin):
1325
1328
  output_hidden_states: Optional[bool] = None,
1326
1329
  output_attentions: Optional[bool] = None,
1327
1330
  return_dict: Optional[bool] = None,
1331
+ **kwargs,
1328
1332
  ) -> BackboneOutput:
1329
1333
  """
1330
1334
  Returns:
@@ -892,6 +892,7 @@ class HubertModel(HubertPreTrainedModel):
892
892
  output_attentions: Optional[bool] = None,
893
893
  output_hidden_states: Optional[bool] = None,
894
894
  return_dict: Optional[bool] = None,
895
+ **kwargs,
895
896
  ) -> Union[tuple, BaseModelOutput]:
896
897
  r"""
897
898
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1038,6 +1039,7 @@ class HubertForCTC(HubertPreTrainedModel):
1038
1039
  output_hidden_states: Optional[bool] = None,
1039
1040
  return_dict: Optional[bool] = None,
1040
1041
  labels: Optional[torch.Tensor] = None,
1042
+ **kwargs,
1041
1043
  ) -> Union[tuple, CausalLMOutput]:
1042
1044
  r"""
1043
1045
  labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
@@ -1149,6 +1151,7 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
1149
1151
  output_hidden_states: Optional[bool] = None,
1150
1152
  return_dict: Optional[bool] = None,
1151
1153
  labels: Optional[torch.Tensor] = None,
1154
+ **kwargs,
1152
1155
  ) -> Union[tuple, SequenceClassifierOutput]:
1153
1156
  r"""
1154
1157
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -226,6 +226,7 @@ class HubertModel(Wav2Vec2Model, HubertPreTrainedModel):
226
226
  output_attentions: Optional[bool] = None,
227
227
  output_hidden_states: Optional[bool] = None,
228
228
  return_dict: Optional[bool] = None,
229
+ **kwargs,
229
230
  ) -> Union[tuple, BaseModelOutput]:
230
231
  r"""
231
232
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -30,7 +30,7 @@ from transformers.cache_utils import Cache
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
33
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
36
36
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41
- from ...utils.generic import check_model_inputs
41
+ from ...utils.generic import check_model_inputs, maybe_autocast
42
42
  from .configuration_hunyuan_v1_dense import HunYuanDenseV1Config
43
43
 
44
44
 
@@ -153,6 +153,7 @@ def eager_attention_forward(
153
153
  return attn_output, attn_weights
154
154
 
155
155
 
156
+ @use_kernelized_func(apply_rotary_pos_emb)
156
157
  class HunYuanDenseV1Attention(nn.Module):
157
158
  """Multi-headed attention from 'Attention Is All You Need' paper"""
158
159
 
@@ -178,7 +179,6 @@ class HunYuanDenseV1Attention(nn.Module):
178
179
  self.o_proj = nn.Linear(
179
180
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
180
181
  )
181
- self.rotary_fn = apply_rotary_pos_emb
182
182
  self.query_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
183
183
  self.key_layernorm = HunYuanDenseV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
184
184
 
@@ -359,7 +359,7 @@ class HunYuanDenseV1RotaryEmbedding(nn.Module):
359
359
  position_ids_expanded = position_ids[:, None, :].float()
360
360
 
361
361
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
362
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
362
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
363
363
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
364
364
  emb = torch.cat((freqs, freqs), dim=-1)
365
365
  cos = emb.cos() * self.attention_scaling
@@ -30,7 +30,7 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
33
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
36
36
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41
- from ...utils.generic import check_model_inputs
41
+ from ...utils.generic import check_model_inputs, maybe_autocast
42
42
  from .configuration_hunyuan_v1_moe import HunYuanMoEV1Config
43
43
 
44
44
 
@@ -152,6 +152,7 @@ def eager_attention_forward(
152
152
  return attn_output, attn_weights
153
153
 
154
154
 
155
+ @use_kernelized_func(apply_rotary_pos_emb)
155
156
  class HunYuanMoEV1Attention(nn.Module):
156
157
  """Multi-headed attention from 'Attention Is All You Need' paper"""
157
158
 
@@ -177,7 +178,6 @@ class HunYuanMoEV1Attention(nn.Module):
177
178
  self.o_proj = nn.Linear(
178
179
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
179
180
  )
180
- self.rotary_fn = apply_rotary_pos_emb
181
181
  self.query_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
182
182
  self.key_layernorm = HunYuanMoEV1RMSNorm(self.head_dim, eps=config.rms_norm_eps)
183
183
 
@@ -452,7 +452,7 @@ class HunYuanMoEV1RotaryEmbedding(nn.Module):
452
452
  position_ids_expanded = position_ids[:, None, :].float()
453
453
 
454
454
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
455
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
455
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
456
456
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
457
457
  emb = torch.cat((freqs, freqs), dim=-1)
458
458
  cos = emb.cos() * self.attention_scaling
@@ -653,6 +653,7 @@ class IBertModel(IBertPreTrainedModel):
653
653
  output_attentions: Optional[bool] = None,
654
654
  output_hidden_states: Optional[bool] = None,
655
655
  return_dict: Optional[bool] = None,
656
+ **kwargs,
656
657
  ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, tuple[torch.FloatTensor]]:
657
658
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
658
659
  output_hidden_states = (
@@ -746,6 +747,7 @@ class IBertForMaskedLM(IBertPreTrainedModel):
746
747
  output_attentions: Optional[bool] = None,
747
748
  output_hidden_states: Optional[bool] = None,
748
749
  return_dict: Optional[bool] = None,
750
+ **kwargs,
749
751
  ) -> Union[MaskedLMOutput, tuple[torch.FloatTensor]]:
750
752
  r"""
751
753
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -836,6 +838,7 @@ class IBertForSequenceClassification(IBertPreTrainedModel):
836
838
  output_attentions: Optional[bool] = None,
837
839
  output_hidden_states: Optional[bool] = None,
838
840
  return_dict: Optional[bool] = None,
841
+ **kwargs,
839
842
  ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]:
840
843
  r"""
841
844
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -916,6 +919,7 @@ class IBertForMultipleChoice(IBertPreTrainedModel):
916
919
  output_attentions: Optional[bool] = None,
917
920
  output_hidden_states: Optional[bool] = None,
918
921
  return_dict: Optional[bool] = None,
922
+ **kwargs,
919
923
  ) -> Union[MultipleChoiceModelOutput, tuple[torch.FloatTensor]]:
920
924
  r"""
921
925
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -1018,6 +1022,7 @@ class IBertForTokenClassification(IBertPreTrainedModel):
1018
1022
  output_attentions: Optional[bool] = None,
1019
1023
  output_hidden_states: Optional[bool] = None,
1020
1024
  return_dict: Optional[bool] = None,
1025
+ **kwargs,
1021
1026
  ) -> Union[TokenClassifierOutput, tuple[torch.FloatTensor]]:
1022
1027
  r"""
1023
1028
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1102,6 +1107,7 @@ class IBertForQuestionAnswering(IBertPreTrainedModel):
1102
1107
  output_attentions: Optional[bool] = None,
1103
1108
  output_hidden_states: Optional[bool] = None,
1104
1109
  return_dict: Optional[bool] = None,
1110
+ **kwargs,
1105
1111
  ) -> Union[QuestionAnsweringModelOutput, tuple[torch.FloatTensor]]:
1106
1112
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1107
1113
 
@@ -1107,31 +1107,15 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
1107
1107
  bias=False,
1108
1108
  partially_freeze=config.freeze_lm_head,
1109
1109
  )
1110
+ if config.additional_vocab_size > 0:
1111
+ self._tied_weights_keys = {
1112
+ "lm_head.weight": "model.embed_tokens.weight",
1113
+ "lm_head.additional_fc.weight": "model.embed_tokens.additional_embedding.weight",
1114
+ }
1110
1115
 
1111
1116
  # Initialize weights and apply final processing
1112
1117
  self.post_init()
1113
1118
 
1114
- def tie_weights(self, **kwargs):
1115
- """
1116
- Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of
1117
- IdeficsDecoupledLinear and IdeficsDecoupledEmbedding.
1118
- """
1119
- output_embeddings = self.get_output_embeddings()
1120
- input_embeddings = self.get_input_embeddings()
1121
-
1122
- if getattr(self.config, "tie_word_embeddings", True):
1123
- output_embeddings.weight = input_embeddings.weight
1124
- if input_embeddings.num_additional_embeddings > 0:
1125
- assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
1126
- output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
1127
-
1128
- if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
1129
- output_embeddings.out_features = input_embeddings.num_embeddings
1130
- if hasattr(output_embeddings, "out_additional_features") and hasattr(
1131
- input_embeddings, "num_additional_embeddings"
1132
- ):
1133
- output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
1134
-
1135
1119
  @can_return_tuple
1136
1120
  @auto_docstring
1137
1121
  def forward(
@@ -38,6 +38,7 @@ from ...utils import (
38
38
  logging,
39
39
  torch_float,
40
40
  )
41
+ from ...utils.generic import maybe_autocast
41
42
  from .configuration_imagegpt import ImageGPTConfig
42
43
 
43
44
 
@@ -150,7 +151,7 @@ class ImageGPTAttention(nn.Module):
150
151
  scale_factor /= float(self.layer_idx + 1)
151
152
 
152
153
  # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
153
- with torch.autocast(query.device.type, enabled=False):
154
+ with maybe_autocast(query.device.type, enabled=False):
154
155
  q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
155
156
  attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
156
157
  attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
@@ -879,6 +879,7 @@ class InformerEncoder(InformerPreTrainedModel):
879
879
  output_attentions: Optional[bool] = None,
880
880
  output_hidden_states: Optional[bool] = None,
881
881
  return_dict: Optional[bool] = None,
882
+ **kwargs,
882
883
  ) -> Union[tuple, BaseModelOutput]:
883
884
  r"""
884
885
  Args:
@@ -998,6 +999,7 @@ class InformerDecoder(InformerPreTrainedModel):
998
999
  output_hidden_states: Optional[bool] = None,
999
1000
  return_dict: Optional[bool] = None,
1000
1001
  cache_position: Optional[torch.LongTensor] = None,
1002
+ **kwargs,
1001
1003
  ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
1002
1004
  r"""
1003
1005
  Args:
@@ -1296,6 +1298,7 @@ class InformerModel(InformerPreTrainedModel):
1296
1298
  use_cache: Optional[bool] = None,
1297
1299
  return_dict: Optional[bool] = None,
1298
1300
  cache_position: Optional[torch.LongTensor] = None,
1301
+ **kwargs,
1299
1302
  ) -> Union[Seq2SeqTSModelOutput, tuple]:
1300
1303
  r"""
1301
1304
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
@@ -1573,6 +1576,7 @@ class InformerForPrediction(InformerPreTrainedModel):
1573
1576
  use_cache: Optional[bool] = None,
1574
1577
  return_dict: Optional[bool] = None,
1575
1578
  cache_position: Optional[torch.LongTensor] = None,
1579
+ **kwargs,
1576
1580
  ) -> Union[Seq2SeqTSModelOutput, tuple]:
1577
1581
  r"""
1578
1582
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
@@ -415,6 +415,7 @@ class InformerEncoder(TimeSeriesTransformerEncoder):
415
415
  output_attentions: Optional[bool] = None,
416
416
  output_hidden_states: Optional[bool] = None,
417
417
  return_dict: Optional[bool] = None,
418
+ **kwargs,
418
419
  ) -> Union[tuple, BaseModelOutput]:
419
420
  r"""
420
421
  Args:
@@ -208,7 +208,7 @@ class InternVLVisionPatchEmbeddings(nn.Module):
208
208
  "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
209
209
  )
210
210
 
211
- embeddings = self.projection(pixel_values)
211
+ embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
212
212
  patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
213
213
  embeddings = embeddings.flatten(2).transpose(1, 2)
214
214
 
@@ -449,9 +449,7 @@ class InternVLVisionModel(InternVLVisionPreTrainedModel):
449
449
  @check_model_inputs(tie_last_hidden_states=False)
450
450
  @auto_docstring
451
451
  def forward(
452
- self,
453
- pixel_values: torch.Tensor,
454
- bool_masked_pos: Optional[torch.BoolTensor] = None,
452
+ self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None, **kwargs
455
453
  ) -> Union[tuple, InternVLVisionModelOutputWithPooling]:
456
454
  r"""
457
455
  bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
@@ -176,7 +176,7 @@ class InternVLVisionPatchEmbeddings(nn.Module):
176
176
  "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
177
177
  )
178
178
 
179
- embeddings = self.projection(pixel_values)
179
+ embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
180
180
  patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
181
181
  embeddings = embeddings.flatten(2).transpose(1, 2)
182
182
 
@@ -406,9 +406,7 @@ class InternVLVisionModel(InternVLVisionPreTrainedModel):
406
406
  @check_model_inputs(tie_last_hidden_states=False)
407
407
  @auto_docstring
408
408
  def forward(
409
- self,
410
- pixel_values: torch.Tensor,
411
- bool_masked_pos: Optional[torch.BoolTensor] = None,
409
+ self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None, **kwargs
412
410
  ) -> Union[tuple, InternVLVisionModelOutputWithPooling]:
413
411
  r"""
414
412
  bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
@@ -32,7 +32,7 @@ from torch import nn
32
32
  from ... import initialization as init
33
33
  from ...activations import ACT2FN
34
34
  from ...generation import GenerationMixin
35
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
35
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
36
36
  from ...masking_utils import create_causal_mask
37
37
  from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
38
38
  from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -248,6 +248,7 @@ def eager_attention_forward(
248
248
  return attn_output, attn_weights
249
249
 
250
250
 
251
+ @use_kernelized_func(apply_rotary_pos_emb)
251
252
  class JambaAttention(nn.Module):
252
253
  """Multi-headed attention from 'Attention Is All You Need' paper"""
253
254
 
@@ -264,7 +265,6 @@ class JambaAttention(nn.Module):
264
265
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
265
266
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
266
267
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
267
- self.rotary_fn = apply_rotary_pos_emb
268
268
 
269
269
  def forward(
270
270
  self,
@@ -1007,6 +1007,7 @@ class JanusVQVAE(JanusPreTrainedModel):
1007
1007
  def forward(
1008
1008
  self,
1009
1009
  pixel_values: torch.FloatTensor,
1010
+ **kwargs,
1010
1011
  ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
1011
1012
  batch_size = pixel_values.shape[0]
1012
1013
  quant, embedding_loss, indices = self.encode(pixel_values)
@@ -823,6 +823,7 @@ class JanusVQVAE(ChameleonVQVAE):
823
823
  def forward(
824
824
  self,
825
825
  pixel_values: torch.FloatTensor,
826
+ **kwargs,
826
827
  ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
827
828
  batch_size = pixel_values.shape[0]
828
829
  quant, embedding_loss, indices = self.encode(pixel_values)