transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -42,7 +42,7 @@ from .image_processing_mobilevit import MobileVitImageProcessorKwargs
42
42
 
43
43
  @auto_docstring
44
44
  class MobileViTImageProcessorFast(BaseImageProcessorFast):
45
- resample = PILImageResampling.BILINEAR
45
+ resample = PILImageResampling.BICUBIC
46
46
  size = {"shortest_edge": 224}
47
47
  default_to_square = False
48
48
  crop_size = {"height": 256, "width": 256}
@@ -182,7 +182,6 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
182
182
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
183
183
 
184
184
  # Stack all processed images if return_tensors is specified
185
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
186
185
 
187
186
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
188
187
 
@@ -615,6 +615,10 @@ class MobileViTPreTrainedModel(PreTrainedModel):
615
615
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
616
616
  if module.bias is not None:
617
617
  init.zeros_(module.bias)
618
+ if getattr(module, "running_mean", None) is not None:
619
+ init.zeros_(module.running_mean)
620
+ init.ones_(module.running_var)
621
+ init.zeros_(module.num_batches_tracked)
618
622
  elif isinstance(module, nn.LayerNorm):
619
623
  init.zeros_(module.bias)
620
624
  init.ones_(module.weight)
@@ -659,6 +663,7 @@ class MobileViTModel(MobileViTPreTrainedModel):
659
663
  pixel_values: Optional[torch.Tensor] = None,
660
664
  output_hidden_states: Optional[bool] = None,
661
665
  return_dict: Optional[bool] = None,
666
+ **kwargs,
662
667
  ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
663
668
  output_hidden_states = (
664
669
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -725,6 +730,7 @@ class MobileViTForImageClassification(MobileViTPreTrainedModel):
725
730
  output_hidden_states: Optional[bool] = None,
726
731
  labels: Optional[torch.Tensor] = None,
727
732
  return_dict: Optional[bool] = None,
733
+ **kwargs,
728
734
  ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
729
735
  r"""
730
736
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -889,6 +895,7 @@ class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
889
895
  labels: Optional[torch.Tensor] = None,
890
896
  output_hidden_states: Optional[bool] = None,
891
897
  return_dict: Optional[bool] = None,
898
+ **kwargs,
892
899
  ) -> Union[tuple, SemanticSegmenterOutput]:
893
900
  r"""
894
901
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -582,6 +582,10 @@ class MobileViTV2PreTrainedModel(PreTrainedModel):
582
582
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
583
583
  if module.bias is not None:
584
584
  init.zeros_(module.bias)
585
+ if getattr(module, "running_mean", None) is not None:
586
+ init.zeros_(module.running_mean)
587
+ init.ones_(module.running_var)
588
+ init.zeros_(module.num_batches_tracked)
585
589
  elif isinstance(module, nn.GroupNorm):
586
590
  init.zeros_(module.bias)
587
591
  init.ones_(module.weight)
@@ -623,6 +627,7 @@ class MobileViTV2Model(MobileViTV2PreTrainedModel):
623
627
  pixel_values: Optional[torch.Tensor] = None,
624
628
  output_hidden_states: Optional[bool] = None,
625
629
  return_dict: Optional[bool] = None,
630
+ **kwargs,
626
631
  ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
627
632
  output_hidden_states = (
628
633
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -691,6 +696,7 @@ class MobileViTV2ForImageClassification(MobileViTV2PreTrainedModel):
691
696
  output_hidden_states: Optional[bool] = None,
692
697
  labels: Optional[torch.Tensor] = None,
693
698
  return_dict: Optional[bool] = None,
699
+ **kwargs,
694
700
  ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
695
701
  r"""
696
702
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -858,6 +864,7 @@ class MobileViTV2ForSemanticSegmentation(MobileViTV2PreTrainedModel):
858
864
  labels: Optional[torch.Tensor] = None,
859
865
  output_hidden_states: Optional[bool] = None,
860
866
  return_dict: Optional[bool] = None,
867
+ **kwargs,
861
868
  ) -> Union[tuple, SemanticSegmenterOutput]:
862
869
  r"""
863
870
  labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -45,6 +45,7 @@ from ...modeling_outputs import (
45
45
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
46
46
  from ...modeling_utils import PreTrainedModel
47
47
  from ...utils import auto_docstring, is_flash_attn_2_available, logging
48
+ from ...utils.generic import maybe_autocast
48
49
  from ...utils.import_utils import is_triton_available
49
50
  from .configuration_modernbert import ModernBertConfig
50
51
 
@@ -267,7 +268,7 @@ class ModernBertRotaryEmbedding(nn.Module):
267
268
  rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
268
269
  curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
269
270
  self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
270
- setattr(self, f"{layer_type}_original_inv_freq", curr_inv_freq)
271
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
271
272
  setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
272
273
 
273
274
  @staticmethod
@@ -316,7 +317,7 @@ class ModernBertRotaryEmbedding(nn.Module):
316
317
  position_ids_expanded = position_ids[:, None, :].float()
317
318
 
318
319
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
319
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
320
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
320
321
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
321
322
  emb = torch.cat((freqs, freqs), dim=-1)
322
323
  cos = emb.cos() * attention_scaling
@@ -676,6 +677,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
676
677
  init.ones_(module.weight)
677
678
  if module.bias is not None:
678
679
  init.zeros_(module.bias)
680
+ elif isinstance(module, ModernBertRotaryEmbedding):
681
+ for layer_type in module.layer_types:
682
+ rope_init_fn = module.compute_default_rope_parameters
683
+ if module.rope_type[layer_type] != "default":
684
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
685
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
686
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
687
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
688
+ elif isinstance(module, ModernBertUnpaddedRotaryEmbedding):
689
+ inv_freq = module._compute_inv_freq()
690
+ init.copy_(module.inv_freq, inv_freq)
679
691
 
680
692
  def _check_and_adjust_attn_implementation(
681
693
  self, attn_implementation: Optional[str], is_init_check: bool = False
@@ -852,6 +864,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
852
864
  output_attentions: Optional[bool] = None,
853
865
  output_hidden_states: Optional[bool] = None,
854
866
  return_dict: Optional[bool] = None,
867
+ **kwargs,
855
868
  ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
856
869
  r"""
857
870
  sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1345,6 +1358,7 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1345
1358
  output_attentions: Optional[bool] = None,
1346
1359
  output_hidden_states: Optional[bool] = None,
1347
1360
  return_dict: Optional[bool] = None,
1361
+ **kwargs,
1348
1362
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1349
1363
  r"""
1350
1364
  sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -35,7 +35,7 @@ from ...modeling_outputs import (
35
35
  SequenceClassifierOutput,
36
36
  TokenClassifierOutput,
37
37
  )
38
- from ...modeling_rope_utils import RopeParameters
38
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
39
39
  from ...modeling_utils import PreTrainedModel
40
40
  from ...utils import auto_docstring, is_flash_attn_2_available, logging
41
41
  from ...utils.import_utils import is_triton_available
@@ -871,6 +871,17 @@ class ModernBertPreTrainedModel(PreTrainedModel):
871
871
  init.ones_(module.weight)
872
872
  if module.bias is not None:
873
873
  init.zeros_(module.bias)
874
+ elif isinstance(module, ModernBertRotaryEmbedding):
875
+ for layer_type in module.layer_types:
876
+ rope_init_fn = module.compute_default_rope_parameters
877
+ if module.rope_type[layer_type] != "default":
878
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
879
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
880
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
881
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
882
+ elif isinstance(module, ModernBertUnpaddedRotaryEmbedding):
883
+ inv_freq = module._compute_inv_freq()
884
+ init.copy_(module.inv_freq, inv_freq)
874
885
 
875
886
  def _check_and_adjust_attn_implementation(
876
887
  self, attn_implementation: Optional[str], is_init_check: bool = False
@@ -975,6 +986,7 @@ class ModernBertModel(ModernBertPreTrainedModel):
975
986
  output_attentions: Optional[bool] = None,
976
987
  output_hidden_states: Optional[bool] = None,
977
988
  return_dict: Optional[bool] = None,
989
+ **kwargs,
978
990
  ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
979
991
  r"""
980
992
  sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1468,6 +1480,7 @@ class ModernBertForTokenClassification(ModernBertPreTrainedModel):
1468
1480
  output_attentions: Optional[bool] = None,
1469
1481
  output_hidden_states: Optional[bool] = None,
1470
1482
  return_dict: Optional[bool] = None,
1483
+ **kwargs,
1471
1484
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1472
1485
  r"""
1473
1486
  sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -39,7 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
40
  from ...processing_utils import Unpack
41
41
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
42
- from ...utils.generic import check_model_inputs
42
+ from ...utils.generic import check_model_inputs, maybe_autocast
43
43
  from .configuration_modernbert_decoder import ModernBertDecoderConfig
44
44
 
45
45
 
@@ -119,7 +119,7 @@ class ModernBertDecoderRotaryEmbedding(nn.Module):
119
119
  rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
120
120
  curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
121
121
  self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
122
- setattr(self, f"{layer_type}_original_inv_freq", curr_inv_freq)
122
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
123
123
  setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
124
124
 
125
125
  @staticmethod
@@ -168,7 +168,7 @@ class ModernBertDecoderRotaryEmbedding(nn.Module):
168
168
  position_ids_expanded = position_ids[:, None, :].float()
169
169
 
170
170
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
171
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
171
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
172
172
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
173
173
  emb = torch.cat((freqs, freqs), dim=-1)
174
174
  cos = emb.cos() * attention_scaling
@@ -342,7 +342,7 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
342
342
  attention_mask: Optional[torch.Tensor] = None,
343
343
  past_key_values: Optional[Cache] = None,
344
344
  cache_position: Optional[torch.LongTensor] = None,
345
- **kwargs,
345
+ **kwargs: Unpack[TransformersKwargs],
346
346
  ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
347
347
  residual = hidden_states
348
348
  hidden_states = self.attn_norm(hidden_states)
@@ -443,6 +443,14 @@ class ModernBertDecoderPreTrainedModel(PreTrainedModel):
443
443
  init.ones_(module.weight)
444
444
  if module.bias is not None:
445
445
  init.zeros_(module.bias)
446
+ elif isinstance(module, ModernBertDecoderRotaryEmbedding):
447
+ for layer_type in module.layer_types:
448
+ rope_init_fn = module.compute_default_rope_parameters
449
+ if module.rope_type[layer_type] != "default":
450
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
451
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
452
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
453
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
446
454
 
447
455
 
448
456
  @auto_docstring
@@ -477,7 +485,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
477
485
  inputs_embeds: Optional[torch.Tensor] = None,
478
486
  use_cache: Optional[bool] = None,
479
487
  cache_position: Optional[torch.LongTensor] = None,
480
- **kwargs,
488
+ **kwargs: Unpack[TransformersKwargs],
481
489
  ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
482
490
  if (input_ids is None) == (inputs_embeds is None):
483
491
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@@ -489,7 +497,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
489
497
  batch_size, seq_length = inputs_embeds.shape[:2]
490
498
 
491
499
  # Handle past_key_values and cache setup
492
- if use_cache and past_key_values is None and not self.training:
500
+ if use_cache and past_key_values is None:
493
501
  past_key_values = DynamicCache(config=self.config)
494
502
 
495
503
  if cache_position is None:
@@ -527,13 +535,12 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
527
535
  for layer_type in self.config.layer_types:
528
536
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
529
537
 
530
- for idx, decoder_layer in enumerate(self.layers):
538
+ for decoder_layer in self.layers:
531
539
  hidden_states = decoder_layer(
532
540
  hidden_states,
533
541
  attention_mask=causal_mask_mapping[decoder_layer.attention_type],
534
542
  position_embeddings=position_embeddings[decoder_layer.attention_type],
535
543
  past_key_values=past_key_values,
536
- use_cache=use_cache,
537
544
  cache_position=cache_position,
538
545
  position_ids=position_ids,
539
546
  **kwargs,
@@ -583,7 +590,7 @@ class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationM
583
590
  labels: Optional[torch.LongTensor] = None,
584
591
  use_cache: Optional[bool] = None,
585
592
  logits_to_keep: Union[int, torch.Tensor] = 0,
586
- **kwargs,
593
+ **kwargs: Unpack[TransformersKwargs],
587
594
  ) -> Union[tuple, CausalLMOutputWithPast]:
588
595
  r"""
589
596
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -686,7 +693,7 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
686
693
  inputs_embeds: Optional[torch.Tensor] = None,
687
694
  labels: Optional[torch.LongTensor] = None,
688
695
  use_cache: Optional[bool] = None,
689
- **kwargs,
696
+ **kwargs: Unpack[TransformersKwargs],
690
697
  ) -> Union[tuple, SequenceClassifierOutputWithPast]:
691
698
  r"""
692
699
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -28,7 +28,7 @@ from ...generation import GenerationMixin
28
28
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
29
29
  from ...modeling_layers import GradientCheckpointingLayer
30
30
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
- from ...modeling_rope_utils import RopeParameters
31
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, RopeParameters
32
32
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
33
33
  from ...processing_utils import Unpack
34
34
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
@@ -394,7 +394,7 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer):
394
394
  attention_mask: Optional[torch.Tensor] = None,
395
395
  past_key_values: Optional[Cache] = None,
396
396
  cache_position: Optional[torch.LongTensor] = None,
397
- **kwargs,
397
+ **kwargs: Unpack[TransformersKwargs],
398
398
  ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
399
399
  residual = hidden_states
400
400
  hidden_states = self.attn_norm(hidden_states)
@@ -482,6 +482,14 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
482
482
  init.ones_(module.weight)
483
483
  if module.bias is not None:
484
484
  init.zeros_(module.bias)
485
+ elif isinstance(module, ModernBertDecoderRotaryEmbedding):
486
+ for layer_type in module.layer_types:
487
+ rope_init_fn = module.compute_default_rope_parameters
488
+ if module.rope_type[layer_type] != "default":
489
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
490
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
491
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
492
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
485
493
 
486
494
  def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check):
487
495
  raise AttributeError("No need to inherit!")
@@ -525,7 +533,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
525
533
  inputs_embeds: Optional[torch.Tensor] = None,
526
534
  use_cache: Optional[bool] = None,
527
535
  cache_position: Optional[torch.LongTensor] = None,
528
- **kwargs,
536
+ **kwargs: Unpack[TransformersKwargs],
529
537
  ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
530
538
  if (input_ids is None) == (inputs_embeds is None):
531
539
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@@ -537,7 +545,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
537
545
  batch_size, seq_length = inputs_embeds.shape[:2]
538
546
 
539
547
  # Handle past_key_values and cache setup
540
- if use_cache and past_key_values is None and not self.training:
548
+ if use_cache and past_key_values is None:
541
549
  past_key_values = DynamicCache(config=self.config)
542
550
 
543
551
  if cache_position is None:
@@ -575,13 +583,12 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel):
575
583
  for layer_type in self.config.layer_types:
576
584
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
577
585
 
578
- for idx, decoder_layer in enumerate(self.layers):
586
+ for decoder_layer in self.layers:
579
587
  hidden_states = decoder_layer(
580
588
  hidden_states,
581
589
  attention_mask=causal_mask_mapping[decoder_layer.attention_type],
582
590
  position_embeddings=position_embeddings[decoder_layer.attention_type],
583
591
  past_key_values=past_key_values,
584
- use_cache=use_cache,
585
592
  cache_position=cache_position,
586
593
  position_ids=position_ids,
587
594
  **kwargs,
@@ -631,7 +638,7 @@ class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationM
631
638
  labels: Optional[torch.LongTensor] = None,
632
639
  use_cache: Optional[bool] = None,
633
640
  logits_to_keep: Union[int, torch.Tensor] = 0,
634
- **kwargs,
641
+ **kwargs: Unpack[TransformersKwargs],
635
642
  ) -> Union[tuple, CausalLMOutputWithPast]:
636
643
  r"""
637
644
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -734,7 +741,7 @@ class ModernBertDecoderForSequenceClassification(ModernBertDecoderPreTrainedMode
734
741
  inputs_embeds: Optional[torch.Tensor] = None,
735
742
  labels: Optional[torch.LongTensor] = None,
736
743
  use_cache: Optional[bool] = None,
737
- **kwargs,
744
+ **kwargs: Unpack[TransformersKwargs],
738
745
  ) -> Union[tuple, SequenceClassifierOutputWithPast]:
739
746
  r"""
740
747
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -30,6 +30,7 @@ from transformers.utils.generic import OutputRecorder, check_model_inputs
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
32
32
  from ...generation import GenerationMixin
33
+ from ...integrations import use_kernelized_func
33
34
  from ...masking_utils import create_causal_mask
34
35
  from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
35
36
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -45,6 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
45
46
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
47
  from ...processing_utils import Unpack
47
48
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
49
+ from ...utils.generic import maybe_autocast
48
50
  from .configuration_moonshine import MoonshineConfig
49
51
 
50
52
 
@@ -96,7 +98,7 @@ class MoonshineRotaryEmbedding(nn.Module):
96
98
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
97
99
 
98
100
  self.register_buffer("inv_freq", inv_freq, persistent=False)
99
- self.original_inv_freq = inv_freq
101
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
100
102
 
101
103
  @staticmethod
102
104
  def compute_default_rope_parameters(
@@ -137,7 +139,7 @@ class MoonshineRotaryEmbedding(nn.Module):
137
139
  position_ids_expanded = position_ids[:, None, :].float()
138
140
 
139
141
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
140
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
142
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
141
143
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
142
144
  emb = torch.cat((freqs, freqs), dim=-1)
143
145
  cos = emb.cos() * self.attention_scaling
@@ -233,6 +235,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
233
235
  return q_embed, k_embed
234
236
 
235
237
 
238
+ @use_kernelized_func(apply_rotary_pos_emb)
236
239
  class MoonshineAttention(nn.Module):
237
240
  """Multi-headed attention from 'Attention Is All You Need' paper"""
238
241
 
@@ -264,7 +267,6 @@ class MoonshineAttention(nn.Module):
264
267
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
265
268
  )
266
269
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
267
- self.rotary_fn = apply_rotary_pos_emb
268
270
 
269
271
  # Pad head dimension to the next specified multiple.
270
272
  if self.config.pad_head_dim_to_multiple_of is not None:
@@ -34,6 +34,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast,
34
34
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
35
  from ...modeling_utils import PreTrainedModel
36
36
  from ...utils import auto_docstring, is_torch_flex_attn_available, logging
37
+ from ...utils.generic import maybe_autocast
37
38
  from ..auto.modeling_auto import AutoModel
38
39
  from .configuration_moshi import MoshiConfig, MoshiDepthConfig
39
40
 
@@ -288,7 +289,7 @@ class MoshiRotaryEmbedding(nn.Module):
288
289
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
289
290
 
290
291
  self.register_buffer("inv_freq", inv_freq, persistent=False)
291
- self.original_inv_freq = inv_freq
292
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
292
293
 
293
294
  @staticmethod
294
295
  def compute_default_rope_parameters(
@@ -327,7 +328,7 @@ class MoshiRotaryEmbedding(nn.Module):
327
328
  position_ids_expanded = position_ids[:, None, :].float()
328
329
 
329
330
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
330
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
331
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
331
332
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
332
333
  emb = torch.cat((freqs, freqs), dim=-1)
333
334
  cos = emb.cos() * self.attention_scaling
@@ -608,8 +609,8 @@ class MoshiFlashAttention2(MoshiAttention):
608
609
  else torch.get_autocast_gpu_dtype()
609
610
  )
610
611
  # Handle the case where the model is quantized
611
- elif hasattr(self.config, "_pre_quantization_dtype"):
612
- target_dtype = self.config._pre_quantization_dtype
612
+ elif hasattr(self.config, "quantization_config"):
613
+ target_dtype = self.config.dtype
613
614
  else:
614
615
  target_dtype = self.q_proj.weight.dtype
615
616
 
@@ -868,6 +869,8 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
868
869
  self.gradient_checkpointing = False
869
870
  self.config = config
870
871
 
872
+ self.post_init()
873
+
871
874
  def forward(
872
875
  self,
873
876
  input_ids: Optional[torch.LongTensor] = None,
@@ -882,6 +885,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
882
885
  position_ids: Optional[torch.LongTensor] = None,
883
886
  labels: Optional[torch.LongTensor] = None,
884
887
  cache_position: Optional[torch.LongTensor] = None,
888
+ **kwargs,
885
889
  ) -> Union[tuple, BaseModelOutputWithPast]:
886
890
  """
887
891
  Args:
@@ -957,7 +961,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
957
961
  )
958
962
  use_cache = False
959
963
 
960
- if use_cache and past_key_values is None and not self.training:
964
+ if use_cache and past_key_values is None:
961
965
  past_key_values = DynamicCache(config=self.config)
962
966
 
963
967
  past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
@@ -1228,6 +1232,7 @@ class MoshiModel(MoshiPreTrainedModel):
1228
1232
  output_hidden_states: Optional[bool] = None,
1229
1233
  return_dict: Optional[bool] = None,
1230
1234
  cache_position: Optional[torch.LongTensor] = None,
1235
+ **kwargs,
1231
1236
  ) -> Union[tuple, BaseModelOutputWithPast]:
1232
1237
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1233
1238
  output_hidden_states = (
@@ -2175,6 +2180,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
2175
2180
  user_delay_pattern_mask=None,
2176
2181
  moshi_delay_pattern_mask=None,
2177
2182
  kwargs_depth_decoder=None,
2183
+ is_first_iteration=False,
2178
2184
  blank_user_audio_codes: Optional[torch.FloatTensor] = None,
2179
2185
  **kwargs,
2180
2186
  ):
@@ -2186,49 +2192,21 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
2186
2192
  # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
2187
2193
  # (we can't check exception 3 while compiling)
2188
2194
 
2189
- if past_key_values is not None:
2190
- if (
2191
- inputs_embeds is not None # Exception 1
2192
- or cache_position[-1] >= input_ids.shape[1] # Exception 3
2193
- ):
2194
- input_ids = input_ids[:, -cache_position.shape[0] :]
2195
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
2196
- input_ids = input_ids[:, cache_position]
2197
-
2198
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
2199
- if inputs_embeds is not None and cache_position[0] == 0:
2200
- model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
2201
- else:
2202
- model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
2203
-
2204
- if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
2205
- if model_inputs["inputs_embeds"] is not None:
2206
- batch_size, sequence_length, _ = inputs_embeds.shape
2207
- device = inputs_embeds.device
2208
- else:
2209
- batch_size, sequence_length = input_ids.shape
2210
- device = input_ids.device
2211
-
2212
- attention_mask = self.decoder.model._prepare_4d_causal_attention_mask_with_cache_position(
2213
- attention_mask,
2214
- sequence_length=sequence_length,
2215
- target_length=past_key_values.get_max_cache_shape(),
2216
- dtype=self.decoder.lm_head.weight.dtype,
2217
- device=device,
2218
- cache_position=cache_position,
2219
- batch_size=batch_size,
2220
- config=self.config,
2221
- past_key_values=past_key_values,
2222
- )
2223
-
2224
- model_inputs.update(
2225
- {
2226
- "position_ids": position_ids,
2227
- "past_key_values": past_key_values,
2228
- "use_cache": use_cache,
2229
- "attention_mask": attention_mask,
2230
- "cache_position": cache_position,
2231
- }
2195
+ model_inputs = super().prepare_inputs_for_generation(
2196
+ input_ids,
2197
+ past_key_values=past_key_values,
2198
+ attention_mask=attention_mask,
2199
+ inputs_embeds=inputs_embeds,
2200
+ cache_position=cache_position,
2201
+ position_ids=position_ids,
2202
+ use_cache=use_cache,
2203
+ logits_to_keep=logits_to_keep,
2204
+ user_delay_pattern_mask=user_delay_pattern_mask,
2205
+ moshi_delay_pattern_mask=moshi_delay_pattern_mask,
2206
+ kwargs_depth_decoder=kwargs_depth_decoder,
2207
+ is_first_iteration=is_first_iteration,
2208
+ blank_user_audio_codes=blank_user_audio_codes,
2209
+ **kwargs,
2232
2210
  )
2233
2211
 
2234
2212
  # 2. Now that everything is prepared, generate audio_codes using the depth decoder
@@ -2267,11 +2245,6 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
2267
2245
  model_inputs["input_ids"] = None
2268
2246
  model_inputs["inputs_embeds"] = inputs_embeds
2269
2247
 
2270
- # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
2271
- for key, value in kwargs.items():
2272
- if key not in model_inputs:
2273
- model_inputs[key] = value
2274
-
2275
2248
  return model_inputs
2276
2249
 
2277
2250
  def _update_model_kwargs_for_generation(
@@ -52,6 +52,8 @@ class MPNetPreTrainedModel(PreTrainedModel):
52
52
  super()._init_weights(module)
53
53
  if isinstance(module, MPNetLMHead):
54
54
  init.zeros_(module.bias)
55
+ elif isinstance(module, MPNetEmbeddings):
56
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
55
57
 
56
58
 
57
59
  class MPNetEmbeddings(nn.Module):
@@ -488,6 +490,7 @@ class MPNetForMaskedLM(MPNetPreTrainedModel):
488
490
  output_attentions: Optional[bool] = None,
489
491
  output_hidden_states: Optional[bool] = None,
490
492
  return_dict: Optional[bool] = None,
493
+ **kwargs,
491
494
  ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
492
495
  r"""
493
496
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -577,6 +580,7 @@ class MPNetForSequenceClassification(MPNetPreTrainedModel):
577
580
  output_attentions: Optional[bool] = None,
578
581
  output_hidden_states: Optional[bool] = None,
579
582
  return_dict: Optional[bool] = None,
583
+ **kwargs,
580
584
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
581
585
  r"""
582
586
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -656,6 +660,7 @@ class MPNetForMultipleChoice(MPNetPreTrainedModel):
656
660
  output_attentions: Optional[bool] = None,
657
661
  output_hidden_states: Optional[bool] = None,
658
662
  return_dict: Optional[bool] = None,
663
+ **kwargs,
659
664
  ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
660
665
  r"""
661
666
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -748,6 +753,7 @@ class MPNetForTokenClassification(MPNetPreTrainedModel):
748
753
  output_attentions: Optional[bool] = None,
749
754
  output_hidden_states: Optional[bool] = None,
750
755
  return_dict: Optional[bool] = None,
756
+ **kwargs,
751
757
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
752
758
  r"""
753
759
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -831,6 +837,7 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
831
837
  output_attentions: Optional[bool] = None,
832
838
  output_hidden_states: Optional[bool] = None,
833
839
  return_dict: Optional[bool] = None,
840
+ **kwargs,
834
841
  ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
835
842
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
836
843