transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -673,6 +673,7 @@ class T5Stack(T5PreTrainedModel):
673
673
  output_hidden_states=None,
674
674
  return_dict=None,
675
675
  cache_position=None,
676
+ **kwargs,
676
677
  ):
677
678
  use_cache = use_cache if use_cache is not None else self.config.use_cache
678
679
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -879,6 +880,7 @@ class T5Model(T5PreTrainedModel):
879
880
  output_hidden_states: Optional[bool] = None,
880
881
  return_dict: Optional[bool] = None,
881
882
  cache_position: Optional[torch.LongTensor] = None,
883
+ **kwargs,
882
884
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
883
885
  r"""
884
886
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1044,6 +1046,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
1044
1046
  output_hidden_states: Optional[bool] = None,
1045
1047
  return_dict: Optional[bool] = None,
1046
1048
  cache_position: Optional[torch.LongTensor] = None,
1049
+ **kwargs,
1047
1050
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1048
1051
  r"""
1049
1052
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1209,6 +1212,7 @@ class T5EncoderModel(T5PreTrainedModel):
1209
1212
  output_attentions: Optional[bool] = None,
1210
1213
  output_hidden_states: Optional[bool] = None,
1211
1214
  return_dict: Optional[bool] = None,
1215
+ **kwargs,
1212
1216
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
1213
1217
  r"""
1214
1218
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1279,6 +1283,7 @@ class T5ForSequenceClassification(T5PreTrainedModel):
1279
1283
  output_attentions: Optional[bool] = None,
1280
1284
  output_hidden_states: Optional[bool] = None,
1281
1285
  return_dict: Optional[bool] = None,
1286
+ **kwargs,
1282
1287
  ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
1283
1288
  r"""
1284
1289
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1417,6 +1422,7 @@ class T5ForTokenClassification(T5PreTrainedModel):
1417
1422
  output_attentions: Optional[bool] = None,
1418
1423
  output_hidden_states: Optional[bool] = None,
1419
1424
  return_dict: Optional[bool] = None,
1425
+ **kwargs,
1420
1426
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1421
1427
  r"""
1422
1428
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1520,6 +1526,7 @@ class T5ForQuestionAnswering(T5PreTrainedModel):
1520
1526
  output_attentions: Optional[bool] = None,
1521
1527
  output_hidden_states: Optional[bool] = None,
1522
1528
  return_dict: Optional[bool] = None,
1529
+ **kwargs,
1523
1530
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
1524
1531
  r"""
1525
1532
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -15,6 +15,7 @@
15
15
  """Tokenization class for model T5."""
16
16
 
17
17
  import re
18
+ from typing import Optional, Union
18
19
 
19
20
  from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
20
21
  from tokenizers.models import Unigram
@@ -61,26 +62,24 @@ class T5Tokenizer(TokenizersBackend):
61
62
  calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method
62
63
  additional_special_tokens (`list[str]`, *optional*):
63
64
  Additional special tokens used by the tokenizer.
64
- vocab (`dict`, *optional*):
65
+ vocab (`str`, `dict` or `list`, *optional*):
65
66
  Custom vocabulary dict. If not provided, a minimal vocabulary is created using the special tokens.
