transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ import torch
25
25
  import torch.nn as nn
26
26
  from torch.nn import LayerNorm
27
27
 
28
+ from ... import initialization as init
28
29
  from ...activations import ACT2FN
29
30
  from ...cache_utils import Cache
30
31
  from ...generation import GenerationMixin
@@ -43,6 +44,8 @@ class VideoLlama3VisionRotaryEmbedding(nn.Module):
43
44
 
44
45
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
45
46
  super().__init__()
47
+ self.dim = dim
48
+ self.theta = theta
46
49
  inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
47
50
  self.register_buffer("inv_freq", inv_freq, persistent=False)
48
51
 
@@ -380,6 +383,12 @@ class VideoLlama3PreTrainedModel(PreTrainedModel):
380
383
  _can_compile_fullgraph = True
381
384
  _supports_attention_backend = True
382
385
 
386
+ def _init_weights(self, module):
387
+ super()._init_weights(module)
388
+ if isinstance(module, VideoLlama3VisionRotaryEmbedding):
389
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
390
+ init.copy_(module.inv_freq, inv_freq)
391
+
383
392
 
384
393
  class VideoLlama3VisionModel(VideoLlama3PreTrainedModel):
385
394
  config: VideoLlama3VisionConfig
@@ -855,6 +864,7 @@ class VideoLlama3ForConditionalGeneration(VideoLlama3PreTrainedModel, Generation
855
864
  video_grid_thw: Optional[torch.LongTensor] = None,
856
865
  video_merge_sizes: Optional[torch.LongTensor] = None,
857
866
  video_compression_mask: Optional[torch.BoolTensor] = None,
867
+ is_first_iteration: Optional[bool] = False,
858
868
  **kwargs,
859
869
  ):
860
870
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -874,10 +884,11 @@ class VideoLlama3ForConditionalGeneration(VideoLlama3PreTrainedModel, Generation
874
884
  video_merge_sizes=video_merge_sizes,
875
885
  video_compression_mask=video_compression_mask,
876
886
  use_cache=use_cache,
887
+ is_first_iteration=is_first_iteration,
877
888
  **kwargs,
878
889
  )
879
890
 
880
- if model_inputs["cache_position"][0] != 0:
891
+ if not is_first_iteration and use_cache:
881
892
  model_inputs["pixel_values"] = None
882
893
  model_inputs["pixel_values_videos"] = None
883
894
 
@@ -21,6 +21,7 @@ import torch.nn as nn
21
21
  import torch.nn.functional as F
22
22
  from torch.nn import LayerNorm
23
23
 
24
+ from ... import initialization as init
24
25
  from ...cache_utils import Cache
25
26
  from ...configuration_utils import PreTrainedConfig
26
27
  from ...feature_extraction_utils import BatchFeature
@@ -433,6 +434,12 @@ class VideoLlama3PreTrainedModel(Qwen2VLPreTrainedModel):
433
434
  config: VideoLlama3Config
434
435
  _no_split_modules = ["VideoLlama3VisionEncoderLayer"]
435
436
 
437
+ def _init_weights(self, module):
438
+ PreTrainedModel._init_weights(self, module)
439
+ if isinstance(module, VideoLlama3VisionRotaryEmbedding):
440
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
441
+ init.copy_(module.inv_freq, inv_freq)
442
+
436
443
 
437
444
  class VideoLlama3VisionModel(VideoLlama3PreTrainedModel):
438
445
  config: VideoLlama3VisionConfig
@@ -842,6 +849,7 @@ class VideoLlama3ForConditionalGeneration(Qwen2VLForConditionalGeneration):
842
849
  video_grid_thw: Optional[torch.LongTensor] = None,
843
850
  video_merge_sizes: Optional[torch.LongTensor] = None,
844
851
  video_compression_mask: Optional[torch.BoolTensor] = None,
852
+ is_first_iteration: Optional[bool] = False,
845
853
  **kwargs,
846
854
  ):
847
855
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -861,10 +869,11 @@ class VideoLlama3ForConditionalGeneration(Qwen2VLForConditionalGeneration):
861
869
  video_merge_sizes=video_merge_sizes,
862
870
  video_compression_mask=video_compression_mask,
863
871
  use_cache=use_cache,
872
+ is_first_iteration=is_first_iteration,
864
873
  **kwargs,
865
874
  )
