transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -216,7 +216,7 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
216
216
  "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
217
217
  )
218
218
 
219
- embeddings = self.projection(pixel_values)
219
+ embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
220
220
  patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
221
221
  embeddings = embeddings.flatten(2).transpose(1, 2)
222
222
 
@@ -741,6 +741,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
741
741
  output_hidden_states: Optional[bool] = None,
742
742
  interpolate_pos_encoding: bool = False,
743
743
  return_dict: Optional[bool] = None,
744
+ **kwargs,
744
745
  ) -> Union[tuple, Data2VecVisionModelOutputWithPooling]:
745
746
  r"""
746
747
  bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
@@ -828,6 +829,7 @@ class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
828
829
  output_hidden_states: Optional[bool] = None,
829
830
  interpolate_pos_encoding: bool = False,
830
831
  return_dict: Optional[bool] = None,
832
+ **kwargs,
831
833
  ) -> Union[tuple, ImageClassifierOutput]:
832
834
  r"""
833
835
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1173,6 +1175,7 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
1173
1175
  output_hidden_states: Optional[bool] = None,
1174
1176
  interpolate_pos_encoding: bool = False,
1175
1177
  return_dict: Optional[bool] = None,
1178
+ **kwargs,
1176
1179
  ) -> Union[tuple, SemanticSegmenterOutput]:
1177
1180
  r"""
1178
1181
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -37,7 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
37
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
38
  from ...processing_utils import Unpack
39
39
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
40
- from ...utils.generic import check_model_inputs
40
+ from ...utils.generic import check_model_inputs, maybe_autocast
41
41
  from .configuration_dbrx import DbrxConfig
42
42
 
43
43
 
@@ -97,7 +97,7 @@ class DbrxRotaryEmbedding(nn.Module):
97
97
  position_ids_expanded = position_ids[:, None, :].float()
98
98
 
99
99
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
100
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
100
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
101
101
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
102
102
  emb = torch.cat((freqs, freqs), dim=-1)
103
103
  cos = emb.cos() * self.attention_scaling
@@ -655,6 +655,7 @@ class DebertaModel(DebertaPreTrainedModel):
655
655
  output_attentions: Optional[bool] = None,
656
656
  output_hidden_states: Optional[bool] = None,
657
657
  return_dict: Optional[bool] = None,
658
+ **kwargs,
658
659
  ) -> Union[tuple, BaseModelOutput]:
659
660
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
660
661
  output_hidden_states = (
@@ -860,6 +861,7 @@ class DebertaForMaskedLM(DebertaPreTrainedModel):
860
861
  output_attentions: Optional[bool] = None,
861
862
  output_hidden_states: Optional[bool] = None,
862
863
  return_dict: Optional[bool] = None,
864
+ **kwargs,
863
865
  ) -> Union[tuple, MaskedLMOutput]:
864
866
  r"""
865
867
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -969,6 +971,7 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
969
971
  output_attentions: Optional[bool] = None,
970
972
  output_hidden_states: Optional[bool] = None,
971
973
  return_dict: Optional[bool] = None,
974
+ **kwargs,
972
975
  ) -> Union[tuple, SequenceClassifierOutput]:
973
976
  r"""
974
977
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1063,6 +1066,7 @@ class DebertaForTokenClassification(DebertaPreTrainedModel):
1063
1066
  output_attentions: Optional[bool] = None,
1064
1067
  output_hidden_states: Optional[bool] = None,
1065
1068
  return_dict: Optional[bool] = None,
1069
+ **kwargs,
1066
1070
  ) -> Union[tuple, TokenClassifierOutput]:
1067
1071
  r"""
1068
1072
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1125,6 +1129,7 @@ class DebertaForQuestionAnswering(DebertaPreTrainedModel):
1125
1129
  output_attentions: Optional[bool] = None,
1126
1130
  output_hidden_states: Optional[bool] = None,
1127
1131
  return_dict: Optional[bool] = None,
1132
+ **kwargs,
1128
1133
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1129
1134
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1130
1135
 
@@ -14,6 +14,8 @@
14
14
  # limitations under the License.
15
15
  """Fast Tokenization class for model DeBERTa."""
16
16
 
17
+ from typing import Optional, Union
18
+
17
19
  from tokenizers import AddedToken, Tokenizer, decoders, pre_tokenizers, processors
18
20
  from tokenizers.models import BPE
19
21
 
@@ -93,12 +95,12 @@ class DebertaTokenizer(TokenizersBackend):
93
95
 
94
96
  vocab_files_names = VOCAB_FILES_NAMES
95
97
  model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
98
+ model = BPE
96
99
 
97
100
  def __init__(
98
101
  self,
99
- vocab_file=None,
100
- vocab=None,
101
- merges=None,
102
+ vocab: Optional[Union[str, dict[str, int]]] = None,
103
+ merges: Optional[Union[str, list[str]]] = None,
102
104
  errors="replace",
103
105
  bos_token="[CLS]",
104
106
  eos_token="[SEP]",
@@ -110,26 +112,21 @@ class DebertaTokenizer(TokenizersBackend):
110
112
  add_prefix_space=False,
111
113
  **kwargs,
112
114
  ):
113
- self.vocab_file = vocab_file
114
115
  self.add_prefix_space = add_prefix_space
115
116
 
116
- if vocab is not None:
117
- self._vocab = (
118
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
119
- )
120
- else:
121
- self._vocab = {
117
+ self._vocab = (
118
+ vocab
119
+ if vocab is not None
120
+ else {
122
121
  str(unk_token): 0,
123
122
  str(cls_token): 1,
124
123
  str(sep_token): 2,
125
124
  str(pad_token): 3,
126
125
  str(mask_token): 4,
127
126
  }
127
+ )
128
128
 
129
- if merges is not None and isinstance(merges, list) and len(merges) > 0:
130
- self._merges = [tuple(m) if isinstance(m, list) else m for m in merges]
131
- else:
132
- self._merges = []
129
+ self._merges = merges or []
133
130
 
134
131
  self._tokenizer = Tokenizer(
135
132
  BPE(
@@ -148,10 +145,7 @@ class DebertaTokenizer(TokenizersBackend):
148
145
  self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
149
146
  self._tokenizer.decoder = decoders.ByteLevel()
150
147
 
151
- tokenizer_object = self._tokenizer
152
-
153
148
  super().__init__(
154
- tokenizer_object=tokenizer_object,
155
149
  errors=errors,
156
150
  bos_token=bos_token,
157
151
  eos_token=eos_token,
@@ -163,7 +157,6 @@ class DebertaTokenizer(TokenizersBackend):
163
157
  add_prefix_space=add_prefix_space,
164
158
  **kwargs,
165
159
  )
166
-
167
160
  self._tokenizer.post_processor = processors.TemplateProcessing(
168
161
  single=f"{self.cls_token} $A {self.sep_token}",
169
162
  pair=f"{self.cls_token} $A {self.sep_token} {self.sep_token} $B {self.sep_token}",
@@ -173,8 +166,6 @@ class DebertaTokenizer(TokenizersBackend):
173
166
  ],
174
167
  )
175
168
 
176
- self._post_init()
177
-
178
169
  @property
179
170
  def mask_token(self) -> str:
180
171
  """
@@ -732,6 +732,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
732
732
  output_attentions: Optional[bool] = None,
733
733
  output_hidden_states: Optional[bool] = None,
734
734
  return_dict: Optional[bool] = None,
735
+ **kwargs,
735
736
  ) -> Union[tuple, BaseModelOutput]:
736
737
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
737
738
  output_hidden_states = (
@@ -936,6 +937,7 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
936
937
  output_attentions: Optional[bool] = None,
937
938
  output_hidden_states: Optional[bool] = None,
938
939
  return_dict: Optional[bool] = None,
940
+ **kwargs,
939
941
  ) -> Union[tuple, MaskedLMOutput]:
940
942
  r"""
941
943
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1047,6 +1049,7 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
1047
1049
  output_attentions: Optional[bool] = None,
1048
1050
  output_hidden_states: Optional[bool] = None,
1049
1051
  return_dict: Optional[bool] = None,
1052
+ **kwargs,
1050
1053
  ) -> Union[tuple, SequenceClassifierOutput]:
1051
1054
  r"""
1052
1055
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1142,6 +1145,7 @@ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
1142
1145
  output_attentions: Optional[bool] = None,
1143
1146
  output_hidden_states: Optional[bool] = None,
1144
1147
  return_dict: Optional[bool] = None,
1148
+ **kwargs,
1145
1149
  ) -> Union[tuple, TokenClassifierOutput]:
1146
1150
  r"""
1147
1151
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1205,6 +1209,7 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
1205
1209
  output_attentions: Optional[bool] = None,
1206
1210
  output_hidden_states: Optional[bool] = None,
1207
1211
  return_dict: Optional[bool] = None,
1212
+ **kwargs,
1208
1213
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1209
1214
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1210
1215
 
@@ -1293,6 +1298,7 @@ class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
1293
1298
  output_attentions: Optional[bool] = None,
1294
1299
  output_hidden_states: Optional[bool] = None,
1295
1300
  return_dict: Optional[bool] = None,
1301
+ **kwargs,
1296
1302
  ) -> Union[tuple, MultipleChoiceModelOutput]:
1297
1303
  r"""
1298
1304
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -14,6 +14,8 @@
14
14
  # limitations under the License.
15
15
  """Tokenization class for model DeBERTa-v2."""
16
16
 
17
+ from typing import Optional, Union
18
+
17
19
  from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers
18
20
  from tokenizers.models import Unigram
19
21
 
@@ -26,13 +28,6 @@ logger = logging.get_logger(__name__)
26
28
  VOCAB_FILES_NAMES = {"vocab_file": "spm.model", "tokenizer_file": "tokenizer.json"}
27
29
 
28
30
 
29
- def _get_prepend_scheme(add_prefix_space: bool) -> str:
30
- if add_prefix_space:
31
- return "always"
32
- else:
33
- return "first"
34
-
35
-
36
31
  class DebertaV2Tokenizer(TokenizersBackend):
37
32
  """
38
33
  Construct a DeBERTa-v2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on Unigram tokenization.
@@ -43,7 +38,7 @@ class DebertaV2Tokenizer(TokenizersBackend):
43
38
  Args:
44
39
  vocab_file (`str`, *optional*):
45
40
  Path to the vocabulary file (SentencePiece model file). Not used directly but kept for compatibility.
46
- vocab (`list`, *optional*):
41
+ vocab (`str`, `dict` or `list`, *optional*):
47
42
  List of tuples (piece, score) for the vocabulary.
48
43
  precompiled_charsmap (`bytes`, *optional*):
49
44
  Precompiled character map for normalization.
@@ -79,11 +74,11 @@ class DebertaV2Tokenizer(TokenizersBackend):
79
74
 
80
75
  vocab_files_names = VOCAB_FILES_NAMES
81
76
  model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
77
+ model = Unigram
82
78
 
83
79
  def __init__(
84
80
  self,
85
- vocab_file=None,
86
- vocab=None,
81
+ vocab: Optional[Union[str, dict, list]] = None,
87
82
  do_lower_case=False,
88
83
  split_by_punct=False,
89
84
  bos_token="[CLS]",
@@ -94,16 +89,15 @@ class DebertaV2Tokenizer(TokenizersBackend):
94
89
  cls_token="[CLS]",
95
90
  mask_token="[MASK]",
96
91
  add_prefix_space=True,
97
- unk_id=3,
92
+ unk_id=1,
98
93
  **kwargs,
99
94
  ):
100
- self.vocab_file = vocab_file
101
95
  self.do_lower_case = do_lower_case
102
96
  self.split_by_punct = split_by_punct
103
97
  self.add_prefix_space = add_prefix_space
104
98
 
105
99
  if vocab is None:
106
- self._vocab = [
100
+ vocab = [
107
101
  (str(pad_token), 0.0),
108
102
  (str(unk_token), 0.0),
109
103
  (str(bos_token), 0.0),
@@ -112,12 +106,11 @@ class DebertaV2Tokenizer(TokenizersBackend):
112
106
  (str(cls_token), 0.0),
113
107
  (str(mask_token), 0.0),
114
108
  ]
109
+ unk_id = 1
110
+ elif isinstance(vocab, list):
111
+ unk_id = vocab.index((str(unk_token), 0.0)) if (str(unk_token), 0.0) in vocab else unk_id
115
112
 
116
- else:
117
- self._vocab = [tuple(item) if not isinstance(item, tuple) else item for item in vocab]
118
- computed_unk_id = {piece: i for i, (piece, _score) in enumerate(self._vocab)}
119
- unk_id = computed_unk_id.get(str(unk_token))
120
-
113
+ self._vocab = vocab
121
114
  self._tokenizer = Tokenizer(
122
115
  Unigram(
123
116
  self._vocab,
@@ -132,10 +125,7 @@ class DebertaV2Tokenizer(TokenizersBackend):
132
125
 
133
126
  list_normalizers.extend(
134
127
  [
135
- normalizers.Replace("\n", " "),
136
- normalizers.Replace("\r", " "),
137
- normalizers.Replace("\t", " "),
138
- normalizers.Replace(Regex(r" {2,}"), " "),
128
+ normalizers.Replace(Regex(r"\s{2,}|[\n\r\t]"), " "),
139
129
  normalizers.NFC(),
140
130
  normalizers.Strip(left=False, right=True),
141
131
  ]
@@ -146,17 +136,12 @@ class DebertaV2Tokenizer(TokenizersBackend):
146
136
  if split_by_punct:
147
137
  list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
148
138
 
149
- prepend_scheme = _get_prepend_scheme(add_prefix_space)
139
+ prepend_scheme = "always" if add_prefix_space else "first"
150
140
  list_pretokenizers.append(pre_tokenizers.Metaspace(replacement="▁", prepend_scheme=prepend_scheme))
151
141
 
152
142
  self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(list_pretokenizers)
153
-
154
143
  self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme=prepend_scheme)
155
-
156
- tokenizer_object = self._tokenizer
157
-
158
144
  super().__init__(
159
- tokenizer_object=tokenizer_object,
160
145
  bos_token=bos_token,
161
146
  eos_token=eos_token,
162
147
  unk_token=unk_token,
@@ -34,6 +34,7 @@ from ...utils import (
34
34
  auto_docstring,
35
35
  logging,
36
36
  )
37
+ from ...utils.generic import maybe_autocast
37
38
  from .configuration_decision_transformer import DecisionTransformerConfig
38
39
 
39
40
 
@@ -141,7 +142,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
141
142
  scale_factor /= float(self.layer_idx + 1)
142
143
 
143
144
  # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
144
- with torch.autocast(query.device.type, enabled=False):
145
+ with maybe_autocast(query.device.type, enabled=False):
145
146
  q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
146
147
  attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
147
148
  attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
@@ -431,6 +432,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
431
432
  output_attentions: Optional[bool] = None,
432
433
  output_hidden_states: Optional[bool] = None,
433
434
  return_dict: Optional[bool] = None,
435
+ **kwargs,
434
436
  ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
435
437
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
436
438
  output_hidden_states = (
@@ -656,6 +658,7 @@ class DecisionTransformerModel(DecisionTransformerPreTrainedModel):
656
658
  output_hidden_states: Optional[bool] = None,
657
659
  output_attentions: Optional[bool] = None,
658
660
  return_dict: Optional[bool] = None,
661
+ **kwargs,
659
662
  ) -> Union[tuple[torch.FloatTensor], DecisionTransformerOutput]:
660
663
  r"""
661
664
  states (`torch.FloatTensor` of shape `(batch_size, episode_length, state_dim)`):
@@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
40
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41
- from ...utils.generic import check_model_inputs
41
+ from ...utils.generic import check_model_inputs, maybe_autocast
42
42
  from .configuration_deepseek_v2 import DeepseekV2Config
43
43
 
44
44
 
@@ -223,7 +223,7 @@ class DeepseekV2RotaryEmbedding(nn.Module):
223
223
  position_ids_expanded = position_ids[:, None, :].float()
224
224
 
225
225
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
226
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
226
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
227
227
  freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
228
228
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation
229
229
  freqs_cis = freqs_cis * self.attention_scaling
@@ -342,7 +342,6 @@ class DeepseekV2Attention(nn.Module):
342
342
  past_key_values: Optional[Cache] = None,
343
343
  cache_position: Optional[torch.LongTensor] = None,
344
344
  position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
345
- position_ids: Optional[torch.Tensor] = None,
346
345
  **kwargs,
347
346
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
348
347
  batch_size, seq_length = hidden_states.shape[:-1]
@@ -25,6 +25,7 @@ from ...cache_utils import Cache
25
25
  from ...modeling_rope_utils import RopeParameters, dynamic_rope_update
26
26
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27
27
  from ...utils import logging
28
+ from ...utils.generic import maybe_autocast
28
29
  from ..llama.configuration_llama import LlamaConfig
29
30
  from ..llama.modeling_llama import (
30
31
  LlamaDecoderLayer,
@@ -303,7 +304,7 @@ class DeepseekV2RotaryEmbedding(LlamaRotaryEmbedding):
303
304
  position_ids_expanded = position_ids[:, None, :].float()
304
305
 
305
306
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
306
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
307
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
307
308
  freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
308
309
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation
309
310
  freqs_cis = freqs_cis * self.attention_scaling
@@ -368,7 +369,6 @@ class DeepseekV2Attention(nn.Module):
368
369
  past_key_values: Optional[Cache] = None,
369
370
  cache_position: Optional[torch.LongTensor] = None,
370
371
  position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
371
- position_ids: Optional[torch.Tensor] = None,
372
372
  **kwargs,
373
373
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
374
374
  batch_size, seq_length = hidden_states.shape[:-1]
@@ -29,7 +29,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
29
29
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
30
30
  from ...processing_utils import Unpack
31
31
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
32
- from ...utils.generic import check_model_inputs
32
+ from ...utils.generic import check_model_inputs, maybe_autocast
33
33
  from .configuration_deepseek_v3 import DeepseekV3Config
34
34
 
35
35
 
@@ -110,7 +110,7 @@ class DeepseekV3RotaryEmbedding(nn.Module):
110
110
  position_ids_expanded = position_ids[:, None, :].float()
111
111
 
112
112
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
113
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
113
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
114
114
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
115
115
  emb = torch.cat((freqs, freqs), dim=-1)
116
116
  cos = emb.cos() * self.attention_scaling
@@ -548,6 +548,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
548
548
  "hidden_states": DeepseekV3DecoderLayer,
549
549
  "attentions": DeepseekV3Attention,
550
550
  }
551
+ _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
551
552
 
552
553
  @torch.no_grad()
553
554
  def _init_weights(self, module):
@@ -304,6 +304,7 @@ class DeepseekV3DecoderLayer(LlamaDecoderLayer):
304
304
 
305
305
  class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
306
306
  _can_compile_fullgraph = False
307
+ _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
307
308
 
308
309
  @torch.no_grad()
309
310
  def _init_weights(self, module):
@@ -1036,6 +1036,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
1036
1036
  output_attentions=None,
1037
1037
  output_hidden_states=None,
1038
1038
  return_dict=None,
1039
+ **kwargs,
1039
1040
  ):
1040
1041
  r"""
1041
1042
  Args:
@@ -1151,6 +1152,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
1151
1152
  output_attentions=None,
1152
1153
  output_hidden_states=None,
1153
1154
  return_dict=None,
1155
+ **kwargs,
1154
1156
  ):
1155
1157
  r"""
1156
1158
  Args:
@@ -1468,6 +1470,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
1468
1470
  output_attentions: Optional[bool] = None,
1469
1471
  output_hidden_states: Optional[bool] = None,
1470
1472
  return_dict: Optional[bool] = None,
1473
+ **kwargs,
1471
1474
  ) -> Union[tuple[torch.FloatTensor], DeformableDetrModelOutput]:
1472
1475
  r"""
1473
1476
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1745,6 +1748,7 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
1745
1748
  output_attentions: Optional[bool] = None,
1746
1749
  output_hidden_states: Optional[bool] = None,
1747
1750
  return_dict: Optional[bool] = None,
1751
+ **kwargs,
1748
1752
  ) -> Union[tuple[torch.FloatTensor], DeformableDetrObjectDetectionOutput]:
1749
1753
  r"""
1750
1754
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -337,6 +337,7 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
337
337
  output_attentions: Optional[bool] = None,
338
338
  output_hidden_states: Optional[bool] = None,
339
339
  return_dict: Optional[bool] = None,
340
+ **kwargs,
340
341
  ) -> Union[tuple[torch.Tensor], DepthEstimatorOutput]:
341
342
  r"""
342
343
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -645,6 +645,7 @@ class DepthProModel(DepthProPreTrainedModel):
645
645
  output_attentions: Optional[bool] = None,
646
646
  output_hidden_states: Optional[bool] = None,
647
647
  return_dict: Optional[bool] = None,
648
+ **kwargs,
648
649
  ) -> Union[tuple, DepthProOutput]:
649
650
  r"""
650
651
  Examples:
@@ -1027,6 +1028,7 @@ class DepthProForDepthEstimation(DepthProPreTrainedModel):
1027
1028
  output_attentions: Optional[bool] = None,
1028
1029
  output_hidden_states: Optional[bool] = None,
1029
1030
  return_dict: Optional[bool] = None,
1031
+ **kwargs,
1030
1032
  ) -> Union[tuple[torch.Tensor], DepthProDepthEstimatorOutput]:
1031
1033
  r"""
1032
1034
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -788,6 +788,7 @@ class DetrEncoder(DetrPreTrainedModel):
788
788
  output_attentions=None,
789
789
  output_hidden_states=None,
790
790
  return_dict=None,
791
+ **kwargs,
791
792
  ):
792
793
  r"""
793
794
  Args:
@@ -905,6 +906,7 @@ class DetrDecoder(DetrPreTrainedModel):
905
906
  output_attentions=None,
906
907
  output_hidden_states=None,
907
908
  return_dict=None,
909
+ **kwargs,
908
910
  ):
909
911
  r"""
910
912
  Args:
@@ -1078,6 +1080,7 @@ class DetrModel(DetrPreTrainedModel):
1078
1080
  output_attentions: Optional[bool] = None,
1079
1081
  output_hidden_states: Optional[bool] = None,
1080
1082
  return_dict: Optional[bool] = None,
1083
+ **kwargs,
1081
1084
  ) -> Union[tuple[torch.FloatTensor], DetrModelOutput]:
1082
1085
  r"""
1083
1086
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1258,6 +1261,7 @@ class DetrForObjectDetection(DetrPreTrainedModel):
1258
1261
  output_attentions: Optional[bool] = None,
1259
1262
  output_hidden_states: Optional[bool] = None,
1260
1263
  return_dict: Optional[bool] = None,
1264
+ **kwargs,
1261
1265
  ) -> Union[tuple[torch.FloatTensor], DetrObjectDetectionOutput]:
1262
1266
  r"""
1263
1267
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1404,6 +1408,7 @@ class DetrForSegmentation(DetrPreTrainedModel):
1404
1408
  output_attentions: Optional[bool] = None,
1405
1409
  output_hidden_states: Optional[bool] = None,
1406
1410
  return_dict: Optional[bool] = None,
1411
+ **kwargs,
1407
1412
  ) -> Union[tuple[torch.FloatTensor], DetrSegmentationOutput]:
1408
1413
  r"""
1409
1414
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -27,7 +27,7 @@ from torch import nn
27
27
 
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
30
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
30
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
31
31
  from ...masking_utils import create_bidirectional_mask, create_causal_mask
32
32
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
33
33
  from ...modeling_layers import GradientCheckpointingLayer
@@ -41,6 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
41
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
42
  from ...processing_utils import Unpack
43
43
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
44
+ from ...utils.generic import maybe_autocast
44
45
  from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
45
46
  from .generation_dia import DiaGenerationMixin
46
47
 
@@ -184,7 +185,7 @@ class DiaRotaryEmbedding(nn.Module):
184
185
  position_ids_expanded = position_ids[:, None, :].float()
185
186
 
186
187
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
187
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
188
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
188
189
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
189
190
  emb = torch.cat((freqs, freqs), dim=-1)
190
191
  cos = emb.cos() * self.attention_scaling
@@ -266,6 +267,7 @@ def eager_attention_forward(
266
267
  return attn_output, attn_weights
267
268
 
268
269
 
270
+ @use_kernelized_func(apply_rotary_pos_emb)
269
271
  class DiaSelfAttention(nn.Module):
270
272
  """Multi-headed attention from 'Attention Is All You Need' paper"""
271
273
 
@@ -523,7 +525,6 @@ class DiaDecoderLayer(GradientCheckpointingLayer):
523
525
  encoder_attention_mask: Optional[torch.Tensor] = None,
524
526
  past_key_values: Optional[EncoderDecoderCache] = None,
525
527
  cache_position: Optional[torch.LongTensor] = None,
526
- position_ids: Optional[torch.LongTensor] = None,
527
528
  **kwargs,
528
529
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
529
530
  self_attn_cache = past_key_values
@@ -314,7 +314,6 @@ class DiaDecoderLayer(GradientCheckpointingLayer):
314
314
  encoder_attention_mask: Optional[torch.Tensor] = None,
315
315
  past_key_values: Optional[EncoderDecoderCache] = None,
316
316
  cache_position: Optional[torch.LongTensor] = None,
317
- position_ids: Optional[torch.LongTensor] = None,
318
317
  **kwargs,
319
318
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
320
319
  self_attn_cache = past_key_values
@@ -46,7 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
46
46
  from ...modeling_utils import PreTrainedModel
47
47
  from ...processing_utils import Unpack
48
48
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
49
- from ...utils.generic import check_model_inputs
49
+ from ...utils.generic import check_model_inputs, maybe_autocast
50
50
  from .configuration_diffllama import DiffLlamaConfig
51
51
 
52
52
 
@@ -125,7 +125,7 @@ class DiffLlamaRotaryEmbedding(nn.Module):
125
125
  position_ids_expanded = position_ids[:, None, :].float()
126
126
 
127
127
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
128
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
128
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
129
129
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
130
130
  emb = torch.cat((freqs, freqs), dim=-1)
131
131
  cos = emb.cos() * self.attention_scaling
@@ -596,6 +596,7 @@ class DinatModel(DinatPreTrainedModel):
596
596
  output_attentions: Optional[bool] = None,
597
597
  output_hidden_states: Optional[bool] = None,
598
598
  return_dict: Optional[bool] = None,
599
+ **kwargs,
599
600
  ) -> Union[tuple, DinatModelOutput]:
600
601
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
601
602
  output_hidden_states = (
@@ -668,6 +669,7 @@ class DinatForImageClassification(DinatPreTrainedModel):
668
669
  output_attentions: Optional[bool] = None,
669
670
  output_hidden_states: Optional[bool] = None,
670
671
  return_dict: Optional[bool] = None,
672
+ **kwargs,
671
673
  ) -> Union[tuple, DinatImageClassifierOutput]:
672
674
  r"""
673
675
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -740,6 +742,7 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
740
742
  output_hidden_states: Optional[bool] = None,
741
743
  output_attentions: Optional[bool] = None,
742
744
  return_dict: Optional[bool] = None,
745
+ **kwargs,
743
746
  ) -> BackboneOutput:
744
747
  r"""
745
748
  Examples:
@@ -214,7 +214,7 @@ class DINOv3ConvNextModel(DINOv3ConvNextPreTrainedModel):
214
214
  @can_return_tuple
215
215
  @auto_docstring
216
216
  def forward(
217
- self, pixel_values: torch.FloatTensor, output_hidden_states: Optional[bool] = None
217
+ self, pixel_values: torch.FloatTensor, output_hidden_states: Optional[bool] = None, **kwargs
218
218
  ) -> BaseModelOutputWithPoolingAndNoAttention:
219
219
  hidden_states = pixel_values
220
220