66
67
  """
67
68
 
68
69
  vocab_files_names = VOCAB_FILES_NAMES
69
70
  model_input_names = ["input_ids", "attention_mask"]
70
- slow_tokenizer_class = None
71
+ model = Unigram
71
72
 
72
73
  def __init__(
73
74
  self,
75
+ vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
74
76
  eos_token="</s>",
75
77
  unk_token="<unk>",
76
78
  pad_token="<pad>",
77
79
  extra_ids=100,
78
80
  additional_special_tokens=None,
79
- vocab=None,
80
- vocab_file=None,
81
81
  **kwargs,
82
82
  ):
83
- self.vocab_file = vocab_file
84
83
  self._extra_ids = extra_ids
85
84
 
86
85
  # Handle extra_ids and additional_special_tokens
@@ -130,10 +129,7 @@ class T5Tokenizer(TokenizersBackend):
130
129
 
131
130
  self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
132
131
 
133
- tokenizer_object = self._tokenizer
134
-
135
132
  super().__init__(
136
- tokenizer_object=tokenizer_object,
137
133
  eos_token=eos_token,
138
134
  unk_token=unk_token,
139
135
  pad_token=pad_token,
@@ -29,7 +29,7 @@ from ... import initialization as init
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
31
31
  from ...generation import GenerationMixin
32
- from ...integrations import use_kernel_func_from_hub
32
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
33
33
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
34
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
35
35
  from ...modeling_layers import GradientCheckpointingLayer
@@ -45,7 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
45
45
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
46
  from ...processing_utils import Unpack
47
47
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
48
- from ...utils.generic import OutputRecorder, check_model_inputs
48
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
49
49
  from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
50
50
 
51
51
 
@@ -147,7 +147,7 @@ class T5GemmaRotaryEmbedding(nn.Module):
147
147
  position_ids_expanded = position_ids[:, None, :].float()
148
148
 
149
149
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
150
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
150
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
151
151
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
152
152
  emb = torch.cat((freqs, freqs), dim=-1)
153
153
  cos = emb.cos() * self.attention_scaling
@@ -238,6 +238,7 @@ def eager_attention_forward(
238
238
  return attn_output, attn_weights
239
239
 
240
240
 
241
+ @use_kernelized_func(apply_rotary_pos_emb)
241
242
  class T5GemmaSelfAttention(nn.Module):
242
243
  """Multi-headed attention from 'Attention Is All You Need' paper"""
243
244
 
@@ -265,7 +266,6 @@ class T5GemmaSelfAttention(nn.Module):
265
266
  self.o_proj = nn.Linear(
266
267
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
267
268
  )
268
- self.rotary_fn = apply_rotary_pos_emb
269
269
  self.attn_logit_softcapping = self.config.attn_logit_softcapping
270
270
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
271
271
 
@@ -315,6 +315,7 @@ class T5GemmaSelfAttention(nn.Module):
315
315
  return attn_output, attn_weights
316
316
 
317
317
 
318
+ @use_kernelized_func(apply_rotary_pos_emb)
318
319
  class T5GemmaCrossAttention(nn.Module):
319
320
  """Multi-headed attention from 'Attention Is All You Need' paper"""
320
321
 
@@ -341,7 +342,6 @@ class T5GemmaCrossAttention(nn.Module):
341
342
  self.o_proj = nn.Linear(
342
343
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
343
344
  )
344
- self.rotary_fn = apply_rotary_pos_emb
345
345
  self.attn_logit_softcapping = self.config.attn_logit_softcapping
346
346
 
347
347
  if config.cross_attention_hidden_size is None:
@@ -30,7 +30,7 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
32
32
  from ...generation import GenerationConfig, GenerationMixin, GenerationMode
33
- from ...integrations import use_kernel_func_from_hub
33
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_bidirectional_mask, create_causal_mask, create_sliding_window_causal_mask
35
35
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
36
36
  from ...modeling_layers import GradientCheckpointingLayer
@@ -46,7 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
46
46
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
47
47
  from ...processing_utils import Unpack
48
48
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
49
- from ...utils.generic import OutputRecorder, check_model_inputs
49
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
50
50
  from ..auto import AutoModel
51
51
  from .configuration_t5gemma2 import T5Gemma2Config, T5Gemma2DecoderConfig, T5Gemma2EncoderConfig, T5Gemma2TextConfig
52
52
 
@@ -162,7 +162,7 @@ class T5Gemma2RotaryEmbedding(nn.Module):
162
162
  position_ids_expanded = position_ids[:, None, :].float()
163
163
 
164
164
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
165
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
165
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
166
166
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
167
167
  emb = torch.cat((freqs, freqs), dim=-1)
168
168
  cos = emb.cos() * attention_scaling
@@ -253,6 +253,7 @@ def eager_attention_forward(
253
253
  return attn_output, attn_weights
254
254
 
255
255
 
256
+ @use_kernelized_func(apply_rotary_pos_emb)
256
257
  class T5Gemma2SelfAttention(nn.Module):
257
258
  """Multi-headed attention from 'Attention Is All You Need' paper"""
258
259
 
@@ -279,7 +280,6 @@ class T5Gemma2SelfAttention(nn.Module):
279
280
  self.o_proj = nn.Linear(
280
281
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
281
282
  )
282
- self.rotary_fn = apply_rotary_pos_emb
283
283
  self.attn_logit_softcapping = self.config.attn_logit_softcapping
284
284
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
285
285
  self.is_sliding = self.layer_type == "sliding_attention"
@@ -294,7 +294,7 @@ class T5Gemma2SelfAttention(nn.Module):
294
294
  attention_mask: Optional[torch.Tensor] = None,
295
295
  past_key_values: Optional[Cache] = None,
296
296
  cache_position: Optional[torch.LongTensor] = None,
297
- **kwargs: Unpack[FlashAttentionKwargs],
297
+ **kwargs: Unpack[TransformersKwargs],
298
298
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
299
299
  input_shape = hidden_states.shape[:-1]
300
300
  hidden_shape = (*input_shape, -1, self.head_dim)
@@ -335,6 +335,7 @@ class T5Gemma2SelfAttention(nn.Module):
335
335
  return attn_output, attn_weights
336
336
 
337
337
 
338
+ @use_kernelized_func(apply_rotary_pos_emb)
338
339
  class T5Gemma2MergedAttention(nn.Module):
339
340
  """Merged self-attention and cross-attention for decoder."""
340
341
 
@@ -361,7 +362,6 @@ class T5Gemma2MergedAttention(nn.Module):
361
362
  self.o_proj = nn.Linear(
362
363
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
363
364
  )
364
- self.rotary_fn = apply_rotary_pos_emb
365
365
  self.attn_logit_softcapping = self.config.attn_logit_softcapping
366
366
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
367
367
  self.is_sliding = self.layer_type == "sliding_attention"
@@ -749,6 +749,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
749
749
  output_attentions=None,
750
750
  output_hidden_states=None,
751
751
  return_dict=None,
752
+ **kwargs,
752
753
  ):
753
754
  r"""