866
875
 
867
- if model_inputs["cache_position"][0] != 0:
876
+ if not is_first_iteration and use_cache:
868
877
  model_inputs["pixel_values"] = None
869
878
  model_inputs["pixel_values_videos"] = None
870
879
 
@@ -599,6 +599,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
599
599
  attention_mask=None,
600
600
  cache_position=None,
601
601
  logits_to_keep=None,
602
+ is_first_iteration=False,
602
603
  **kwargs,
603
604
  ):
604
605
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -610,12 +611,15 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
610
611
  attention_mask=attention_mask,
611
612
  cache_position=cache_position,
612
613
  logits_to_keep=logits_to_keep,
614
+ is_first_iteration=is_first_iteration,
613
615
  **kwargs,
614
616
  )
615
617
 
616
- if cache_position[0] == 0:
617
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
618
- # Otherwise we need pixel values to be passed to model
618
+ if is_first_iteration or not kwargs.get("use_cache", True):
619
+ # Pixel values are used only in the first iteration if available
620
+ # In subsquent iterations, they are already merged with text and cached
621
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
622
+ # iteration with a question and cached system prompt (continue generate from cache)
619
623
  model_inputs["pixel_values_images"] = pixel_values_images
620
624
  model_inputs["pixel_values_videos"] = pixel_values_videos
621
625
 
@@ -115,7 +115,7 @@ class ViltConfig(PreTrainedConfig):
115
115
  num_channels=3,
116
116
  qkv_bias=True,
117
117
  max_image_length=-1,
118
- tie_word_embeddings=False,
118
+ tie_word_embeddings=True,
119
119
  num_images=-1,
120
120
  **kwargs,
121
121
  ):
@@ -142,7 +142,7 @@ class ViltConfig(PreTrainedConfig):
142
142
  self.qkv_bias = qkv_bias
143
143
  self.max_image_length = max_image_length
144
144
  self.num_images = num_images
145
- self.tie_encoder_decoder = True
145
+ self.tie_word_embeddings = True # force it
146
146
 
147
147
 
148
148
  __all__ = ["ViltConfig"]
@@ -23,6 +23,7 @@ import torch
23
23
  from torch import nn
24
24
  from torch.nn import CrossEntropyLoss
25
25
 
26
+ from ... import initialization as init
26
27
  from ...activations import ACT2FN
27
28
  from ...modeling_layers import GradientCheckpointingLayer
28
29
  from ...modeling_outputs import (
@@ -516,6 +517,12 @@ class ViltPreTrainedModel(PreTrainedModel):
516
517
  supports_gradient_checkpointing = True
517
518
  _no_split_modules = ["ViltEmbeddings", "ViltSelfAttention"]
518
519
 
520
+ def _init_weights(self, module):
521
+ super()._init_weights(module)
522
+ if isinstance(module, TextEmbeddings):
523
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
524
+ init.zeros_(module.token_type_ids)
525
+
519
526
 
520
527
  @auto_docstring
521
528
  class ViltModel(ViltPreTrainedModel):
@@ -556,6 +563,7 @@ class ViltModel(ViltPreTrainedModel):
556
563
  output_attentions: Optional[bool] = None,
557
564
  output_hidden_states: Optional[bool] = None,
558
565
  return_dict: Optional[bool] = None,
566
+ **kwargs,
559
567
  ) -> Union[BaseModelOutputWithPooling, tuple[torch.FloatTensor]]:
560
568
  r"""
561
569
  image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
@@ -708,6 +716,7 @@ class ViltForMaskedLM(ViltPreTrainedModel):
708
716
  output_attentions: Optional[bool] = None,
709
717
  output_hidden_states: Optional[bool] = None,
710
718
  return_dict: Optional[bool] = None,
719
+ **kwargs,
711
720
  ) -> Union[MaskedLMOutput, tuple[torch.FloatTensor]]:
712
721
  r"""
713
722
  image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
@@ -875,6 +884,7 @@ class ViltForQuestionAnswering(ViltPreTrainedModel):
875
884
  output_attentions: Optional[bool] = None,
876
885
  output_hidden_states: Optional[bool] = None,
877
886
  return_dict: Optional[bool] = None,
887
+ **kwargs,
878
888
  ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]:
879
889
  r"""
880
890
  image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
@@ -979,6 +989,7 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
979
989
  output_attentions: Optional[bool] = None,
980
990
  output_hidden_states: Optional[bool] = None,
981
991
  return_dict: Optional[bool] = None,
992
+ **kwargs,
982
993
  ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]:
983
994
  r"""
984
995
  image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
@@ -1082,6 +1093,7 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
1082
1093
  output_attentions: Optional[bool] = None,
1083
1094
  output_hidden_states: Optional[bool] = None,
1084
1095
  return_dict: Optional[bool] = None,
1096
+ **kwargs,
1085
1097
  ) -> Union[ViltForImagesAndTextClassificationOutput, tuple[torch.FloatTensor]]:
1086
1098
  r"""
1087
1099
  image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
@@ -1210,6 +1222,7 @@ class ViltForTokenClassification(ViltPreTrainedModel):
1210
1222
  output_attentions: Optional[bool] = None,
1211
1223
  output_hidden_states: Optional[bool] = None,
1212
1224
  return_dict: Optional[bool] = None,
1225
+ **kwargs,
1213
1226
  ) -> Union[TokenClassifierOutput, tuple[torch.FloatTensor]]:
1214
1227
  r"""
1215
1228
  image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
@@ -415,6 +415,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
415
415
  attention_mask=None,
416
416
  cache_position=None,
417
417
  logits_to_keep=None,
418
+ is_first_iteration=False,
418
419
  **kwargs,
419
420
  ):
420
421
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -426,12 +427,15 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
426
427
  attention_mask=attention_mask,
427
428
  cache_position=cache_position,
428
429
  logits_to_keep=logits_to_keep,
430
+ is_first_iteration=is_first_iteration,
429
431
  **kwargs,
430
432
  )
431
433
 
432
- if cache_position[0] == 0:
433
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
434
- # Otherwise we need pixel values to be passed to model
434
+ if is_first_iteration or not kwargs.get("use_cache", True):
435
+ # Pixel values are used only in the first iteration if available
436
+ # In subsquent iterations, they are already merged with text and cached
437
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
438
+ # iteration with a question and cached system prompt (continue generate from cache)
435
439
  model_inputs["pixel_values"] = pixel_values
436
440
 
437
441
  return model_inputs
@@ -184,6 +184,7 @@ class VisionTextDualEncoderModel(PreTrainedModel):
184
184
  output_attentions: Optional[bool] = None,
185
185
  output_hidden_states: Optional[bool] = None,
186
186
  return_dict: Optional[bool] = None,
187
+ **kwargs,
187
188
  ) -> Union[tuple[torch.Tensor], CLIPOutput]:
188
189
  r"""
189
190
  return_loss (`bool`, *optional*):
@@ -473,6 +473,8 @@ class VisualBertPreTrainedModel(PreTrainedModel):
473
473
  init.ones_(module.weight)
474
474
  elif isinstance(module, VisualBertLMPredictionHead):
475
475
  init.zeros_(module.bias)
476
+ elif isinstance(module, VisualBertEmbeddings):
477
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
476
478
 
477
479
 
478
480
  @dataclass
@@ -550,6 +552,7 @@ class VisualBertModel(VisualBertPreTrainedModel):
550
552
  output_attentions: Optional[bool] = None,
551
553
  output_hidden_states: Optional[bool] = None,
552
554
  return_dict: Optional[bool] = None,
555
+ **kwargs,
553
556
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
554
557
  r"""
555
558
  visual_embeds (`torch.FloatTensor` of shape `(batch_size, visual_seq_length, visual_embedding_dim)`, *optional*):
@@ -735,6 +738,7 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel):
735
738
  return_dict: Optional[bool] = None,
736
739
  labels: Optional[torch.LongTensor] = None,
737
740
  sentence_image_labels: Optional[torch.LongTensor] = None,
741
+ **kwargs,
738
742
  ) -> Union[tuple[torch.Tensor], VisualBertForPreTrainingOutput]:
739
743
  r"""
740
744
  visual_embeds (`torch.FloatTensor` of shape `(batch_size, visual_seq_length, visual_embedding_dim)`, *optional*):
@@ -877,6 +881,7 @@ class VisualBertForMultipleChoice(VisualBertPreTrainedModel):
877
881
  output_hidden_states: Optional[bool] = None,
878
882
  return_dict: Optional[bool] = None,
879
883
  labels: Optional[torch.LongTensor] = None,
884
+ **kwargs,
880
885
  ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
881
886
  r"""
882
887
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -1063,6 +1068,7 @@ class VisualBertForQuestionAnswering(VisualBertPreTrainedModel):
1063
1068
  output_hidden_states: Optional[bool] = None,
1064
1069
  return_dict: Optional[bool] = None,
1065
1070
  labels: Optional[torch.LongTensor] = None,
1071
+ **kwargs,
1066
1072
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
1067
1073
  r"""
1068
1074
  visual_embeds (`torch.FloatTensor` of shape `(batch_size, visual_seq_length, visual_embedding_dim)`, *optional*):
@@ -1199,6 +1205,7 @@ class VisualBertForVisualReasoning(VisualBertPreTrainedModel):
1199
1205
  output_hidden_states: Optional[bool] = None,
1200
1206
  return_dict: Optional[bool] = None,
1201
1207
  labels: Optional[torch.LongTensor] = None,
1208
+ **kwargs,
1202
1209
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
1203
1210
  r"""
1204
1211
  visual_embeds (`torch.FloatTensor` of shape `(batch_size, visual_seq_length, visual_embedding_dim)`, *optional*):
@@ -1372,6 +1379,7 @@ class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
1372
1379
  return_dict: Optional[bool] = None,
1373
1380
  region_to_phrase_position: Optional[torch.LongTensor] = None,
1374
1381
  labels: Optional[torch.LongTensor] = None,
1382
+ **kwargs,
1375
1383
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
1376
1384
  r"""
1377
1385
  visual_embeds (`torch.FloatTensor` of shape `(batch_size, visual_seq_length, visual_embedding_dim)`, *optional*):
@@ -630,6 +630,7 @@ class VitDetModel(VitDetPreTrainedModel):
630
630
  output_attentions: Optional[bool] = None,
631
631
  output_hidden_states: Optional[bool] = None,
632
632
  return_dict: Optional[bool] = None,
633
+ **kwargs,
633
634
  ) -> Union[tuple, BaseModelOutput]:
634
635
  r"""
635
636
  Examples:
@@ -706,6 +707,7 @@ class VitDetBackbone(VitDetPreTrainedModel, BackboneMixin):
706
707
  output_hidden_states: Optional[bool] = None,
707
708
  output_attentions: Optional[bool] = None,
708
709
  return_dict: Optional[bool] = None,
710
+ **kwargs,
709
711
  ) -> BackboneOutput:
710
712
  r"""
711
713
  Examples:
@@ -36,7 +36,7 @@ class VitMatteConfig(PreTrainedConfig):
36
36
  documentation from [`PreTrainedConfig`] for more information.
37
37
 
38
38
  Args:
39
- backbone_config (`PreTrainedConfig` or `dict`, *optional*, defaults to `VitDetConfig()`):
39
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `VitDetConfig()`):
40
40
  The configuration of the backbone model.
41
41
  backbone (`str`, *optional*):
42
42
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -152,7 +152,6 @@ class VitMatteImageProcessorFast(BaseImageProcessorFast):
152
152
  processed_images_grouped[shape] = stacked_images
153
153
 
154
154
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
155
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
156
155
 
157
156
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
158
157
 
@@ -65,6 +65,10 @@ class VitMattePreTrainedModel(PreTrainedModel):
65
65
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
66
66
  if module.bias is not None:
67
67
  init.zeros_(module.bias)
68
+ if getattr(module, "running_mean", None) is not None:
69
+ init.zeros_(module.running_mean)
70
+ init.ones_(module.running_var)
71
+ init.zeros_(module.num_batches_tracked)
68
72
 
69
73
 
70
74
  class VitMatteBasicConv3x3(nn.Module):
@@ -234,6 +238,7 @@ class VitMatteForImageMatting(VitMattePreTrainedModel):
234
238
  output_hidden_states: Optional[bool] = None,
235
239
  labels: Optional[torch.Tensor] = None,
236
240
  return_dict: Optional[bool] = None,
241
+ **kwargs,
237
242
  ):
238
243
  r"""
239
244
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -36,7 +36,7 @@ class VitPoseConfig(PreTrainedConfig):
36
36
  documentation from [`PreTrainedConfig`] for more information.
37
37
 
38
38
  Args:
39
- backbone_config (`PreTrainedConfig` or `dict`, *optional*, defaults to `VitPoseBackboneConfig()`):
39
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `VitPoseBackboneConfig()`):
40
40
  The configuration of the backbone model. Currently, only `backbone_config` with `vitpose_backbone` as `model_type` is supported.
41
41
  backbone (`str`, *optional*):
42
42
  Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
@@ -156,7 +156,6 @@ class VitPoseImageProcessorFast(BaseImageProcessorFast):
156
156
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
157
157
 
158
158
  # Stack into batch tensor
159
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
160
159
 
161
160
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
162
161
 
@@ -1275,6 +1275,7 @@ class VitsModel(VitsPreTrainedModel):
1275
1275
  output_hidden_states: Optional[bool] = None,
1276
1276
  return_dict: Optional[bool] = None,
1277
1277
  labels: Optional[torch.FloatTensor] = None,
1278
+ **kwargs,
1278
1279
  ) -> Union[tuple[Any], VitsModelOutput]:
1279
1280
  r"""
1280
1281
  speaker_id (`int`, *optional*):
@@ -1088,6 +1088,7 @@ class VJEPA2ForVideoClassification(VJEPA2PreTrainedModel):
1088
1088
  labels: Optional[torch.Tensor] = None,
1089
1089
  output_attentions: Optional[bool] = None,
1090
1090
  output_hidden_states: Optional[bool] = None,
1091
+ **kwargs,
1091
1092
  ) -> Union[tuple, ImageClassifierOutput]:
1092
1093
  r"""
1093
1094
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -505,11 +505,11 @@ class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
505
505
  # Overwritten -- we should not pass input_features when we are in cached decoding stage
506
506
 
507
507
  input_features = kwargs.pop("input_features", None)
508
- cache_position = kwargs.get("cache_position")
508
+ is_first_iteration = kwargs.get("is_first_iteration", False)
509
509
 
510
510
  model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
511
511
 
512
- if cache_position is not None and cache_position[0] == 0:
512
+ if is_first_iteration or not kwargs.get("use_cache", True):
513
513
  # input_features should only be passed when we are not in cached decoding stage
514
514
  model_inputs["input_features"] = input_features
515
515
 
@@ -267,11 +267,11 @@ class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
267
267
  # Overwritten -- we should not pass input_features when we are in cached decoding stage
268
268
 
269
269
  input_features = kwargs.pop("input_features", None)
270
- cache_position = kwargs.get("cache_position")
270
+ is_first_iteration = kwargs.get("is_first_iteration", False)
271
271
 
272
272
  model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
273
273
 
274
- if cache_position is not None and cache_position[0] == 0:
274
+ if is_first_iteration or not kwargs.get("use_cache", True):
275
275
  # input_features should only be passed when we are not in cached decoding stage
276
276
  model_inputs["input_features"] = input_features
277
277
 
@@ -1340,6 +1340,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
1340
1340
  output_attentions: Optional[bool] = None,
1341
1341
  output_hidden_states: Optional[bool] = None,
1342
1342
  return_dict: Optional[bool] = None,
1343
+ **kwargs,
1343
1344
  ) -> Union[tuple, Wav2Vec2BaseModelOutput]:
1344
1345
  r"""
1345
1346
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1453,6 +1454,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
1453
1454
  output_attentions: Optional[bool] = None,
1454
1455
  output_hidden_states: Optional[bool] = None,
1455
1456
  return_dict: Optional[bool] = None,
1457
+ **kwargs,
1456
1458
  ) -> Union[tuple, Wav2Vec2ForPreTrainingOutput]:
1457
1459
  r"""
1458
1460
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1628,6 +1630,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
1628
1630
  output_hidden_states: Optional[bool] = None,
1629
1631
  return_dict: Optional[bool] = None,
1630
1632
  labels: Optional[torch.Tensor] = None,
1633
+ **kwargs,
1631
1634
  ) -> Union[tuple, MaskedLMOutput]:
1632
1635
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1633
1636
 
@@ -1729,6 +1732,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
1729
1732
  output_hidden_states: Optional[bool] = None,
1730
1733
  return_dict: Optional[bool] = None,
1731
1734
  labels: Optional[torch.Tensor] = None,
1735
+ **kwargs,
1732
1736
  ) -> Union[tuple, CausalLMOutput]:
1733
1737
  r"""
1734
1738
  labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
@@ -1840,6 +1844,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
1840
1844
  output_hidden_states: Optional[bool] = None,
1841
1845
  return_dict: Optional[bool] = None,
1842
1846
  labels: Optional[torch.Tensor] = None,
1847
+ **kwargs,
1843
1848
  ) -> Union[tuple, SequenceClassifierOutput]:
1844
1849
  r"""
1845
1850
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -1943,6 +1948,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
1943
1948
  output_attentions: Optional[bool] = None,
1944
1949
  output_hidden_states: Optional[bool] = None,
1945
1950
  return_dict: Optional[bool] = None,
1951
+ **kwargs,
1946
1952
  ) -> Union[tuple, TokenClassifierOutput]:
1947
1953
  r"""
1948
1954
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -2114,6 +2120,7 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
2114
2120
  output_hidden_states: Optional[bool] = None,
2115
2121
  return_dict: Optional[bool] = None,
2116
2122
  labels: Optional[torch.Tensor] = None,
2123
+ **kwargs,
2117
2124
  ) -> Union[tuple, XVectorOutput]:
2118
2125
  r"""
2119
2126
  input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -74,18 +74,17 @@ class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
74
74
  super().__init__()
75
75
  self.max_len = config.max_source_positions
76
76
  self.d_model = config.hidden_size
77
- self.pe = None
78
- self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
77
+ self.register_buffer("pe", self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)), persistent=False)
79
78
 
80
- def extend_pe(self, x):
79
+ def extend_pe(self, x, pe=None):
81
80
  # Reset the positional encodings
82
- if self.pe is not None:
81
+ if pe is not None:
83
82
  # self.pe contains both positive and negative parts
84
83
  # the length of self.pe is 2 * input_len - 1
85
- if self.pe.size(1) >= x.size(1) * 2 - 1:
86
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
87
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
88
- return
84
+ if pe.size(1) >= x.size(1) * 2 - 1:
85
+ if pe.dtype != x.dtype or pe.device != x.device:
86
+ pe = pe.to(dtype=x.dtype, device=x.device)
87
+ return pe
89
88
  # Suppose `i` is the position of query vector and `j` is the
90
89
  # position of key vector. We use positive relative positions when keys
91
90
  # are to the left (i>j) and negative relative positions otherwise (i<j).
@@ -106,10 +105,10 @@ class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
106
105
  pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
107
106
  pe_negative = pe_negative[1:].unsqueeze(0)
108
107
  pe = torch.cat([pe_positive, pe_negative], dim=1)
109
- self.pe = pe.to(device=x.device, dtype=x.dtype)
108
+ return pe.to(device=x.device, dtype=x.dtype)
110
109
 
111
110
  def forward(self, hidden_states: torch.Tensor):
112
- self.extend_pe(hidden_states)
111
+ self.pe = self.extend_pe(hidden_states, self.pe)
113
112
  start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
114
113
  end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
115
114
  relative_position_embeddings = self.pe[:, start_idx:end_idx]
@@ -749,6 +748,13 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
749
748
  init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1))
750
749
  elif isinstance(module, AMSoftmaxLoss): # noqa: F821
751
750
  init.normal_(module.weight)
751
+ elif isinstance(module, Wav2Vec2BertRotaryPositionalEmbedding):
752
+ dim = self.config.hidden_size // self.config.num_attention_heads
753
+ base = self.config.rotary_embedding_base
754
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
755
+ init.copy_(module.inv_freq, inv_freq)
756
+ elif isinstance(module, Wav2Vec2BertRelPositionalEmbedding):
757
+ init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, module.max_len)))
752
758
 
753
759
  # Ignore copy
754
760
  def _get_feat_extract_output_lengths(
@@ -994,6 +1000,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
994
1000
  output_attentions: Optional[bool] = None,
995
1001
  output_hidden_states: Optional[bool] = None,
996
1002
  return_dict: Optional[bool] = None,
1003
+ **kwargs,
997
1004
  ) -> Union[tuple, Wav2Vec2BertBaseModelOutput]:
998
1005
  r"""
999
1006
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1086,6 +1093,7 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
1086
1093
  output_hidden_states: Optional[bool] = None,
1087
1094
  return_dict: Optional[bool] = None,
1088
1095
  labels: Optional[torch.Tensor] = None,
1096
+ **kwargs,
1089
1097
  ) -> Union[tuple, CausalLMOutput]:
1090
1098
  r"""
