transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from ... import initialization as init
31
31
  from ...activations import ACT2FN
32
32
  from ...cache_utils import Cache, DynamicCache
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
34
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
35
35
  from ...masking_utils import create_causal_mask
36
36
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
37
37
  from ...modeling_layers import GradientCheckpointingLayer
@@ -39,8 +39,8 @@ from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
39
39
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
- from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling
43
- from ...utils.generic import OutputRecorder, check_model_inputs
42
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
43
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
44
44
  from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig
45
45
 
46
46
 
@@ -226,6 +226,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
226
226
  return q_embed, k_embed
227
227
 
228
228
 
229
+ @use_kernelized_func(apply_rotary_pos_emb)
229
230
  class Qwen3VLMoeTextAttention(nn.Module):
230
231
  """Multi-headed attention from 'Attention Is All You Need' paper"""
231
232
 
@@ -252,7 +253,6 @@ class Qwen3VLMoeTextAttention(nn.Module):
252
253
  self.o_proj = nn.Linear(
253
254
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
254
255
  )
255
- self.rotary_fn = apply_rotary_pos_emb
256
256
  self.q_norm = Qwen3VLMoeTextRMSNorm(
257
257
  self.head_dim, eps=config.rms_norm_eps
258
258
  ) # unlike olmo, only on the head dim!
@@ -860,7 +860,7 @@ class Qwen3VLMoeTextRotaryEmbedding(nn.Module):
860
860
  position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
861
861
 
862
862
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
863
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
863
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
864
864
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
865
865
  freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
866
866
  emb = torch.cat((freqs, freqs), dim=-1)
@@ -1358,44 +1358,19 @@ class Qwen3VLMoeModel(Qwen3VLMoePreTrainedModel):
1358
1358
  deepstack_visual_embeds = deepstack_video_embeds
1359
1359
 
1360
1360
  if position_ids is None:
1361
- attention_mask_tensor = (
1362
- attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
1363
- )
1364
- if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
1365
- attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
1366
- # Only apply conversion for floating point tensors (inverted masks)
1367
- if attention_mask_tensor.dtype.is_floating_point:
1368
- attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
1369
- attention_mask_tensor = (1.0 - attention_mask_tensor).int()
1370
-
1371
- # Calculate RoPE index once per generation in the pre-fill stage only.
1372
- # When compiling, we can't check tensor values thus we check only input length
1373
- # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1374
- # models currently cannot do asssisted decoding
1375
- prefill_compiled_stage = is_torchdynamo_compiling() and (
1376
- (input_ids is not None and input_ids.shape[1] != 1)
1377
- or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
1378
- )
1379
- prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
1380
- (cache_position is not None and cache_position[0] == 0)
1381
- or (past_key_values is None or past_key_values.get_seq_length() == 0)
1382
- )
1383
- if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
1361
+ past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1362
+ if self.rope_deltas is None or past_key_values_length == 0:
1384
1363
  position_ids, rope_deltas = self.get_rope_index(
1385
1364
  input_ids,
1386
1365
  image_grid_thw,
1387
1366
  video_grid_thw,
1388
- attention_mask=attention_mask_tensor,
1367
+ attention_mask=attention_mask,
1389
1368
  )
1390
1369
  self.rope_deltas = rope_deltas
1391
1370
  # then use the prev pre-calculated rope-deltas to get the correct position ids
1392
1371
  else:
1393
1372
  batch_size, seq_length, _ = inputs_embeds.shape
1394
- delta = (
1395
- (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1396
- if cache_position is not None
1397
- else 0
1398
- )
1373
+ delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
1399
1374
  position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1400
1375
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1401
1376
  if cache_position is not None: # otherwise `deltas` is an int `0`
@@ -1532,7 +1507,7 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMi
1532
1507
  def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1533
1508
  return self.model.get_image_features(pixel_values, image_grid_thw)
1534
1509
 
1535
- @check_model_inputs
1510
+ @can_return_tuple
1536
1511
  def forward(
1537
1512
  self,
1538
1513
  input_ids: torch.LongTensor = None,
@@ -1642,6 +1617,8 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMi
1642
1617
  aux_loss=aux_loss,
1643
1618
  logits=logits,
1644
1619
  past_key_values=outputs.past_key_values,
1620
+ hidden_states=outputs.hidden_states,
1621
+ attentions=outputs.attentions,
1645
1622
  rope_deltas=outputs.rope_deltas,
1646
1623
  )
1647
1624
 
@@ -1677,8 +1654,33 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMi
1677
1654
  **kwargs,
1678
1655
  )