754
755
  Args:
@@ -869,6 +870,7 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
869
870
  output_attentions=None,
870
871
  output_hidden_states=None,
871
872
  return_dict=None,
873
+ **kwargs,
872
874
  ):
873
875
  r"""
874
876
  Args:
@@ -1043,6 +1045,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
1043
1045
  output_attentions: Optional[bool] = None,
1044
1046
  output_hidden_states: Optional[bool] = None,
1045
1047
  return_dict: Optional[bool] = None,
1048
+ **kwargs,
1046
1049
  ) -> Union[tuple[torch.FloatTensor], TableTransformerModelOutput]:
1047
1050
  r"""
1048
1051
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -1202,6 +1205,7 @@ class TableTransformerForObjectDetection(TableTransformerPreTrainedModel):
1202
1205
  output_attentions: Optional[bool] = None,
1203
1206
  output_hidden_states: Optional[bool] = None,
1204
1207
  return_dict: Optional[bool] = None,
1208
+ **kwargs,
1205
1209
  ) -> Union[tuple[torch.FloatTensor], TableTransformerObjectDetectionOutput]:
1206
1210
  r"""
1207
1211
  decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
@@ -563,6 +563,7 @@ class TapasModel(TapasPreTrainedModel):
563
563
  output_attentions: Optional[bool] = None,
564
564
  output_hidden_states: Optional[bool] = None,
565
565
  return_dict: Optional[bool] = None,
566
+ **kwargs,
566
567
  ) -> Union[tuple, BaseModelOutputWithPooling]:
567
568
  r"""
568
569
  token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
@@ -843,6 +844,7 @@ class TapasForQuestionAnswering(TapasPreTrainedModel):
843
844
  output_attentions: Optional[bool] = None,
844
845
  output_hidden_states: Optional[bool] = None,