1091
1099
  labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
@@ -1192,6 +1200,7 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
1192
1200
  output_hidden_states: Optional[bool] = None,
1193
1201
  return_dict: Optional[bool] = None,
1194
1202
  labels: Optional[torch.Tensor] = None,
1203
+ **kwargs,
1195
1204
  ) -> Union[tuple, SequenceClassifierOutput]:
1196
1205
  r"""
1197
1206
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1282,6 +1291,7 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2BertPreTrainedModel):
1282
1291
  output_attentions: Optional[bool] = None,
1283
1292
  output_hidden_states: Optional[bool] = None,
1284
1293
  return_dict: Optional[bool] = None,
1294
+ **kwargs,
1285
1295
  ) -> Union[tuple, TokenClassifierOutput]:
1286
1296
  r"""
1287
1297
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1440,6 +1450,7 @@ class Wav2Vec2BertForXVector(Wav2Vec2BertPreTrainedModel):
1440
1450
  output_hidden_states: Optional[bool] = None,
1441
1451
  return_dict: Optional[bool] = None,
1442
1452
  labels: Optional[torch.Tensor] = None,
1453
+ **kwargs,
1443
1454
  ) -> Union[tuple, XVectorOutput]:
1444
1455
  r"""
1445
1456
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -621,6 +621,13 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel):
621
621
  init.constant_(module.layer_weights, 1.0 / (self.config.num_hidden_layers + 1))
622
622
  elif isinstance(module, AMSoftmaxLoss): # noqa: F821
623
623
  init.normal_(module.weight)
624
+ elif isinstance(module, Wav2Vec2BertRotaryPositionalEmbedding):
625
+ dim = self.config.hidden_size // self.config.num_attention_heads
626
+ base = self.config.rotary_embedding_base
627
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
628
+ init.copy_(module.inv_freq, inv_freq)
629
+ elif isinstance(module, Wav2Vec2BertRelPositionalEmbedding):
630
+ init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, module.max_len)))
624
631
 
625
632
  # Ignore copy
626
633
  def _get_feat_extract_output_lengths(
@@ -702,6 +709,7 @@ class Wav2Vec2BertModel(Wav2Vec2Model, Wav2Vec2BertPreTrainedModel):
702
709
  output_attentions: Optional[bool] = None,
703
710
  output_hidden_states: Optional[bool] = None,
704
711
  return_dict: Optional[bool] = None,
712
+ **kwargs,
705
713
  ) -> Union[tuple, Wav2Vec2BertBaseModelOutput]:
706
714
  r"""
707
715
  mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -768,6 +776,7 @@ class Wav2Vec2BertForCTC(Wav2Vec2ConformerForCTC):
768
776
  output_hidden_states: Optional[bool] = None,
769
777
  return_dict: Optional[bool] = None,
770
778
  labels: Optional[torch.Tensor] = None,
779
+ **kwargs,
771
780
  ) -> Union[tuple, CausalLMOutput]:
772
781
  r"""
773
782
  labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
@@ -856,6 +865,7 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2ForSequenceClassification):
856
865
  output_hidden_states: Optional[bool] = None,
857
866
  return_dict: Optional[bool] = None,
858
867
  labels: Optional[torch.Tensor] = None,
868
+ **kwargs,
859
869
  ) -> Union[tuple, SequenceClassifierOutput]:
860
870
  r"""
861
871
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -926,6 +936,7 @@ class Wav2Vec2BertForAudioFrameClassification(Wav2Vec2ConformerForAudioFrameClas
926
936
  output_attentions: Optional[bool] = None,
927
937
  output_hidden_states: Optional[bool] = None,
928
938
  return_dict: Optional[bool] = None,
939
+ **kwargs,
929
940
  ) -> Union[tuple, TokenClassifierOutput]:
930
941
  r"""
931
942
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -987,6 +998,7 @@ class Wav2Vec2BertForXVector(Wav2Vec2ConformerForXVector):
987
998
  output_hidden_states: Optional[bool] = None,
988
999
  return_dict: Optional[bool] = None,
989
1000
  labels: Optional[torch.Tensor] = None,
1001
+ **kwargs,
990
1002
  ) -> Union[tuple, XVectorOutput]:
991
1003
  r"""
992
1004
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):