transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -36,7 +36,7 @@ from ...processing_utils import Unpack
36
36
  from ...pytorch_utils import compile_compatible_method_lru_cache
37
37
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
38
38
  from ...utils.backbone_utils import BackboneMixin
39
- from ...utils.generic import check_model_inputs
39
+ from ...utils.generic import check_model_inputs, maybe_autocast
40
40
  from .configuration_dinov3_vit import DINOv3ViTConfig
41
41
 
42
42
 
@@ -156,7 +156,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
156
156
  device = pixel_values.device
157
157
  device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
158
158
 
159
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
159
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
160
160
  # Although we could precompute static patch_coords from image_size and patch_size in the config,
161
161
  # the model was trained with random_scale, so it can process images of varying sizes.
162
162
  # Therefore, it's better to compute patch_coords dynamically (with lru_cache).
@@ -40,7 +40,7 @@ from ...processing_utils import Unpack
40
40
  from ...pytorch_utils import compile_compatible_method_lru_cache
41
41
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
42
42
  from ...utils.backbone_utils import BackboneMixin
43
- from ...utils.generic import check_model_inputs
43
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
44
  from .configuration_dinov3_vit import DINOv3ViTConfig
45
45
 
46
46
 
@@ -163,7 +163,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
163
163
  device = pixel_values.device
164
164
  device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
165
165
 
166
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
166
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
167
167
  # Although we could precompute static patch_coords from image_size and patch_size in the config,
168
168
  # the model was trained with random_scale, so it can process images of varying sizes.
169
169
  # Therefore, it's better to compute patch_coords dynamically (with lru_cache).
@@ -23,6 +23,19 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
23
23
  class DistilBertTokenizer(BertTokenizer):
24
24
  model_input_names = ["input_ids", "attention_mask"]
25
25
 
26
+ def __init__(self, *args, do_lower_case: bool = True, **kwargs):
27
+ """
28
+ Construct a DistilBERT tokenizer (backed by HuggingFace's tokenizers library). Based on WordPiece.
29
+
30
+ This tokenizer inherits from [`BertTokenizer`] which contains most of the main methods. Users should refer to
31
+ this superclass for more information regarding those methods.
32
+
33
+ Args:
34
+ do_lower_case (`bool`, *optional*, defaults to `True`):
35
+ Whether or not to lowercase the input when tokenizing.
36
+ """
37
+ super().__init__(*args, do_lower_case=do_lower_case, **kwargs)
38
+
26
39
 
27
40
  # DistilBertTokenizerFast is an alias for DistilBertTokenizer (since BertTokenizer is already a fast tokenizer)
28
41
  DistilBertTokenizerFast = DistilBertTokenizer
@@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
42
  from ...modeling_utils import AttentionInterface, PreTrainedModel
43
43
  from ...processing_utils import Unpack
44
44
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available
45
- from ...utils.generic import OutputRecorder, check_model_inputs
45
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
46
46
  from .configuration_doge import DogeConfig
47
47
 
48
48
 
@@ -127,7 +127,7 @@ class DogeRotaryEmbedding(nn.Module):
127
127
  position_ids_expanded = position_ids[:, None, :].float()
128
128
 
129
129
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
130
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
130
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
131
131
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
132
132
  emb = torch.cat((freqs, freqs), dim=-1)
133
133
  cos = emb.cos() * self.attention_scaling
@@ -297,7 +297,6 @@ class DogeAttention(nn.Module):
297
297
  attention_mask: Optional[torch.Tensor] = None,
298
298
  past_key_values: Optional[Cache] = None,
299
299
  cache_position: Optional[torch.LongTensor] = None,
300
- position_ids: Optional[torch.LongTensor] = None,
301
300
  **kwargs,
302
301
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
303
302
  input_shape = hidden_states.shape[:-1]
@@ -321,7 +321,6 @@ class DogeAttention(nn.Module):
321
321
  attention_mask: Optional[torch.Tensor] = None,
322
322
  past_key_values: Optional[Cache] = None,
323
323
  cache_position: Optional[torch.LongTensor] = None,
324
- position_ids: Optional[torch.LongTensor] = None,
325
324
  **kwargs,
326
325
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
327
326
  input_shape = hidden_states.shape[:-1]
@@ -837,6 +837,7 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
837
837
  output_hidden_states: Optional[bool] = None,
838
838
  interpolate_pos_encoding: bool = False,
839
839
  return_dict: Optional[bool] = None,
840
+ **kwargs,
840
841
  ) -> Union[tuple, DonutSwinModelOutput]:
841
842
  r"""
842
843
  bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
@@ -923,6 +924,7 @@ class DonutSwinForImageClassification(DonutSwinPreTrainedModel):
923
924
  output_hidden_states: Optional[bool] = None,
924
925
  interpolate_pos_encoding: bool = False,
925
926
  return_dict: Optional[bool] = None,
927
+ **kwargs,
926
928
  ) -> Union[tuple, DonutSwinImageClassifierOutput]:
927
929
  r"""
928
930
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -29,7 +29,7 @@ from ... import initialization as init
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
31
  from ...generation import GenerationMixin
32
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
32
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
33
33
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
34
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
35
35
  from ...modeling_layers import GradientCheckpointingLayer
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41
- from ...utils.generic import check_model_inputs
41
+ from ...utils.generic import check_model_inputs, maybe_autocast
42
42
  from .configuration_dots1 import Dots1Config
43
43
 
44
44
 
@@ -119,7 +119,7 @@ class Dots1RotaryEmbedding(nn.Module):
119
119
  position_ids_expanded = position_ids[:, None, :].float()
120
120
 
121
121
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
122
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
122
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
123
123
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
124
124
  emb = torch.cat((freqs, freqs), dim=-1)
125
125
  cos = emb.cos() * self.attention_scaling
@@ -201,6 +201,7 @@ def eager_attention_forward(
201
201
  return attn_output, attn_weights
202
202
 
203
203
 
204
+ @use_kernelized_func(apply_rotary_pos_emb)
204
205
  class Dots1Attention(nn.Module):
205
206
  """Multi-headed attention from 'Attention Is All You Need' paper"""
206
207
 
@@ -227,7 +228,6 @@ class Dots1Attention(nn.Module):
227
228
  self.o_proj = nn.Linear(
228
229
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
229
230
  )
230
- self.rotary_fn = apply_rotary_pos_emb
231
231
  self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
232
232
  self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
233
233
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
@@ -369,9 +369,11 @@ class Dots1MoE(nn.Module):
369
369
 
370
370
  def route_tokens_to_experts(self, router_logits):
371
371
  router_logits = router_logits.sigmoid() # main diff with deepseekv3
372
- router_logits = router_logits + self.gate.e_score_correction_bias
372
+ router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
373
373
  group_scores = (
374
- router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
374
+ router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
375
+ .topk(2, dim=-1)[0]
376
+ .sum(dim=-1)
375
377
  )
376
378
  group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
377
379
  group_mask = torch.zeros_like(group_scores)
@@ -381,7 +383,7 @@ class Dots1MoE(nn.Module):
381
383
  .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
382
384
  .reshape(-1, self.n_routed_experts)
383
385
  )
384
- scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
386
+ scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
385
387
  topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
386
388
  topk_weights = router_logits.gather(1, topk_indices)
387
389
  if self.norm_topk_prob:
@@ -467,6 +469,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
467
469
  "hidden_states": Dots1DecoderLayer,
468
470
  "attentions": Dots1Attention,
469
471
  }
472
+ _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
470
473
 
471
474
  @torch.no_grad()
472
475
  def _init_weights(self, module):
@@ -61,9 +61,11 @@ class Dots1TopkRouter(DeepseekV3TopkRouter):
61
61
  class Dots1MoE(DeepseekV3MoE):
62
62
  def route_tokens_to_experts(self, router_logits):
63
63
  router_logits = router_logits.sigmoid() # main diff with deepseekv3
64
- router_logits = router_logits + self.gate.e_score_correction_bias
64
+ router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
65
65
  group_scores = (
66
- router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1)
66
+ router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
67
+ .topk(2, dim=-1)[0]
68
+ .sum(dim=-1)
67
69
  )
68
70
  group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
69
71
  group_mask = torch.zeros_like(group_scores)
@@ -73,7 +75,7 @@ class Dots1MoE(DeepseekV3MoE):
73
75
  .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
74
76
  .reshape(-1, self.n_routed_experts)
75
77
  )
76
- scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0)
78
+ scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
77
79
  topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
78
80
  topk_weights = router_logits.gather(1, topk_indices)
79
81
  if self.norm_topk_prob:
@@ -129,6 +129,7 @@ class DPREncoder(DPRPreTrainedModel):
129
129
  output_attentions: bool = False,
130
130
  output_hidden_states: bool = False,
131
131
  return_dict: bool = False,
132
+ **kwargs,
132
133
  ) -> Union[BaseModelOutputWithPooling, tuple[Tensor, ...]]:
133
134
  outputs = self.bert_model(
134
135
  input_ids=input_ids,
@@ -181,6 +182,7 @@ class DPRSpanPredictor(DPRPreTrainedModel):
181
182
  output_attentions: bool = False,
182
183
  output_hidden_states: bool = False,
183
184
  return_dict: bool = False,
185
+ **kwargs,
184
186
  ) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
185
187
  # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
186
188
  n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
@@ -282,6 +284,7 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
282
284
  output_attentions: Optional[bool] = None,
283
285
  output_hidden_states: Optional[bool] = None,
284
286
  return_dict: Optional[bool] = None,
287
+ **kwargs,
285
288
  ) -> Union[DPRContextEncoderOutput, tuple[Tensor, ...]]:
286
289
  r"""
287
290
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -387,6 +390,7 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
387
390
  output_attentions: Optional[bool] = None,
388
391
  output_hidden_states: Optional[bool] = None,
389
392
  return_dict: Optional[bool] = None,
393
+ **kwargs,
390
394
  ) -> Union[DPRQuestionEncoderOutput, tuple[Tensor, ...]]:
391
395
  r"""
392
396
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -492,6 +496,7 @@ class DPRReader(DPRPretrainedReader):
492
496
  output_attentions: Optional[bool] = None,
493
497
  output_hidden_states: Optional[bool] = None,
494
498
  return_dict: Optional[bool] = None,
499
+ **kwargs,
495
500
  ) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
496
501
  r"""
497
502
  input_ids (`tuple[torch.LongTensor]` of shapes `(n_passages, sequence_length)`):
@@ -39,6 +39,10 @@ class DPRContextEncoderTokenizer(BertTokenizer):
39
39
 
40
40
  vocab_files_names = VOCAB_FILES_NAMES
41
41
 
42
+ def __init__(self, *args, do_lower_case=False, **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.do_lower_case = do_lower_case
45
+
42
46
 
43
47
  class DPRQuestionEncoderTokenizer(BertTokenizer):
44
48
  r"""
@@ -52,6 +56,10 @@ class DPRQuestionEncoderTokenizer(BertTokenizer):
52
56
 
53
57
  vocab_files_names = VOCAB_FILES_NAMES
54
58
 
59
+ def __init__(self, *args, do_lower_case=False, **kwargs):
60
+ super().__init__(*args, **kwargs)
61
+ self.do_lower_case = do_lower_case
62
+
55
63
 
56
64
  DPRSpanPrediction = collections.namedtuple(
57
65
  "DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"]
@@ -316,5 +324,9 @@ class DPRReaderTokenizer(CustomDPRReaderTokenizerMixin, BertTokenizer):
316
324
  vocab_files_names = VOCAB_FILES_NAMES
317
325
  model_input_names = ["input_ids", "attention_mask"]
318
326
 
327
+ def __init__(self, *args, do_lower_case=False, **kwargs):
328
+ super().__init__(*args, **kwargs)
329
+ self.do_lower_case = do_lower_case
330
+
319
331
 
320
332
  __all__ = ["DPRContextEncoderTokenizer", "DPRQuestionEncoderTokenizer", "DPRReaderOutput", "DPRReaderTokenizer"]
@@ -393,7 +393,7 @@ class EdgeTamVisionNeck(nn.Module):
393
393
  n = len(self.convs) - 1
394
394
  for i in range(n, -1, -1):
395
395
  lateral_features = hidden_states[i].permute(0, 3, 1, 2)
396
- lateral_features = self.convs[n - i](lateral_features)
396
+ lateral_features = self.convs[n - i](lateral_features.to(self.convs[i].weight.dtype))
397
397
  if i not in self.fpn_top_down_levels or i == n:
398
398
  prev_features = lateral_features
399
399
  else:
@@ -2117,6 +2117,7 @@ class EdgeTamVideoModel(EdgeTamVideoPreTrainedModel):
2117
2117
  frame_idx: Optional[int] = None,
2118
2118
  frame: Optional[torch.Tensor] = None,
2119
2119
  reverse: bool = False,
2120
+ **kwargs,
2120
2121
  ) -> EdgeTamVideoSegmentationOutput:
2121
2122
  r"""
2122
2123
  inference_session (`EdgeTamVideoInferenceSession`):
@@ -1256,6 +1256,7 @@ class EdgeTamVideoModel(Sam2VideoModel):
1256
1256
  frame_idx: Optional[int] = None,
1257
1257
  frame: Optional[torch.Tensor] = None,
1258
1258
  reverse: bool = False,
1259
+ **kwargs,
1259
1260
  ) -> EdgeTamVideoSegmentationOutput:
1260
1261
  r"""
1261
1262
  inference_session (`EdgeTamVideoInferenceSession`):
@@ -33,7 +33,7 @@ from ...utils import (
33
33
  can_return_tuple,
34
34
  torch_int,
35
35
  )
36
- from ...utils.generic import check_model_inputs
36
+ from ...utils.generic import check_model_inputs, maybe_autocast
37
37
  from .configuration_efficientloftr import EfficientLoFTRConfig
38
38
 
39
39
 
@@ -147,7 +147,7 @@ class EfficientLoFTRRotaryEmbedding(nn.Module):
147
147
  embed_height = (feats_height - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
148
148
  embed_width = (feats_width - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
149
149
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
150
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
150
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
151
151
  emb = compute_embeddings(self.inv_freq, embed_height, embed_width, self.config.hidden_size)
152
152
  sin = emb.sin()
153
153
  cos = emb.cos()
@@ -471,6 +471,7 @@ class EfficientNetModel(EfficientNetPreTrainedModel):
471
471
  pixel_values: Optional[torch.FloatTensor] = None,
472
472
  output_hidden_states: Optional[bool] = None,
473
473
  return_dict: Optional[bool] = None,
474
+ **kwargs,
474
475
  ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
475
476
  output_hidden_states = (
476
477
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -529,6 +530,7 @@ class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
529
530
  labels: Optional[torch.LongTensor] = None,
530
531
  output_hidden_states: Optional[bool] = None,
531
532
  return_dict: Optional[bool] = None,
533
+ **kwargs,
532
534
  ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
533
535
  r"""
534
536
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -33,7 +33,7 @@ from ... import initialization as init
33
33
  from ...activations import ACT2FN
34
34
  from ...cache_utils import Cache, DynamicCache
35
35
  from ...generation import GenerationMixin
36
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
36
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
37
37
  from ...masking_utils import create_causal_mask
38
38
  from ...modeling_layers import GradientCheckpointingLayer
39
39
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
41
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
42
  from ...processing_utils import Unpack
43
43
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
44
- from ...utils.generic import check_model_inputs
44
+ from ...utils.generic import check_model_inputs, maybe_autocast
45
45
  from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
46
46
 
47
47
 
@@ -118,6 +118,7 @@ def eager_attention_forward(
118
118
  return attn_output, attn_weights
119
119
 
120
120
 
121
+ @use_kernelized_func(apply_rotary_pos_emb)
121
122
  class Emu3Attention(nn.Module):
122
123
  """Multi-headed attention from 'Attention Is All You Need' paper"""
123
124
 
@@ -143,7 +144,6 @@ class Emu3Attention(nn.Module):
143
144
  self.o_proj = nn.Linear(
144
145
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
145
146
  )
146
- self.rotary_fn = apply_rotary_pos_emb
147
147
 
148
148
  def forward(
149
149
  self,
@@ -1167,7 +1167,7 @@ class Emu3RotaryEmbedding(nn.Module):
1167
1167
  position_ids_expanded = position_ids[:, None, :].float()
1168
1168
 
1169
1169
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1170
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
1170
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
1171
1171
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1172
1172
  emb = torch.cat((freqs, freqs), dim=-1)
1173
1173
  cos = emb.cos() * self.attention_scaling
@@ -815,7 +815,19 @@ class EomtImageProcessor(BaseImageProcessor):
815
815
 
816
816
  segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
817
817
 
818
- output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
818
+ if patch_offsets:
819
+ output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
820
+ else:
821
+ output_logits = []
822
+
823
+ for idx in range(len(segmentation_logits)):
824
+ resized_logits = torch.nn.functional.interpolate(
825
+ segmentation_logits[idx].unsqueeze(dim=0),
826
+ size=target_sizes[idx],
827
+ mode="bilinear",
828
+ align_corners=False,
829
+ )
830
+ output_logits.append(resized_logits[0])
819
831
 
820
832
  preds = [logit.argmax(dim=0) for logit in output_logits]
821
833
  return preds
@@ -239,7 +239,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
239
239
  for shape, stacked_images in grouped_images.items():
240
240
  if do_resize:
241
241
  stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
242
- resized_images_grouped[shape] = stacked_images
242
+ resized_images_grouped[shape] = stacked_images
243
243
  images = reorder_images(resized_images_grouped, grouped_images_index)
244
244
 
245
245
  # Group images by size for batched resizing, Needed in case do_resize is False.
@@ -385,7 +385,19 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
385
385
 
386
386
  segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
387
387
 
388
- output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
388
+ if patch_offsets:
389
+ output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
390
+ else:
391
+ output_logits = []
392
+
393
+ for idx in range(len(segmentation_logits)):
394
+ resized_logits = torch.nn.functional.interpolate(
395
+ segmentation_logits[idx].unsqueeze(dim=0),
396
+ size=target_sizes[idx],
397
+ mode="bilinear",
398
+ align_corners=False,
399
+ )
400
+ output_logits.append(resized_logits[0])
389
401
 
390
402
  preds = [logit.argmax(dim=0) for logit in output_logits]
391
403
  return preds
@@ -27,7 +27,7 @@ from torch import nn
27
27
  from ...activations import ACT2FN
28
28
  from ...cache_utils import Cache, DynamicCache
29
29
  from ...generation import GenerationMixin
30
- from ...integrations import use_kernel_forward_from_hub
30
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
31
31
  from ...masking_utils import create_causal_mask
32
32
  from ...modeling_layers import GradientCheckpointingLayer
33
33
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -35,7 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
35
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
36
  from ...processing_utils import Unpack
37
37
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
38
- from ...utils.generic import check_model_inputs
38
+ from ...utils.generic import check_model_inputs, maybe_autocast
39
39
  from .configuration_ernie4_5 import Ernie4_5Config
40
40
 
41
41
 
@@ -95,7 +95,7 @@ class Ernie4_5RotaryEmbedding(nn.Module):
95
95
  position_ids_expanded = position_ids[:, None, :].float()
96
96
 
97
97
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
98
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
98
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
99
99
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
100
100
  emb = torch.cat((freqs, freqs), dim=-1)
101
101
  cos = emb.cos() * self.attention_scaling
@@ -203,6 +203,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
203
203
  return q_embed.to(original_dtype), k_embed.to(original_dtype)
204
204
 
205
205
 
206
+ @use_kernelized_func(apply_rotary_pos_emb)
206
207
  class Ernie4_5Attention(nn.Module):
207
208
  """Multi-headed attention from 'Attention Is All You Need' paper"""
208
209
 
@@ -221,7 +222,6 @@ class Ernie4_5Attention(nn.Module):
221
222
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
222
223
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
223
224
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
224
- self.rotary_fn = apply_rotary_pos_emb
225
225
 
226
226
  def forward(
227
227
  self,
@@ -18,6 +18,7 @@ from torch import nn
18
18
 
19
19
  from ...modeling_rope_utils import dynamic_rope_update
20
20
  from ...utils import auto_docstring, can_return_tuple
21
+ from ...utils.generic import maybe_autocast
21
22
  from ..glm.modeling_glm import rotate_half
22
23
  from ..llama.modeling_llama import (
23
24
  LlamaAttention,
@@ -36,7 +37,7 @@ class Ernie4_5RotaryEmbedding(OlmoRotaryEmbedding):
36
37
  position_ids_expanded = position_ids[:, None, :].float()
37
38
 
38
39
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
39
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
40
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
40
41
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
41
42
  emb = torch.cat((freqs, freqs), dim=-1)
42
43
  cos = emb.cos() * self.attention_scaling
@@ -29,7 +29,7 @@ from ... import initialization as init
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
31
  from ...generation import GenerationMixin
32
- from ...integrations import use_kernel_forward_from_hub
32
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
33
33
  from ...masking_utils import create_causal_mask
34
34
  from ...modeling_layers import GradientCheckpointingLayer
35
35
  from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
@@ -37,7 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
37
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
38
  from ...processing_utils import Unpack
39
39
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
40
- from ...utils.generic import OutputRecorder, check_model_inputs
40
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
41
41
  from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
42
42
 
43
43
 
@@ -135,7 +135,7 @@ class Ernie4_5_MoeRotaryEmbedding(nn.Module):
135
135
  position_ids_expanded = position_ids[:, None, :].float()
136
136
 
137
137
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
138
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
138
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
139
139
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
140
140
  emb = torch.cat((freqs, freqs), dim=-1)
141
141
  cos = emb.cos() * self.attention_scaling
@@ -226,6 +226,7 @@ def eager_attention_forward(
226
226
  return attn_output, attn_weights
227
227
 
228
228
 
229
+ @use_kernelized_func(apply_rotary_pos_emb)
229
230
  class Ernie4_5_MoeAttention(nn.Module):
230
231
  """Multi-headed attention from 'Attention Is All You Need' paper"""
231
232
 
@@ -244,7 +245,6 @@ class Ernie4_5_MoeAttention(nn.Module):
244
245
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
245
246
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
246
247
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
247
- self.rotary_fn = apply_rotary_pos_emb
248
248
 
249
249
  def forward(
250
250
  self,
@@ -371,7 +371,7 @@ class Ernie4_5_MoeTopKRouter(nn.Module):
371
371
  else "cpu"
372
372
  )
373
373
 
374
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
374
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
375
375
  router_logits = F.linear(hidden_states.float(), self.weight)
376
376
  router_logits = F.softmax(router_logits, dim=1, dtype=torch.float)
377
377
  router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1)
@@ -26,7 +26,7 @@ from ...modeling_outputs import MoeModelOutputWithPast
26
26
  from ...modeling_utils import PreTrainedModel
27
27
  from ...processing_utils import Unpack
28
28
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
29
- from ...utils.generic import OutputRecorder, check_model_inputs
29
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
30
30
  from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401
31
31
  from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm
32
32
  from ..mixtral.modeling_mixtral import (
@@ -146,7 +146,7 @@ class Ernie4_5_MoeTopKRouter(nn.Module):
146
146
  else "cpu"
147
147
  )
148
148
 
149
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
149
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
150
150
  router_logits = F.linear(hidden_states.float(), self.weight)
151
151
  router_logits = F.softmax(router_logits, dim=1, dtype=torch.float)
152
152
  router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1)
@@ -32,6 +32,7 @@ from ...utils import (
32
32
  auto_docstring,
33
33
  logging,
34
34
  )
35
+ from ...utils.generic import maybe_autocast
35
36
  from .modeling_esm import EsmModel, EsmPreTrainedModel
36
37
  from .openfold_utils import (
37
38
  OFProtein,
@@ -267,7 +268,7 @@ class EsmFoldLayerNorm(nn.Module):
267
268
  def forward(self, x):
268
269
  d = x.dtype
269
270
  if d is torch.bfloat16 and not is_deepspeed_initialized():
270
- with torch.autocast(device_type="cuda", enabled=False):
271
+ with maybe_autocast(device_type="cuda", enabled=False):
271
272
  out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
272
273
  else:
273
274
  out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
@@ -282,7 +283,7 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
282
283
  """
283
284
  d = t.dtype
284
285
  if d is torch.bfloat16 and not is_deepspeed_initialized():
285
- with torch.autocast(device_type="cuda", enabled=False):
286
+ with maybe_autocast(device_type="cuda", enabled=False):
286
287
  s = torch.nn.functional.softmax(t, dim=dim)
287
288
  else:
288
289
  s = torch.nn.functional.softmax(t, dim=dim)
@@ -868,7 +869,7 @@ class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
868
869
 
869
870
  device_type = a.device.type if a.device.type != "mps" else "cpu"
870
871
  if is_fp16_enabled(device_type):
871
- with torch.autocast(device_type=device_type, enabled=False):
872
+ with maybe_autocast(device_type=device_type, enabled=False):
872
873
  x = self._combine_projections(a.float(), b.float())
873
874
  else:
874
875
  x = self._combine_projections(a, b)
@@ -1491,7 +1492,7 @@ class EsmFoldInvariantPointAttention(nn.Module):
1491
1492
  # [*, H, N_res, N_res]
1492
1493
  device_type = q.device.type if q.device.type != "mps" else "cpu"
1493
1494
  if is_fp16_enabled(device_type):
1494
- with torch.autocast(device_type=device_type, enabled=False):
1495
+ with maybe_autocast(device_type=device_type, enabled=False):
1495
1496
  a = torch.matmul(
1496
1497
  permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
1497
1498
  permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]