1679
1656
 
1680
- # Qwen3VLMoe position_ids are prepareed with rope_deltas in forward
1681
- model_inputs["position_ids"] = None
1657
+ # Qwen3VLMoe position_ids are prepared with rope_deltas
1658
+ if position_ids is None:
1659
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1660
+ # When compiling, we can't check tensor values thus we check only input length
1661
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1662
+ # models currently cannot do asssisted decoding
1663
+ if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
1664
+ vision_positions, rope_deltas = self.model.get_rope_index(
1665
+ model_inputs.get("input_ids", None),
1666
+ image_grid_thw=image_grid_thw,
1667
+ video_grid_thw=video_grid_thw,
1668
+ attention_mask=attention_mask,
1669
+ )
1670
+ self.model.rope_deltas = rope_deltas
1671
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1672
+ elif "position_ids" in model_inputs:
1673
+ batch_size, seq_length = model_inputs["position_ids"].shape
1674
+ device = model_inputs["position_ids"].device
1675
+ position_ids = torch.arange(seq_length, device=device)
1676
+ position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1677
+ delta = cache_position[0] + self.model.rope_deltas
1678
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1679
+ vision_positions = position_ids + delta.expand_as(position_ids)
1680
+
1681
+ # Concatenate "text + vision" positions into [4, bs, seq-len]
1682
+ text_positions = model_inputs["position_ids"][None, ...]
1683
+ model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
1682
1684
 
1683
1685
  if cache_position[0] != 0:
1684
1686
  model_inputs["pixel_values"] = None
@@ -26,7 +26,7 @@ from ...configuration_utils import PreTrainedConfig
26
26
  from ...modeling_rope_utils import RopeParameters
27
27
  from ...modeling_utils import PreTrainedModel
28
28
  from ...processing_utils import Unpack
29
- from ...utils import TransformersKwargs, logging
29
+ from ...utils import TransformersKwargs, can_return_tuple, logging
30
30
  from ..qwen3_moe.modeling_qwen3_moe import (
31
31
  Qwen3MoeDecoderLayer,
32
32
  Qwen3MoePreTrainedModel,
@@ -387,6 +387,7 @@ class Qwen3VLMoeModel(Qwen3VLModel):
387
387
 
388
388
 
389
389
  class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
390
+ @can_return_tuple
390
391
  def forward(
391
392
  self,
392
393
  input_ids: torch.LongTensor = None,
@@ -496,6 +497,8 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
496
497
  aux_loss=aux_loss,
497
498
  logits=logits,
498
499
  past_key_values=outputs.past_key_values,
500
+ hidden_states=outputs.hidden_states,
501
+ attentions=outputs.attentions,
499
502
  rope_deltas=outputs.rope_deltas,
500
503
  )
501
504
 
@@ -439,6 +439,7 @@ class RagModel(RagPreTrainedModel):
439
439
  output_hidden_states: Optional[bool] = None,
440
440
  output_retrieved: Optional[bool] = None,
441
441
  n_docs: Optional[int] = None,
442
+ **kwargs,
442
443
  ) -> Union[tuple[torch.Tensor], RetrievAugLMOutput]:
443
444
  r"""
444
445
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -31,6 +31,7 @@ from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
31
31
  from ...modeling_rope_utils import dynamic_rope_update
32
32
  from ...modeling_utils import PreTrainedModel
33
33
  from ...utils import auto_docstring, logging
34
+ from ...utils.generic import maybe_autocast
34
35
  from ...utils.import_utils import is_tracing
35
36
  from .configuration_recurrent_gemma import RecurrentGemmaConfig
36
37
 
@@ -121,7 +122,7 @@ class RecurrentGemmaRotaryEmbedding(nn.Module):
121
122
  position_ids_expanded = position_ids[:, None, :].float()
122
123
 
123
124
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
124
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
125
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
125
126
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
126
127
  emb = torch.cat((freqs, freqs), dim=-1)
127
128
  cos = emb.cos() * self.attention_scaling
@@ -460,6 +461,7 @@ class RecurrentGemmaRecurrentBlock(nn.Module):
460
461
  use_cache: bool = True,
461
462
  ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
462
463
  _, seq_len, _ = input_states.shape
464
+ batch_size = input_states.shape[0]
463
465
 
464
466
  y_branch = self.linear_y(input_states)
465
467
  y_branch = self.act_fn(y_branch)
@@ -468,6 +470,17 @@ class RecurrentGemmaRecurrentBlock(nn.Module):
468
470
  x_branch = x_branch.transpose(1, 2)
469
471
 
470
472
  if use_cache:
473
+ # Check if cache needs initialization (None or batch size mismatch)
474
+ if self.conv1d_state is None or self.conv1d_state.shape[0] != batch_size:
475
+ self.conv1d_state = torch.zeros(
476
+ (batch_size, self.hidden_size, self.conv1d_width - 1),
477
+ device=input_states.device,
478
+ dtype=input_states.dtype,
479
+ )
480
+ self.rg_lru.recurrent_states = torch.zeros(
481
+ (batch_size, self.lru_width), device=input_states.device, dtype=torch.float32
482
+ )
483
+
471
484
  if cache_position.shape[0] != 1: # prefill
472
485
  self.conv1d_state = nn.functional.pad(x_branch, (self.conv1d_width - x_branch.shape[-1] - 1, 0))
473
486
  x_branch = self.conv_1d(x_branch)[..., :seq_len]
@@ -643,6 +656,7 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
643
656
  use_cache: Optional[bool] = None,
644
657
  output_hidden_states: Optional[bool] = None,
645
658
  return_dict: Optional[bool] = None,
659
+ **kwargs,
646
660
  ) -> Union[tuple, BaseModelOutputWithNoAttention]:
647
661
  output_hidden_states = (
648
662
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1946,6 +1946,7 @@ class ReformerModel(ReformerPreTrainedModel):
1946
1946
  output_hidden_states: Optional[bool] = None,
1947
1947
  output_attentions: Optional[bool] = None,
1948
1948
  return_dict: Optional[bool] = None,
1949
+ **kwargs,
1949
1950
  ) -> Union[tuple, ReformerModelOutput]:
1950
1951
  r"""
1951
1952
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -2297,6 +2298,7 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
2297
2298
  output_hidden_states: Optional[bool] = None,
2298
2299
  output_attentions: Optional[bool] = None,
2299
2300
  return_dict: Optional[bool] = None,
2301
+ **kwargs,
2300
2302
  ) -> Union[tuple, MaskedLMOutput]:
2301
2303
  r"""
2302
2304
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -2428,6 +2430,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
2428
2430
  output_hidden_states: Optional[bool] = None,
2429
2431
  output_attentions: Optional[bool] = None,
2430
2432
  return_dict: Optional[bool] = None,
2433
+ **kwargs,
2431
2434
  ) -> Union[tuple, SequenceClassifierOutput]:
2432
2435
  r"""
2433
2436
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -2577,6 +2580,7 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
2577
2580
  output_hidden_states: Optional[bool] = None,
2578
2581
  output_attentions: Optional[bool] = None,
2579
2582
  return_dict: Optional[bool] = None,
2583
+ **kwargs,
2580
2584
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
2581
2585
  r"""
2582
2586
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
  """Tokenization class for model Reformer."""
16
16
 
17
- from typing import Optional
17
+ from typing import Optional, Union
18
18
 
19
19
  from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers
20
20
  from tokenizers.models import BPE
@@ -60,38 +60,27 @@ class ReformerTokenizer(TokenizersBackend):
60
60
  The token used for padding, for example when batching sequences of different lengths.
61
61
  additional_special_tokens (`list[str]`, *optional*):
62
62
  Additional special tokens used by the tokenizer.
63
- vocab (`dict`, *optional*):
64
- Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
65
- merges (`list`, *optional*):
66
- Custom merges list. If not provided, merges are loaded from vocab_file.
63
+ vocab (`str` or `dict[str, int]`, *optional*):
64
+ Custom vocabulary dictionary. If not provided, vocabulary is loaded from `vocab_file`.
65
+ merges (`str` or `list[str]`, *optional*):
66
+ Custom merges list. If not provided, merges are loaded from `vocab_file`.
67
67
  """
68
68
 
69
69
  vocab_files_names = VOCAB_FILES_NAMES
70
70
  model_input_names = ["input_ids", "attention_mask"]
71
- slow_tokenizer_class = None
71
+ model = BPE
72
72
 
73
73
  def __init__(
74
74
  self,
75
- vocab_file: Optional[str] = None,
75
+ vocab: Optional[Union[str, dict[str, int]]] = None,
76
+ merges: Optional[Union[str, list[str]]] = None,
76
77
  eos_token: str = "</s>",
77
78
  unk_token: str = "<unk>",
78
79
  additional_special_tokens: Optional[list] = None,
79
- vocab: Optional[dict] = None,
80
- merges: Optional[list] = None,
81
80
  **kwargs,
82
81
  ):
83
- self.vocab_file = vocab_file
84
-
85
- if vocab is not None:
86
- self._vocab = vocab
87
- else:
88
- self._vocab = {}
89
-
90
- if merges is not None:
91
- # Convert lists to tuples if necessary (happens when loading from JSON)
92
- self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
93
- else:
94
- self._merges = []
82
+ self._vocab = vocab or {}
83
+ self._merges = merges or []
95
84
 