845
846
  return_dict: Optional[bool] = None,
847
+ **kwargs,
846
848
  ) -> Union[tuple, TableQuestionAnsweringOutput]:
847
849
  r"""
848
850
  token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
@@ -1164,6 +1166,7 @@ class TapasForSequenceClassification(TapasPreTrainedModel):
1164
1166
  output_attentions: Optional[bool] = None,
1165
1167
  output_hidden_states: Optional[bool] = None,
1166
1168
  return_dict: Optional[bool] = None,
1169
+ **kwargs,
1167
1170
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
1168
1171
  r"""
1169
1172
  token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 7)`, *optional*):
@@ -233,7 +233,11 @@ class TextNetModel(TextNetPreTrainedModel):
233
233
 
234
234
  @auto_docstring
235
235
  def forward(
236
- self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
236
+ self,
237
+ pixel_values: Tensor,
238
+ output_hidden_states: Optional[bool] = None,
239
+ return_dict: Optional[bool] = None,
240
+ **kwargs,
237
241
  ) -> Union[tuple[Any, list[Any]], tuple[Any], BaseModelOutputWithPoolingAndNoAttention]:
238
242
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
239
243
  output_hidden_states = (
@@ -288,6 +292,7 @@ class TextNetForImageClassification(TextNetPreTrainedModel):
288
292
  labels: Optional[torch.LongTensor] = None,
289
293
  output_hidden_states: Optional[bool] = None,
290
294
  return_dict: Optional[bool] = None,
295
+ **kwargs,
291
296
  ) -> ImageClassifierOutputWithNoAttention:
292
297
  r"""
293
298
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -353,7 +358,11 @@ class TextNetBackbone(TextNetPreTrainedModel, BackboneMixin):
353
358
 
354
359
  @auto_docstring
355
360
  def forward(
356
- self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
361
+ self,
362
+ pixel_values: Tensor,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ **kwargs,
357
366
  ) -> Union[tuple[tuple], BackboneOutput]:
358
367
  r"""
359
368
  Examples:
@@ -658,6 +658,7 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel):
658
658
  output_attentions: Optional[bool] = None,
659
659
  output_hidden_states: Optional[bool] = None,
660
660
  return_dict: Optional[bool] = None,
661
+ **kwargs,
661
662
  ) -> Union[tuple, BaseModelOutput]:
662
663
  r"""
663
664
  Args:
@@ -777,6 +778,7 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
777
778
  output_hidden_states: Optional[bool] = None,
778
779
  return_dict: Optional[bool] = None,
779
780
  cache_position: Optional[torch.LongTensor] = None,
781
+ **kwargs,
780
782
  ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
781
783
  r"""
782
784
  Args:
@@ -1075,6 +1077,7 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
1075
1077
  use_cache: Optional[bool] = None,
1076
1078
  return_dict: Optional[bool] = None,
1077
1079
  cache_position: Optional[torch.LongTensor] = None,
1080
+ **kwargs,
1078
1081
  ) -> Union[Seq2SeqTSModelOutput, tuple]:
1079
1082
  r"""
1080
1083
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
@@ -1320,6 +1323,7 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
1320
1323
  use_cache: Optional[bool] = None,
1321
1324
  return_dict: Optional[bool] = None,
1322
1325
  cache_position: Optional[torch.LongTensor] = None,
1326
+ **kwargs,
1323
1327
  ) -> Union[Seq2SeqTSModelOutput, tuple]:
1324
1328
  r"""
1325
1329
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
@@ -361,6 +361,7 @@ class TimesFmModel(TimesFmPreTrainedModel):
361
361
  freq: torch.Tensor,
362
362
  output_attentions: bool = False,
363
363
  output_hidden_states: bool = False,
364
+ **kwargs,
364
365
  ) -> TimesFmOutput:
365
366
  r"""
366
367
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -668,6 +669,7 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
668
669
  truncate_negative: bool = False,
669
670
  output_attentions: Optional[bool] = None,
670
671
  output_hidden_states: Optional[bool] = None,
672
+ **kwargs,
671
673
  ) -> TimesFmOutputForPrediction:
672
674
  r"""
673
675
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -317,6 +317,7 @@ class TimesFmModel(TimesFmPreTrainedModel):
317
317
  freq: torch.Tensor,
318
318
  output_attentions: bool = False,
319
319
  output_hidden_states: bool = False,
320
+ **kwargs,
320
321
  ) -> TimesFmOutput:
321
322
  r"""
322
323
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -624,6 +625,7 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
624
625
  truncate_negative: bool = False,
625
626
  output_attentions: Optional[bool] = None,
626
627
  output_hidden_states: Optional[bool] = None,
628
+ **kwargs,
627
629
  ) -> TimesFmOutputForPrediction:
628
630
  r"""
629
631
  past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -494,6 +494,7 @@ class TimesformerModel(TimesformerPreTrainedModel):
494
494
  output_attentions: Optional[bool] = None,
495
495
  output_hidden_states: Optional[bool] = None,
496
496
  return_dict: Optional[bool] = None,
497
+ **kwargs,
497
498
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
498
499
  r"""
499
500
  Examples:
@@ -624,6 +625,7 @@ class TimesformerForVideoClassification(TimesformerPreTrainedModel):
624
625
  output_attentions: Optional[bool] = None,
625
626
  output_hidden_states: Optional[bool] = None,
626
627
  return_dict: Optional[bool] = None,
628
+ **kwargs,
627
629
  ) -> Union[tuple, ImageClassifierOutput]:
628
630
  r"""
629
631
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -225,7 +225,7 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
225
225
  "different architecture or updating the timm package to a compatible version."
226
226
  )
227
227
 
228
- pixel_values = pixel_values.to(self.device, self.dtype)
228
+ pixel_values = pixel_values.to(self.device)
229
229
 
230
230
  if self.features_only:
231
231
  last_hidden_state = self.timm_model.forward(pixel_values, **kwargs)
@@ -459,6 +459,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
459
459
  output_hidden_states=None,
460
460
  return_dict=None,
461
461
  cache_position=None,
462
+ **kwargs,
462
463
  ):
463
464
  r"""
464
465
  Args:
@@ -686,6 +687,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
686
687
  output_hidden_states: Optional[bool] = None,
687
688
  return_dict: Optional[bool] = None,
688
689
  cache_position: Optional[torch.Tensor] = None,
690
+ **kwargs,
689
691
  ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
690
692
  r"""
691
693
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -721,6 +721,7 @@ class TvpModel(TvpPreTrainedModel):
721
721
  output_hidden_states: Optional[bool] = None,
722
722
  return_dict: Optional[bool] = None,
723
723
  interpolate_pos_encoding: bool = False,
724
+ **kwargs,
724
725
  ):
725
726
  r"""
726
727
  Examples:
@@ -822,6 +823,7 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
822
823
  output_hidden_states: Optional[bool] = None,
823
824
  return_dict: Optional[bool] = None,
824
825
  interpolate_pos_encoding: bool = False,
826
+ **kwargs,
825
827
  ):
826
828
  r"""
827
829
  labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
@@ -1105,6 +1105,7 @@ class UdopStack(UdopPreTrainedModel):
1105
1105
  output_hidden_states=None,
1106
1106
  return_dict=None,
1107
1107
  cache_position=None,
1108
+ **kwargs,
1108
1109
  ):
1109
1110
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1110
1111
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1474,6 +1475,7 @@ class UdopModel(UdopPreTrainedModel):
1474
1475
  output_hidden_states: Optional[bool] = None,
1475
1476
  return_dict: Optional[bool] = None,
1476
1477
  cache_position: Optional[torch.LongTensor] = None,
1478
+ **kwargs,
1477
1479
  ) -> tuple[Tensor, ...]:
1478
1480
  r"""
1479
1481
  bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
@@ -1652,6 +1654,7 @@ class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
1652
1654
  return_dict: Optional[bool] = None,
1653
1655
  labels: Optional[Tensor] = None,
1654
1656
  cache_position: Optional[torch.LongTensor] = None,
1657
+ **kwargs,
1655
1658
  ) -> tuple[Tensor, ...]:
1656
1659
  r"""
1657
1660
  bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
@@ -1821,6 +1824,7 @@ class UdopEncoderModel(UdopPreTrainedModel):
1821
1824
  output_attentions: Optional[bool] = None,
1822
1825
  output_hidden_states: Optional[bool] = None,
1823
1826
  return_dict: Optional[bool] = None,
1827
+ **kwargs,
1824
1828
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutputWithAttentionMask]:
1825
1829
  r"""
1826
1830
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -183,10 +183,11 @@ class UdopTokenizer(TokenizersBackend):
183
183
 
184
184
  vocab_files_names = VOCAB_FILES_NAMES
185
185
  model_input_names = ["input_ids", "attention_mask"]
186
- slow_tokenizer_class = None
186
+ model = Unigram
187
187
 
188
188
  def __init__(
189
189
  self,
190
+ vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
190
191
  eos_token="</s>",
191
192
  sep_token="</s>",
192
193
  unk_token="<unk>",
@@ -196,7 +197,6 @@ class UdopTokenizer(TokenizersBackend):
196
197
  pad_token_label=-100,
197
198
  only_label_first_subword=True,
198
199
  extra_special_tokens=None,
199
- vocab=None,
200
200
  **kwargs,
201
201
  ):
202
202
  if "additional_special_tokens" in kwargs and "extra_special_tokens" not in kwargs:
@@ -205,24 +205,17 @@ class UdopTokenizer(TokenizersBackend):
205
205
  kwargs["extra_special_tokens"] = extra_special_tokens
206
206
 
207
207
  if vocab is None:
208
- vocab_scores = [(str(pad_token), 0.0), (str(eos_token), 0.0), (str(unk_token), 0.0), ("▁", -2.0)]
209
- elif isinstance(vocab, dict):
210
- vocab_scores = [(str(token), float(score)) for token, score in vocab.items()]
211
- elif isinstance(vocab, list) and len(vocab) > 0:
212
- if isinstance(vocab[0], (tuple, list)):
213
- vocab_scores = [(str(token), float(score)) for token, score in vocab]
214
- else:
215
- vocab_scores = [(str(token), 0.0) for token in vocab]
208
+ vocab = [(str(pad_token), 0.0), (str(eos_token), 0.0), (str(unk_token), 0.0), ("▁", -2.0)]
216
209
 
217
210
  unk_id = 2
218
- for idx, (token, _) in enumerate(vocab_scores):
211
+ for idx, (token, _) in enumerate(vocab):
219
212
  if token == str(unk_token):
220
213
  unk_id = idx
221
214
  break
222
215
 
223
216
  self._tokenizer = Tokenizer(
224
217
  Unigram(
225
- vocab_scores,
218
+ vocab,
226
219
  unk_id=unk_id,
227
220
  byte_fallback=False,
228
221
  )
@@ -240,7 +233,6 @@ class UdopTokenizer(TokenizersBackend):
240
233
  self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
241
234
 
242
235
  super().__init__(
243
- tokenizer_object=self._tokenizer,
244
236
  eos_token=eos_token,
245
237
  sep_token=sep_token,
246
238
  unk_token=unk_token,
@@ -621,6 +621,7 @@ class UMT5Stack(UMT5PreTrainedModel):
621
621
  output_hidden_states=None,
622
622
  return_dict=None,
623
623
  cache_position=None,
624
+ **kwargs,
624
625
  ):
625
626
  use_cache = use_cache if use_cache is not None else self.config.use_cache
626
627
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -966,6 +967,7 @@ class UMT5Model(UMT5PreTrainedModel):
966
967
  output_hidden_states: Optional[bool] = None,
967
968
  return_dict: Optional[bool] = None,
968
969
  cache_position: Optional[torch.LongTensor] = None,
970
+ **kwargs,
969
971
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
970
972
  r"""
971
973
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1147,6 +1149,7 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin):
1147
1149
  output_hidden_states: Optional[bool] = None,
1148
1150
  return_dict: Optional[bool] = None,
1149
1151
  cache_position: Optional[torch.LongTensor] = None,
1152
+ **kwargs,
1150
1153
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1151
1154
  r"""
1152
1155
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1332,6 +1335,7 @@ class UMT5EncoderModel(UMT5PreTrainedModel):
1332
1335
  output_attentions: Optional[bool] = None,
1333
1336
  output_hidden_states: Optional[bool] = None,
1334
1337
  return_dict: Optional[bool] = None,
1338
+ **kwargs,
1335
1339
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
1336
1340
  r"""
1337
1341
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1403,6 +1407,7 @@ class UMT5ForSequenceClassification(UMT5PreTrainedModel):
1403
1407
  output_attentions: Optional[bool] = None,
1404
1408
  output_hidden_states: Optional[bool] = None,
1405
1409
  return_dict: Optional[bool] = None,
1410
+ **kwargs,
1406
1411
  ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
1407
1412
  r"""
1408
1413
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1545,6 +1550,7 @@ class UMT5ForTokenClassification(UMT5PreTrainedModel):
1545
1550
  output_attentions: Optional[bool] = None,
1546
1551
  output_hidden_states: Optional[bool] = None,
1547
1552
  return_dict: Optional[bool] = None,
1553
+ **kwargs,
1548
1554
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1549
1555
  r"""
1550
1556
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1649,6 +1655,7 @@ class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
1649
1655
  output_attentions: Optional[bool] = None,
1650
1656
  output_hidden_states: Optional[bool] = None,
1651
1657
  return_dict: Optional[bool] = None,
1658
+ **kwargs,
1652
1659
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
1653
1660
  r"""
1654
1661
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1001,6 +1001,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
1001
1001
  output_attentions: Optional[bool] = None,
1002
1002
  output_hidden_states: Optional[bool] = None,
1003
1003
  return_dict: Optional[bool] = None,
1004
+ **kwargs,
1004
1005
  ) -> Union[tuple, UniSpeechBaseModelOutput]:
1005
1006
  r"""
1006
1007
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1108,6 +1109,7 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
1108
1109
  output_attentions: Optional[bool] = None,
1109
1110
  output_hidden_states: Optional[bool] = None,
1110
1111
  return_dict: Optional[bool] = None,
1112
+ **kwargs,
1111
1113
  ) -> Union[tuple, UniSpeechForPreTrainingOutput]:
1112
1114
  r"""
1113
1115
  Example:
@@ -1255,6 +1257,7 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
1255
1257
  output_hidden_states: Optional[bool] = None,
1256
1258
  return_dict: Optional[bool] = None,
1257
1259
  labels: Optional[torch.Tensor] = None,
1260
+ **kwargs,
1258
1261
  ) -> Union[tuple, CausalLMOutput]:
1259
1262
  r"""
1260
1263
  labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
@@ -1366,6 +1369,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
1366
1369
  output_hidden_states: Optional[bool] = None,
1367
1370
  return_dict: Optional[bool] = None,
1368
1371
  labels: Optional[torch.Tensor] = None,
1372
+ **kwargs,
1369
1373
  ) -> Union[tuple, SequenceClassifierOutput]:
1370
1374
  r"""
1371
1375
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -244,6 +244,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel, Wav2Vec2Model):
244
244
  output_attentions: Optional[bool] = None,
245
245
  output_hidden_states: Optional[bool] = None,
246
246
  return_dict: Optional[bool] = None,
247
+ **kwargs,
247
248
  ) -> Union[tuple, UniSpeechBaseModelOutput]:
248
249
  r"""
249
250
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -351,6 +352,7 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
351
352
  output_attentions: Optional[bool] = None,
352
353
  output_hidden_states: Optional[bool] = None,
353
354
  return_dict: Optional[bool] = None,
355
+ **kwargs,
354
356
  ) -> Union[tuple, UniSpeechForPreTrainingOutput]:
355
357
  r"""
356
358
  Example: