transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -100,21 +100,23 @@ class CohereTokenizer(TokenizersBackend):
100
100
  Whether or not the default system prompt for Cohere tokenizer should be used.
101
101
  add_prefix_space (`bool`, *optional*, defaults to `False`):
102
102
  Whether or not the tokenizer should automatically add a prefix space
103
- vocab (`dict`, *optional*):
103
+ vocab (`str`, `dict` or `list`, *optional*):
104
104
  Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
105
- merges (`list`, *optional*):
106
- Custom merges list. If not provided, merges are loaded from merges_file.
105
+ merges (`str` or `list[str]`, *optional*):
106
+ Custom merges list. If not provided, merges are loaded from `merges_file`.
107
107
  """
108
108
 
109
109
  vocab_files_names = VOCAB_FILES_NAMES
110
110
  pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
111
111
  padding_side = "left"
112
112
  model_input_names = ["input_ids", "attention_mask"]
113
- slow_tokenizer_class = None
113
+ model = BPE
114
114
  # No `max_model_input_sizes`
115
115
 
116
116
  def __init__(
117
117
  self,
118
+ vocab: Optional[Union[str, dict[str, int]]] = None,
119
+ merges: Optional[Union[str, list[str]]] = None,
118
120
  errors: str = "replace",
119
121
  unk_token: str = "<UNK>",
120
122
  bos_token: str = "<BOS_TOKEN>",
@@ -123,27 +125,19 @@ class CohereTokenizer(TokenizersBackend):
123
125
  cls_token: str = "<CLS>",
124
126
  sep_token: str = "<SEP>",
125
127
  mask_token: str = "<MASK_TOKEN>",
126
- add_bos_token: bool = True,
127
- add_eos_token: bool = False,
128
128
  use_default_system_prompt: bool = False,
129
129
  add_prefix_space: bool = False,
130
- vocab: Optional[dict] = None,
131
- merges: Optional[list] = None,
132
130
  **kwargs,
133
131
  ):
134
- self._add_bos_token = add_bos_token
135
- self._add_eos_token = add_eos_token
136
132
  self.use_default_system_prompt = use_default_system_prompt
137
133
  self.add_prefix_space = add_prefix_space
138
134
  self.grounded_generation_template = kwargs.pop("grounded_generation_template", None)
139
135
  self.tool_use_template = kwargs.pop("tool_use_template", None)
140
136
 
141
- if vocab is not None:
142
- self._vocab = (
143
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
144
- )
145
- else:
146
- self._vocab = {
137
+ self._vocab = (
138
+ vocab
139
+ if vocab is not None
140
+ else {
147
141
  str(pad_token): 0,
148
142
  str(unk_token): 1,
149
143
  str(cls_token): 2,
@@ -151,12 +145,9 @@ class CohereTokenizer(TokenizersBackend):
151
145
  str(mask_token): 4,
152
146
  str(bos_token): 5,
153
147
  }
148
+ )
154
149
 
155
- if merges is not None:
156
- self._merges = merges
157
- else:
158
- self._merges = []
159
-
150
+ self._merges = merges or []
160
151
  self._tokenizer = Tokenizer(
161
152
  BPE(
162
153
  vocab=self._vocab,
@@ -177,10 +168,7 @@ class CohereTokenizer(TokenizersBackend):
177
168
  )
178
169
  self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=add_prefix_space, trim_offsets=True)
179
170
 
180
- tokenizer_object = self._tokenizer
181
-
182
171
  super().__init__(
183
- tokenizer_object=tokenizer_object,
184
172
  errors=errors,
185
173
  unk_token=unk_token,
186
174
  bos_token=bos_token,
@@ -189,8 +177,6 @@ class CohereTokenizer(TokenizersBackend):
189
177
  cls_token=cls_token,
190
178
  sep_token=sep_token,
191
179
  mask_token=mask_token,
192
- add_bos_token=add_bos_token,
193
- add_eos_token=add_eos_token,
194
180
  use_default_system_prompt=use_default_system_prompt,
195
181
  add_prefix_space=add_prefix_space,
196
182
  **kwargs,
@@ -198,22 +184,6 @@ class CohereTokenizer(TokenizersBackend):
198
184
 
199
185
  self._post_init()
200
186
 
201
- def _post_init(self):
202
- """Post-initialization to ensure add_prefix_space is applied correctly."""
203
- # Re-apply add_prefix_space setting to pre_tokenizer and decoder
204
- # This is needed because when loading from pretrained, the tokenizer.json
205
- # has these settings baked in and we need to override them
206
- self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
207
- [
208
- pre_tokenizers.Digits(individual_digits=True),
209
- pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, trim_offsets=True),
210
- ]
211
- )
212
- self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=self.add_prefix_space, trim_offsets=True)
213
-
214
- # Call parent to handle AddedToken properties
215
- super()._post_init()
216
-
217
187
  def apply_tool_use_template(
218
188
  self,
219
189
  conversation: list[dict[str, str]],
@@ -28,15 +28,15 @@ import torch.nn as nn
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
+ from ...integrations import use_kernelized_func
31
32
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
32
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
33
33
  from ...modeling_layers import GradientCheckpointingLayer
34
34
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
35
35
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
36
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
37
  from ...processing_utils import Unpack
38
38
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
39
- from ...utils.generic import check_model_inputs
39
+ from ...utils.generic import check_model_inputs, maybe_autocast
40
40
  from .configuration_cohere2 import Cohere2Config
41
41
 
42
42
 
@@ -96,7 +96,7 @@ class Cohere2RotaryEmbedding(nn.Module):
96
96
  position_ids_expanded = position_ids[:, None, :].float()
97
97
 
98
98
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
99
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
99
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
100
100
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
101
101
  emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
102
102
  cos = emb.cos() * self.attention_scaling
@@ -198,6 +198,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
198
198
  return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
199
199
 
200
200
 
201
+ @use_kernelized_func(apply_rotary_pos_emb)
201
202
  class Cohere2Attention(nn.Module):
202
203
  """Multi-headed attention from 'Attention Is All You Need' paper"""
203
204
 
@@ -233,7 +234,7 @@ class Cohere2Attention(nn.Module):
233
234
  attention_mask: Optional[torch.Tensor],
234
235
  past_key_values: Optional[Cache] = None,
235
236
  cache_position: Optional[torch.LongTensor] = None,
236
- **kwargs: Unpack[FlashAttentionKwargs],
237
+ **kwargs: Unpack[TransformersKwargs],
237
238
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
238
239
  input_shape = hidden_states.shape[:-1]
239
240
  hidden_shape = (*input_shape, -1, self.head_dim)
@@ -304,7 +305,7 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer):
304
305
  past_key_values: Optional[Cache] = None,
305
306
  use_cache: Optional[bool] = False,
306
307
  cache_position: Optional[torch.LongTensor] = None,
307
- **kwargs: Unpack[FlashAttentionKwargs],
308
+ **kwargs: Unpack[TransformersKwargs],
308
309
  ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
309
310
  """
310
311
  Args:
@@ -398,7 +399,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
398
399
  if inputs_embeds is None:
399
400
  inputs_embeds = self.embed_tokens(input_ids)
400
401
 
401
- if use_cache and past_key_values is None and not self.training:
402
+ if use_cache and past_key_values is None:
402
403
  past_key_values = DynamicCache(config=self.config)
403
404
 
404
405
  if cache_position is None:
@@ -22,7 +22,6 @@ import torch.nn as nn
22
22
  from ...cache_utils import Cache, DynamicCache
23
23
  from ...configuration_utils import PreTrainedConfig, layer_type_validation
24
24
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
25
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
26
25
  from ...modeling_outputs import BaseModelOutputWithPast
27
26
  from ...modeling_rope_utils import (
28
27
  RopeParameters,
@@ -31,6 +30,7 @@ from ...modeling_rope_utils import (
31
30
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
32
31
  from ...processing_utils import Unpack
33
32
  from ...utils import TransformersKwargs, logging
33
+ from ...utils.generic import maybe_autocast
34
34
  from ..cohere.modeling_cohere import (
35
35
  CohereAttention,
36
36
  CohereDecoderLayer,
@@ -223,7 +223,7 @@ class Cohere2RotaryEmbedding(CohereRotaryEmbedding):
223
223
  position_ids_expanded = position_ids[:, None, :].float()
224
224
 
225
225
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
226
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
226
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
227
227
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
228
228
  emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
229
229
  cos = emb.cos() * self.attention_scaling
@@ -271,7 +271,7 @@ class Cohere2Attention(CohereAttention):
271
271
  attention_mask: Optional[torch.Tensor],
272
272
  past_key_values: Optional[Cache] = None,
273
273
  cache_position: Optional[torch.LongTensor] = None,
274
- **kwargs: Unpack[FlashAttentionKwargs],
274
+ **kwargs: Unpack[TransformersKwargs],
275
275
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
276
276
  input_shape = hidden_states.shape[:-1]
277
277
  hidden_shape = (*input_shape, -1, self.head_dim)
@@ -322,7 +322,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
322
322
  past_key_values: Optional[Cache] = None,
323
323
  use_cache: Optional[bool] = False,
324
324
  cache_position: Optional[torch.LongTensor] = None,
325
- **kwargs: Unpack[FlashAttentionKwargs],
325
+ **kwargs: Unpack[TransformersKwargs],
326
326
  ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
327
327
  residual = hidden_states
328
328
  hidden_states = self.input_layernorm(hidden_states)
@@ -367,7 +367,7 @@ class Cohere2Model(Gemma2Model):
367
367
  if inputs_embeds is None:
368
368
  inputs_embeds = self.embed_tokens(input_ids)
369
369
 
370
- if use_cache and past_key_values is None and not self.training:
370
+ if use_cache and past_key_values is None:
371
371
  past_key_values = DynamicCache(config=self.config)
372
372
 
373
373
  if cache_position is None:
@@ -93,8 +93,9 @@ def get_optimal_tiled_canvas(
93
93
  patch_size_height, patch_size_width = target_tile_size # (height == width)
94
94
 
95
95
  candidate_resolutions = np.array(possible_resolutions) * patch_size_height
96
- original_size = np.stack([image_height, image_width])
97
- required_scales = candidate_resolutions / original_size
96
+ # tiles following (width, height) order to align with aspect ratio convention
97
+ tile_size = np.stack([image_width, image_height])
98
+ required_scales = candidate_resolutions / tile_size
98
99
  required_scale = np.min(required_scales, axis=-1, keepdims=True) # [n_resolutions, 1]
99
100
  if np.all(required_scale < 1):
100
101
  # We are forced to downscale, so try to minimize the amount of downscaling
@@ -103,7 +104,7 @@ def get_optimal_tiled_canvas(
103
104
  # Pick the resolution that required the least upscaling so that it most closely fits the image
104
105
  required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
105
106
  best_grid = possible_resolutions[np.argmin(required_scale)]
106
- return best_grid
107
+ return best_grid # (width, height)
107
108
 
108
109
 
109
110
  @auto_docstring
@@ -295,8 +295,9 @@ def get_optimal_tiled_canvas(
295
295
  patch_size_height, patch_size_width = target_tile_size # (height == width)
296
296
 
297
297
  candidate_resolutions = np.array(possible_resolutions) * patch_size_height
298
- original_size = np.stack([image_height, image_width])
299
- required_scales = candidate_resolutions / original_size
298
+ # tiles following (width, height) order to align with aspect ratio convention
299
+ tile_size = np.stack([image_width, image_height])
300
+ required_scales = candidate_resolutions / tile_size
300
301
  required_scale = np.min(required_scales, axis=-1, keepdims=True) # [n_resolutions, 1]
301
302
  if np.all(required_scale < 1):
302
303
  # We are forced to downscale, so try to minimize the amount of downscaling
@@ -305,7 +306,7 @@ def get_optimal_tiled_canvas(
305
306
  # Pick the resolution that required the least upscaling so that it most closely fits the image
306
307
  required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
307
308
  best_grid = possible_resolutions[np.argmin(required_scale)]
308
- return best_grid
309
+ return best_grid # (width, height)
309
310
 
310
311
 
311
312
  class Cohere2VisionFastImageProcessorKwargs(ImagesKwargs, total=False):
@@ -141,6 +141,7 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
141
141
  pixel_values: Optional[torch.Tensor] = None,
142
142
  image_grid_thw: Optional[torch.LongTensor] = None,
143
143
  cache_position: Optional[torch.LongTensor] = None,
144
+ **kwargs,
144
145
  ) -> ColQwen2ForRetrievalOutput:
145
146
  r"""
146
147
  image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
@@ -322,6 +322,7 @@ class ColQwen2ForRetrieval(ColPaliForRetrieval):
322
322
  pixel_values: Optional[torch.Tensor] = None,
323
323
  image_grid_thw: Optional[torch.LongTensor] = None,
324
324
  cache_position: Optional[torch.LongTensor] = None,
325
+ **kwargs,
325
326
  ) -> ColQwen2ForRetrievalOutput:
326
327
  r"""
327
328
  image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
@@ -1032,6 +1032,7 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
1032
1032
  output_attentions=None,
1033
1033
  output_hidden_states=None,
1034
1034
  return_dict=None,
1035
+ **kwargs,
1035
1036
  ):
1036
1037
  r"""
1037
1038
  Args:
@@ -1156,6 +1157,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
1156
1157
  output_attentions=None,
1157
1158
  output_hidden_states=None,
1158
1159
  return_dict=None,
1160
+ **kwargs,
1159
1161
  ):
1160
1162
  r"""
1161
1163
  Args:
@@ -1344,6 +1346,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
1344
1346
  output_attentions: Optional[bool] = None,
1345
1347
  output_hidden_states: Optional[bool] = None,
1346
1348
  return_dict: Optional[bool] = None,
1349
+ **kwargs,
1347
1350
  ) -> Union[tuple[torch.FloatTensor], ConditionalDetrModelOutput]:
1348
1351
  r"""
1349
1352
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1529,6 +1532,7 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
1529
1532
  output_attentions: Optional[bool] = None,
1530
1533
  output_hidden_states: Optional[bool] = None,
1531
1534
  return_dict: Optional[bool] = None,
1535
+ **kwargs,
1532
1536
  ) -> Union[tuple[torch.FloatTensor], ConditionalDetrObjectDetectionOutput]:
1533
1537
  r"""
1534
1538
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1693,6 +1697,7 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
1693
1697
  output_attentions: Optional[bool] = None,
1694
1698
  output_hidden_states: Optional[bool] = None,
1695
1699
  return_dict: Optional[bool] = None,
1700
+ **kwargs,
1696
1701
  ) -> Union[tuple[torch.FloatTensor], ConditionalDetrSegmentationOutput]:
1697
1702
  r"""
1698
1703
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -629,6 +629,7 @@ class ConvBertModel(ConvBertPreTrainedModel):
629
629
  output_attentions: Optional[bool] = None,
630
630
  output_hidden_states: Optional[bool] = None,
631
631
  return_dict: Optional[bool] = None,
632
+ **kwargs,
632
633
  ) -> Union[tuple, BaseModelOutputWithCrossAttentions]:
633
634
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
634
635
  output_hidden_states = (
@@ -729,6 +730,7 @@ class ConvBertForMaskedLM(ConvBertPreTrainedModel):
729
730
  output_attentions: Optional[bool] = None,
730
731
  output_hidden_states: Optional[bool] = None,
731
732
  return_dict: Optional[bool] = None,
733
+ **kwargs,
732
734
  ) -> Union[tuple, MaskedLMOutput]:
733
735
  r"""
734
736
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -824,6 +826,7 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
824
826
  output_attentions: Optional[bool] = None,
825
827
  output_hidden_states: Optional[bool] = None,
826
828
  return_dict: Optional[bool] = None,
829
+ **kwargs,
827
830
  ) -> Union[tuple, SequenceClassifierOutput]:
828
831
  r"""
829
832
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -906,6 +909,7 @@ class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
906
909
  output_attentions: Optional[bool] = None,
907
910
  output_hidden_states: Optional[bool] = None,
908
911
  return_dict: Optional[bool] = None,
912
+ **kwargs,
909
913
  ) -> Union[tuple, MultipleChoiceModelOutput]:
910
914
  r"""
911
915
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -1013,6 +1017,7 @@ class ConvBertForTokenClassification(ConvBertPreTrainedModel):
1013
1017
  output_attentions: Optional[bool] = None,
1014
1018
  output_hidden_states: Optional[bool] = None,
1015
1019
  return_dict: Optional[bool] = None,
1020
+ **kwargs,
1016
1021
  ) -> Union[tuple, TokenClassifierOutput]:
1017
1022
  r"""
1018
1023
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1078,6 +1083,7 @@ class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
1078
1083
  output_attentions: Optional[bool] = None,
1079
1084
  output_hidden_states: Optional[bool] = None,
1080
1085
  return_dict: Optional[bool] = None,
1086
+ **kwargs,
1081
1087
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1082
1088
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1083
1089
 
@@ -268,7 +268,7 @@ class ConvNextModel(ConvNextPreTrainedModel):
268
268
  @can_return_tuple
269
269
  @auto_docstring
270
270
  def forward(
271
- self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None
271
+ self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, **kwargs
272
272
  ) -> BaseModelOutputWithPoolingAndNoAttention:
273
273
  if output_hidden_states is None:
274
274
  output_hidden_states = self.config.output_hidden_states
@@ -370,9 +370,7 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
370
370
  @can_return_tuple
371
371
  @auto_docstring
372
372
  def forward(
373
- self,
374
- pixel_values: torch.Tensor,
375
- output_hidden_states: Optional[bool] = None,
373
+ self, pixel_values: torch.Tensor, output_hidden_states: Optional[bool] = None, **kwargs
376
374
  ) -> BackboneOutput:
377
375
  r"""
378
376
  Examples:
@@ -289,7 +289,7 @@ class ConvNextV2Model(ConvNextV2PreTrainedModel):
289
289
  @can_return_tuple
290
290
  @auto_docstring
291
291
  def forward(
292
- self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None
292
+ self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, **kwargs
293
293
  ) -> BaseModelOutputWithPoolingAndNoAttention:
294
294
  if output_hidden_states is None:
295
295
  output_hidden_states = self.config.output_hidden_states
@@ -393,9 +393,7 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
393
393
  @can_return_tuple
394
394
  @auto_docstring
395
395
  def forward(
396
- self,
397
- pixel_values: torch.Tensor,
398
- output_hidden_states: Optional[bool] = None,
396
+ self, pixel_values: torch.Tensor, output_hidden_states: Optional[bool] = None, **kwargs
399
397
  ) -> BackboneOutput:
400
398
  r"""
401
399
  Examples:
@@ -32,7 +32,7 @@ from ... import initialization as init
32
32
  from ...activations import ACT2FN
33
33
  from ...cache_utils import Cache, DynamicCache
34
34
  from ...generation import GenerationMixin
35
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
35
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
36
36
  from ...masking_utils import create_causal_mask
37
37
  from ...modeling_layers import GradientCheckpointingLayer
38
38
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -40,6 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
42
  from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
43
+ from ...utils.generic import maybe_autocast
43
44
  from ...utils.import_utils import is_torchdynamo_compiling
44
45
  from ..auto import AutoModel
45
46
  from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
@@ -174,7 +175,7 @@ class CsmRotaryEmbedding(nn.Module):
174
175
  position_ids_expanded = position_ids[:, None, :].float()
175
176
 
176
177
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
177
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
178
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
178
179
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
179
180
  emb = torch.cat((freqs, freqs), dim=-1)
180
181
  cos = emb.cos() * self.attention_scaling
@@ -272,6 +273,7 @@ def eager_attention_forward(
272
273
  return attn_output, attn_weights
273
274
 
274
275
 
276
+ @use_kernelized_func(apply_rotary_pos_emb)
275
277
  class CsmAttention(nn.Module):
276
278
  """Multi-headed attention from 'Attention Is All You Need' paper"""
277
279
 
@@ -297,7 +299,6 @@ class CsmAttention(nn.Module):
297
299
  self.o_proj = nn.Linear(
298
300
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
299
301
  )
300
- self.rotary_fn = apply_rotary_pos_emb
301
302
 
302
303
  def forward(
303
304
  self,
@@ -534,6 +534,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
534
534
  output_attentions: Optional[bool] = None,
535
535
  output_hidden_states: Optional[bool] = None,
536
536
  return_dict: Optional[bool] = None,
537
+ **kwargs,
537
538
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
538
539
  r"""
539
540
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -523,6 +523,7 @@ class CvtModel(CvtPreTrainedModel):
523
523
  pixel_values: Optional[torch.Tensor] = None,
524
524
  output_hidden_states: Optional[bool] = None,
525
525
  return_dict: Optional[bool] = None,
526
+ **kwargs,
526
527
  ) -> Union[tuple, BaseModelOutputWithCLSToken]:
527
528
  output_hidden_states = (
528
529
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -577,6 +578,7 @@ class CvtForImageClassification(CvtPreTrainedModel):
577
578
  labels: Optional[torch.Tensor] = None,
578
579
  output_hidden_states: Optional[bool] = None,
579
580
  return_dict: Optional[bool] = None,
581
+ **kwargs,
580
582
  ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
581
583
  r"""
582
584
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -28,7 +28,7 @@ from torch import nn
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
32
32
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
33
33
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
34
34
  from ...modeling_layers import GradientCheckpointingLayer
@@ -37,7 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
37
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
38
  from ...processing_utils import Unpack
39
39
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
40
- from ...utils.generic import check_model_inputs
40
+ from ...utils.generic import check_model_inputs, maybe_autocast
41
41
  from .configuration_cwm import CwmConfig
42
42
 
43
43
 
@@ -97,7 +97,7 @@ class CwmRotaryEmbedding(nn.Module):
97
97
  position_ids_expanded = position_ids[:, None, :].float()
98
98
 
99
99
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
100
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
100
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
101
101
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
102
102
  emb = torch.cat((freqs, freqs), dim=-1)
103
103
  cos = emb.cos() * self.attention_scaling
@@ -179,6 +179,7 @@ def eager_attention_forward(
179
179
  return attn_output, attn_weights
180
180
 
181
181
 
182
+ @use_kernelized_func(apply_rotary_pos_emb)
182
183
  class CwmAttention(nn.Module):
183
184
  """Multi-headed attention from 'Attention Is All You Need' paper"""
184
185
 
@@ -196,7 +197,6 @@ class CwmAttention(nn.Module):
196
197
  self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
197
198
  self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
198
199
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
199
- self.rotary_fn = apply_rotary_pos_emb
200
200
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
201
201
 
202
202
  def forward(
@@ -681,6 +681,7 @@ class DFineDecoder(DFinePreTrainedModel):
681
681
  memory_mask=None,
682
682
  output_attentions=None,
683
683
  return_dict=None,
684
+ **kwargs,
684
685
  ) -> DFineDecoderOutput:
685
686
  r"""
686
687
  Args:
@@ -1247,6 +1248,7 @@ class DFineModel(DFinePreTrainedModel):
1247
1248
  output_attentions: Optional[bool] = None,
1248
1249
  output_hidden_states: Optional[bool] = None,
1249
1250
  return_dict: Optional[bool] = None,
1251
+ **kwargs,
1250
1252
  ) -> Union[tuple[torch.FloatTensor], DFineModelOutput]:
1251
1253
  r"""
1252
1254
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -726,6 +726,7 @@ class DFineDecoder(RTDetrDecoder):
726
726
  memory_mask=None,
727
727
  output_attentions=None,
728
728
  return_dict=None,
729
+ **kwargs,
729
730
  ) -> DFineDecoderOutput:
730
731
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
731
732
  output_hidden_states = (
@@ -886,6 +886,7 @@ class DabDetrEncoder(DabDetrPreTrainedModel):
886
886
  output_attentions: Optional[bool] = None,
887
887
  output_hidden_states: Optional[bool] = None,
888
888
  return_dict: Optional[bool] = None,
889
+ **kwargs,
889
890
  ):
890
891
  r"""
891
892
  Args:
@@ -1016,6 +1017,7 @@ class DabDetrDecoder(DabDetrPreTrainedModel):
1016
1017
  output_attentions: Optional[bool] = None,
1017
1018
  output_hidden_states: Optional[bool] = None,
1018
1019
  return_dict: Optional[bool] = None,
1020
+ **kwargs,
1019
1021
  ):
1020
1022
  r"""
1021
1023
  Args:
@@ -1222,6 +1224,7 @@ class DabDetrModel(DabDetrPreTrainedModel):
1222
1224
  output_attentions: Optional[bool] = None,
1223
1225
  output_hidden_states: Optional[bool] = None,
1224
1226
  return_dict: Optional[bool] = None,
1227
+ **kwargs,
1225
1228
  ) -> Union[tuple[torch.FloatTensor], DabDetrModelOutput]:
1226
1229
  r"""
1227
1230
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1469,6 +1472,7 @@ class DabDetrForObjectDetection(DabDetrPreTrainedModel):
1469
1472
  output_attentions: Optional[bool] = None,
1470
1473
  output_hidden_states: Optional[bool] = None,
1471
1474
  return_dict: Optional[bool] = None,
1475
+ **kwargs,
1472
1476
  ) -> Union[tuple[torch.FloatTensor], DabDetrObjectDetectionOutput]:
1473
1477
  r"""
1474
1478
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -264,7 +264,7 @@ class DacDecoderBlock(nn.Module):
264
264
  return hidden_state
265
265
 
266
266
 
267
- class DacResidualVectorQuantize(nn.Module):
267
+ class DacResidualVectorQuantizer(nn.Module):
268
268
  """
269
269
  ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://huggingface.co/papers/2107.03312)
270
270
  """
@@ -568,7 +568,7 @@ class DacModel(DacPreTrainedModel):
568
568
  self.encoder = DacEncoder(config)
569
569
  self.decoder = DacDecoder(config)
570
570
 
571
- self.quantizer = DacResidualVectorQuantize(config)
571
+ self.quantizer = DacResidualVectorQuantizer(config)
572
572
 
573
573
  self.bits_per_codebook = int(math.log2(self.config.codebook_size))
574
574
  if 2**self.bits_per_codebook != self.config.codebook_size:
@@ -754,6 +754,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
754
754
  output_attentions: Optional[bool] = None,
755
755
  output_hidden_states: Optional[bool] = None,
756
756
  return_dict: Optional[bool] = None,
757
+ **kwargs,
757
758
  ) -> Union[tuple, Data2VecAudioBaseModelOutput]:
758
759
  r"""
759
760
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -856,6 +857,7 @@ class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel):
856
857
  output_hidden_states: Optional[bool] = None,
857
858
  return_dict: Optional[bool] = None,
858
859
  labels: Optional[torch.Tensor] = None,
860
+ **kwargs,
859
861
  ) -> Union[tuple, CausalLMOutput]:
860
862
  r"""
861
863
  labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
@@ -967,6 +969,7 @@ class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
967
969
  output_hidden_states: Optional[bool] = None,
968
970
  return_dict: Optional[bool] = None,
969
971
  labels: Optional[torch.Tensor] = None,
972
+ **kwargs,
970
973
  ) -> Union[tuple, SequenceClassifierOutput]:
971
974
  r"""
972
975
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1070,6 +1073,7 @@ class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
1070
1073
  output_attentions: Optional[bool] = None,
1071
1074
  output_hidden_states: Optional[bool] = None,
1072
1075
  return_dict: Optional[bool] = None,
1076
+ **kwargs,
1073
1077
  ) -> Union[tuple, TokenClassifierOutput]:
1074
1078
  r"""
1075
1079
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1241,6 +1245,7 @@ class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel):
1241
1245
  output_hidden_states: Optional[bool] = None,
1242
1246
  return_dict: Optional[bool] = None,
1243
1247
  labels: Optional[torch.Tensor] = None,
1248
+ **kwargs,
1244
1249
  ) -> Union[tuple, XVectorOutput]:
1245
1250
  r"""
1246
1251
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):