96
85
  self._tokenizer = Tokenizer(
97
86
  BPE(
@@ -106,10 +95,7 @@ class ReformerTokenizer(TokenizersBackend):
106
95
 
107
96
  self._tokenizer.normalizer = normalizers.Sequence(
108
97
  [
109
- normalizers.Replace("\n", " "),
110
- normalizers.Replace("\r", " "),
111
- normalizers.Replace("\t", " "),
112
- normalizers.Replace(Regex(r" {2,}"), " "),
98
+ normalizers.Replace(Regex(r"\s{2,}|[\n\r\t]"), " "),
113
99
  normalizers.NFC(),
114
100
  normalizers.Strip(left=False, right=True),
115
101
  ]
@@ -118,10 +104,7 @@ class ReformerTokenizer(TokenizersBackend):
118
104
  self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always")
119
105
  self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always")
120
106
 
121
- tokenizer_object = self._tokenizer
122
-
123
107
  super().__init__(
124
- tokenizer_object=tokenizer_object,
125
108
  eos_token=eos_token,
126
109
  unk_token=unk_token,
127
110
  additional_special_tokens=additional_special_tokens or [],
@@ -294,7 +294,11 @@ class RegNetModel(RegNetPreTrainedModel):
294
294
 
295
295
  @auto_docstring
296
296
  def forward(
297
- self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
297
+ self,
298
+ pixel_values: Tensor,
299
+ output_hidden_states: Optional[bool] = None,
300
+ return_dict: Optional[bool] = None,
301
+ **kwargs,
298
302
  ) -> BaseModelOutputWithPoolingAndNoAttention:
299
303
  output_hidden_states = (
300
304
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -348,6 +352,7 @@ class RegNetForImageClassification(RegNetPreTrainedModel):
348
352
  labels: Optional[torch.LongTensor] = None,
349
353
  output_hidden_states: Optional[bool] = None,
350
354
  return_dict: Optional[bool] = None,
355
+ **kwargs,
351
356
  ) -> ImageClassifierOutputWithNoAttention:
352
357
  r"""
353
358
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -540,6 +540,7 @@ class RemBertModel(RemBertPreTrainedModel):
540
540
  output_hidden_states: Optional[bool] = None,
541
541
  return_dict: Optional[bool] = None,
542
542
  cache_position: Optional[torch.Tensor] = None,
543
+ **kwargs,
543
544
  ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
544
545
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
545
546
  output_hidden_states = (
@@ -659,6 +660,7 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
659
660
  output_attentions: Optional[bool] = None,
660
661
  output_hidden_states: Optional[bool] = None,
661
662
  return_dict: Optional[bool] = None,
663
+ **kwargs,
662
664
  ) -> Union[tuple, MaskedLMOutput]:
663
665
  r"""
664
666
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -857,6 +859,7 @@ class RemBertForSequenceClassification(RemBertPreTrainedModel):
857
859
  output_attentions: Optional[bool] = None,
858
860
  output_hidden_states: Optional[bool] = None,
859
861
  return_dict: Optional[bool] = None,
862
+ **kwargs,
860
863
  ) -> Union[tuple, SequenceClassifierOutput]:
861
864
  r"""
862
865
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -940,6 +943,7 @@ class RemBertForMultipleChoice(RemBertPreTrainedModel):
940
943
  output_attentions: Optional[bool] = None,
941
944
  output_hidden_states: Optional[bool] = None,
942
945
  return_dict: Optional[bool] = None,
946
+ **kwargs,
943
947
  ) -> Union[tuple, MultipleChoiceModelOutput]:
944
948
  r"""
945
949
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -1043,6 +1047,7 @@ class RemBertForTokenClassification(RemBertPreTrainedModel):
1043
1047
  output_attentions: Optional[bool] = None,
1044
1048
  output_hidden_states: Optional[bool] = None,
1045
1049
  return_dict: Optional[bool] = None,
1050
+ **kwargs,
1046
1051
  ) -> Union[tuple, TokenClassifierOutput]:
1047
1052
  r"""
1048
1053
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1109,6 +1114,7 @@ class RemBertForQuestionAnswering(RemBertPreTrainedModel):
1109
1114
  output_attentions: Optional[bool] = None,
1110
1115
  output_hidden_states: Optional[bool] = None,
1111
1116
  return_dict: Optional[bool] = None,
1117
+ **kwargs,
1112
1118
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1113
1119
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1114
1120
 
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
  """Tokenization classes for RemBert model."""
16
16
 
17
- from typing import Optional
17
+ from typing import Optional, Union
18
18
 
19
19
  from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
20
20
  from tokenizers.models import Unigram
@@ -74,11 +74,11 @@ class RemBertTokenizer(TokenizersBackend):
74
74
 
75
75
  vocab_files_names = VOCAB_FILES_NAMES
76
76
  model_input_names = ["input_ids", "attention_mask"]
77
- slow_tokenizer_class = None
77
+ model = Unigram
78
78
 
79
79
  def __init__(
80
80
  self,
81
- vocab_file: Optional[str] = None,
81
+ vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
82
82
  do_lower_case: bool = False,
83
83
  keep_accents: bool = False,
84
84
  bos_token: str = "[CLS]",
@@ -90,11 +90,8 @@ class RemBertTokenizer(TokenizersBackend):
90
90
  mask_token: str = "[MASK]",
91
91
  add_prefix_space: bool = True,
92
92
  remove_space: bool = True,
93
- vocab: Optional[dict] = None,
94
- merges: Optional[list] = None,
95
93
  **kwargs,
96
94
  ):
97
- self.vocab_file = vocab_file
98
95
  self.remove_space = remove_space
99
96
  self.do_lower_case = do_lower_case
100
97
  self.keep_accents = keep_accents
@@ -147,11 +144,7 @@ class RemBertTokenizer(TokenizersBackend):
147
144
  self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme=prepend_scheme)
148
145
 
149
146
  self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme=prepend_scheme)
150
-
151
- tokenizer_object = self._tokenizer
152
-
153
147
  super().__init__(
154
- tokenizer_object=tokenizer_object,
155
148
  add_prefix_space=add_prefix_space,
156
149
  do_lower_case=do_lower_case,
157
150
  keep_accents=keep_accents,
@@ -280,7 +280,11 @@ class ResNetModel(ResNetPreTrainedModel):
280
280
 
281
281
  @auto_docstring
282
282
  def forward(
283
- self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
283
+ self,
284
+ pixel_values: Tensor,
285
+ output_hidden_states: Optional[bool] = None,
286
+ return_dict: Optional[bool] = None,
287
+ **kwargs,
284
288
  ) -> BaseModelOutputWithPoolingAndNoAttention:
285
289
  output_hidden_states = (
286
290
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -333,6 +337,7 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
333
337
  labels: Optional[torch.LongTensor] = None,
334
338
  output_hidden_states: Optional[bool] = None,
335
339
  return_dict: Optional[bool] = None,
340
+ **kwargs,
336
341
  ) -> ImageClassifierOutputWithNoAttention:
337
342
  r"""
338
343
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -380,7 +385,11 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
380
385
 
381
386
  @auto_docstring
382
387
  def forward(
383
- self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
388
+ self,
389
+ pixel_values: Tensor,
390
+ output_hidden_states: Optional[bool] = None,
391
+ return_dict: Optional[bool] = None,
392
+ **kwargs,
384
393
  ) -> BackboneOutput:
385
394
  r"""
386
395
  Examples:
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
  """Tokenization classes for RoBERTa."""
16
16
 
17
- from typing import Optional
17
+ from typing import Optional, Union
18
18
 
19
19
  from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
20
20
  from tokenizers.models import BPE
@@ -59,6 +59,10 @@ class RobertaTokenizer(TokenizersBackend):
59
59
  this superclass for more information regarding those methods.
60
60
 
61
61
  Args:
62
+ vocab (`str`, `dict` or `list`, *optional*):
63
+ Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
64
+ merges (`str` or `list`, *optional*):
65
+ Custom merges list. If not provided, merges are loaded from merges_file.
62
66
  errors (`str`, *optional*, defaults to `"replace"`):
63
67
  Paradigm to follow when decoding bytes to UTF-8. See
64
68
  [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
@@ -102,18 +106,16 @@ class RobertaTokenizer(TokenizersBackend):
102
106
  other word. (RoBERTa tokenizer detect beginning of words by the preceding space).
103
107
  trim_offsets (`bool`, *optional*, defaults to `True`):
104
108
  Whether the post processing step should trim offsets to avoid including whitespaces.
105
- vocab (`dict`, *optional*):
106
- Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
107
- merges (`list`, *optional*):
108
- Custom merges list. If not provided, merges are loaded from merges_file.
109
109
  """
110
110
 
111
111
  vocab_files_names = VOCAB_FILES_NAMES
112
112
  model_input_names = ["input_ids", "attention_mask"]
113
- slow_tokenizer_class = None
113
+ model = BPE
114
114
 
115
115
  def __init__(
116
116
  self,
117
+ vocab: Optional[Union[str, dict[str, int]]] = None,
118
+ merges: Optional[Union[str, list[str]]] = None,
117
119
  errors: str = "replace",
118
120
  bos_token: str = "<s>",
119
121
  eos_token: str = "</s>",
@@ -124,30 +126,22 @@ class RobertaTokenizer(TokenizersBackend):
124
126
  mask_token: str = "<mask>",
125
127
  add_prefix_space: bool = False,
126
128
  trim_offsets: bool = True,
127
- vocab: Optional[dict] = None,
128
- merges: Optional[list] = None,
129
129
  **kwargs,
130
130
  ):
131
131
  self.add_prefix_space = add_prefix_space
132
132
  self.trim_offsets = trim_offsets
133
133
 
134
- if vocab is not None:
135
- self._vocab = (
136
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
137
- )
138
- else:
139
- self._vocab = {
134
+ if vocab is None:
135
+ vocab = {
140
136
  str(pad_token): 0,
141
137
  str(unk_token): 1,
142
138
  str(cls_token): 2,
143
139
  str(sep_token): 3,
144
140
  str(mask_token): 4,
145
141
  }
142
+ self._vocab = vocab
146
143
 
147
- if merges is not None:
148
- self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
149
- else:
150
- self._merges = []
144
+ self._merges = merges or []
151
145
 
152
146
  self._tokenizer = Tokenizer(
153
147
  BPE(
@@ -162,17 +156,8 @@ class RobertaTokenizer(TokenizersBackend):
162
156
 
163
157
  self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
164
158
  self._tokenizer.decoder = decoders.ByteLevel()
165
- self._tokenizer.post_processor = processors.RobertaProcessing(
166
- sep=(str(sep_token), self._vocab.get(str(sep_token), 3)),
167
- cls=(str(cls_token), self._vocab.get(str(cls_token), 2)),
168
- add_prefix_space=add_prefix_space,
169
- trim_offsets=trim_offsets,
170
- )
171
-
172
- tokenizer_object = self._tokenizer
173
159
 
174
160
  super().__init__(
175
- tokenizer_object=tokenizer_object,
176
161
  errors=errors,
177
162
  bos_token=bos_token,
178
163
  eos_token=eos_token,
@@ -185,6 +170,12 @@ class RobertaTokenizer(TokenizersBackend):
185
170
  trim_offsets=trim_offsets,
186
171
  **kwargs,
187
172
  )
173
+ self._tokenizer.post_processor = processors.RobertaProcessing(
174
+ sep=(str(sep_token), self.sep_token_id),
175
+ cls=(str(cls_token), self.cls_token_id),
176
+ add_prefix_space=add_prefix_space,
177
+ trim_offsets=trim_offsets,
178
+ )
188
179
 
189
180
 
190
181
  __all__ = ["RobertaTokenizer"]
@@ -693,6 +693,7 @@ class RoFormerModel(RoFormerPreTrainedModel):
693
693
  output_hidden_states: Optional[bool] = None,
694
694
  return_dict: Optional[bool] = None,
695
695
  cache_position: Optional[torch.Tensor] = None,
696
+ **kwargs,
696
697
  ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple[torch.Tensor]]:
697
698
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
698
699
  output_hidden_states = (
@@ -821,6 +822,7 @@ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
821
822
  output_attentions: Optional[bool] = None,
822
823
  output_hidden_states: Optional[bool] = None,
823
824
  return_dict: Optional[bool] = None,
825
+ **kwargs,
824
826
  ) -> Union[MaskedLMOutput, tuple[torch.Tensor]]:
825
827
  r"""
826
828
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1035,6 +1037,7 @@ class RoFormerForSequenceClassification(RoFormerPreTrainedModel):
1035
1037
  output_attentions: Optional[bool] = None,
1036
1038
  output_hidden_states: Optional[bool] = None,
1037
1039
  return_dict: Optional[bool] = None,
1040
+ **kwargs,
1038
1041
  ) -> Union[SequenceClassifierOutput, tuple[torch.Tensor]]:
1039
1042
  r"""
1040
1043
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1114,6 +1117,7 @@ class RoFormerForMultipleChoice(RoFormerPreTrainedModel):
1114
1117
  output_attentions: Optional[bool] = None,
1115
1118
  output_hidden_states: Optional[bool] = None,
1116
1119
  return_dict: Optional[bool] = None,
1120
+ **kwargs,
1117
1121
  ) -> Union[MultipleChoiceModelOutput, tuple[torch.Tensor]]:
1118
1122
  r"""
1119
1123
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -1210,6 +1214,7 @@ class RoFormerForTokenClassification(RoFormerPreTrainedModel):
1210
1214
  output_attentions: Optional[bool] = None,
1211
1215
  output_hidden_states: Optional[bool] = None,
1212
1216
  return_dict: Optional[bool] = None,
1217
+ **kwargs,
1213
1218
  ) -> Union[TokenClassifierOutput, tuple[torch.Tensor]]:
1214
1219
  r"""
1215
1220
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1275,6 +1280,7 @@ class RoFormerForQuestionAnswering(RoFormerPreTrainedModel):
1275
1280
  output_attentions: Optional[bool] = None,
1276
1281
  output_hidden_states: Optional[bool] = None,
1277
1282
  return_dict: Optional[bool] = None,
1283
+ **kwargs,
1278
1284
  ) -> Union[QuestionAnsweringModelOutput, tuple[torch.Tensor]]:
1279
1285
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1280
1286