transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -26,12 +26,12 @@ from ...activations import ACT2FN
26
26
  from ...cache_utils import Cache, DynamicCache
27
27
  from ...configuration_utils import PreTrainedConfig, layer_type_validation
28
28
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
29
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
30
29
  from ...modeling_outputs import BaseModelOutputWithPast
31
30
  from ...modeling_rope_utils import RopeParameters
32
31
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
33
32
  from ...processing_utils import Unpack
34
33
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
34
+ from ...utils.generic import check_model_inputs
35
35
  from ..auto import AutoModel
36
36
  from ..gemma2.configuration_gemma2 import Gemma2Config
37
37
  from ..gemma2.modeling_gemma2 import (
@@ -1474,7 +1474,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
1474
1474
  )
1475
1475
 
1476
1476
  def forward(
1477
- self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
1477
+ self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs
1478
1478
  ) -> tuple[torch.Tensor, torch.BoolTensor]:
1479
1479
  """Encodes a batch of MELs.
1480
1480
 
@@ -1742,7 +1742,7 @@ class Gemma3nTextAttention(Gemma3Attention):
1742
1742
  attention_mask: Optional[torch.Tensor] = None,
1743
1743
  past_key_values: Optional[Cache] = None,
1744
1744
  cache_position: Optional[torch.LongTensor] = None,
1745
- **kwargs: Unpack[FlashAttentionKwargs],
1745
+ **kwargs: Unpack[TransformersKwargs],
1746
1746
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
1747
1747
  input_shape = hidden_states.shape[:-1]
1748
1748
  hidden_shape = (*input_shape, -1, self.config.head_dim)
@@ -1830,10 +1830,8 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
1830
1830
  attention_mask: Optional[torch.Tensor] = None,
1831
1831
  position_ids: Optional[torch.LongTensor] = None,
1832
1832
  past_key_values: Optional[Cache] = None,
1833
- output_attentions: Optional[bool] = False,
1834
- use_cache: Optional[bool] = False,
1835
1833
  cache_position: Optional[torch.LongTensor] = None,
1836
- **kwargs,
1834
+ **kwargs: Unpack[TransformersKwargs],
1837
1835
  ) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
1838
1836
  predictions = self.altup.predict(hidden_states)
1839
1837
  active_prediction = predictions[self.config.altup_active_idx]
@@ -1841,14 +1839,12 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
1841
1839
  active_prediction_normed = self.input_layernorm(active_prediction)
1842
1840
  laurel_output = self.laurel(active_prediction_normed)
1843
1841
 
1844
- attn, self_attn_weights = self.self_attn(
1842
+ attn, _ = self.self_attn(
1845
1843
  hidden_states=active_prediction_normed,
1846
1844
  attention_mask=attention_mask,
1847
1845
  position_ids=position_ids,
1848
1846
  position_embeddings=position_embeddings,
1849
1847
  past_key_values=past_key_values,
1850
- output_attentions=output_attentions,
1851
- use_cache=use_cache,
1852
1848
  cache_position=cache_position,
1853
1849
  **kwargs,
1854
1850
  )
@@ -1877,18 +1873,17 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
1877
1873
  first_prediction = self.post_per_layer_input_norm(first_prediction)
1878
1874
  corrected_predictions[1:] += first_prediction
1879
1875
 
1880
- outputs = (corrected_predictions,)
1881
-
1882
- if output_attentions:
1883
- outputs += (self_attn_weights,)
1884
-
1885
- return outputs
1876
+ return corrected_predictions
1886
1877
 
1887
1878
 
1888
1879
  class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
1889
1880
  config: Gemma3nConfig
1890
1881
  input_modalities = ("image", "text", "audio")
1891
1882
  _no_split_modules = ["Gemma3nTextDecoderLayer"]
1883
+ _can_record_outputs = {
1884
+ "hidden_states": Gemma3nTextDecoderLayer,
1885
+ "attentions": Gemma3nTextAttention,
1886
+ }
1892
1887
 
1893
1888
  @torch.no_grad()
1894
1889
  def _init_weights(self, module):
@@ -1976,7 +1971,8 @@ class Gemma3nTextModel(Gemma3TextModel):
1976
1971
  dtype=inputs_embeds.dtype, device=per_layer_projection.device
1977
1972
  )
1978
1973
 
1979
- @can_return_tuple
1974
+ # Last hidden states should be before reprojecting, to stay consistent with the other layer outputs
1975
+ @check_model_inputs(tie_last_hidden_states=False)
1980
1976
  @auto_docstring
1981
1977
  def forward(
1982
1978
  self,
@@ -1987,8 +1983,6 @@ class Gemma3nTextModel(Gemma3TextModel):
1987
1983
  past_key_values: Optional[Cache] = None,
1988
1984
  inputs_embeds: Optional[torch.FloatTensor] = None,
1989
1985
  use_cache: Optional[bool] = None,
1990
- output_attentions: Optional[bool] = None,
1991
- output_hidden_states: Optional[bool] = None,
1992
1986
  cache_position: Optional[torch.LongTensor] = None,
1993
1987
  **kwargs: Unpack[TransformersKwargs],
1994
1988
  ) -> BaseModelOutputWithPast:
@@ -1996,37 +1990,21 @@ class Gemma3nTextModel(Gemma3TextModel):
1996
1990
  per_layer_inputs (torch.Tensor, *optional*, defaults to None):
1997
1991
  Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
1998
1992
  """
1999
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2000
- output_hidden_states = (
2001
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2002
- )
2003
- use_cache = use_cache if use_cache is not None else self.config.use_cache
2004
-
2005
1993
  if (input_ids is None) ^ (inputs_embeds is not None):
2006
1994
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
2007
1995
 
2008
- if self.gradient_checkpointing and self.training and use_cache:
2009
- logger.warning_once(
2010
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
2011
- )
2012
- use_cache = False
2013
-
2014
1996
  if input_ids is not None:
2015
1997
  inputs_embeds = self.embed_tokens(input_ids)
2016
1998
  per_layer_inputs = self.get_per_layer_inputs(input_ids)
2017
1999
 
2018
2000
  per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
2019
2001
 
2020
- if use_cache and past_key_values is None and not self.training:
2002
+ if use_cache and past_key_values is None:
2021
2003
  past_key_values = DynamicCache(config=self.config)
2022
2004
 
2023
2005
  if cache_position is None:
2024
2006
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
2025
- cache_position = torch.arange(
2026
- past_seen_tokens,
2027
- past_seen_tokens + inputs_embeds.shape[1],
2028
- device=inputs_embeds.device,
2029
- )
2007
+ cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
2030
2008
 
2031
2009
  if position_ids is None:
2032
2010
  position_ids = cache_position.unsqueeze(0)
@@ -2070,39 +2048,21 @@ class Gemma3nTextModel(Gemma3TextModel):
2070
2048
  for layer_type in self.config.layer_types:
2071
2049
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
2072
2050
 
2073
- # decoder layers
2074
- all_hidden_states = () if output_hidden_states else None
2075
- all_self_attns = () if output_attentions else None
2076
-
2077
2051
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
2078
- if output_hidden_states:
2079
- all_hidden_states += (hidden_states,)
2080
-
2081
2052
  causal_mask = causal_mask_mapping[decoder_layer.attention_type]
2082
2053
  per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
2083
2054
 
2084
- layer_outputs = decoder_layer(
2055
+ hidden_states = decoder_layer(
2085
2056
  hidden_states,
2086
2057
  position_embeddings[decoder_layer.attention_type],
2087
2058
  per_layer_input,
2088
2059
  attention_mask=causal_mask,
2089
2060
  position_ids=position_ids,
2090
2061
  past_key_values=past_key_values,
2091
- output_attentions=output_attentions,
2092
- use_cache=use_cache,
2093
2062
  cache_position=cache_position,
2094
2063
  **kwargs,
2095
2064
  )
2096
2065
 
2097
- hidden_states = layer_outputs[0]
2098
-
2099
- if output_attentions:
2100
- all_self_attns += (layer_outputs[1],)
2101
-
2102
- # add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output)
2103
- if output_hidden_states:
2104
- all_hidden_states += (hidden_states,)
2105
-
2106
2066
  # Per-layer inputs to single output
2107
2067
  target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
2108
2068
  temp_hidden_states = [hidden_states[0]]
@@ -2122,8 +2082,6 @@ class Gemma3nTextModel(Gemma3TextModel):
2122
2082
  return BaseModelOutputWithPast(
2123
2083
  last_hidden_state=hidden_states,
2124
2084
  past_key_values=past_key_values,
2125
- hidden_states=all_hidden_states,
2126
- attentions=all_self_attns,
2127
2085
  )
2128
2086
 
2129
2087
 
@@ -2284,7 +2242,7 @@ class Gemma3nModel(PaliGemmaModel):
2284
2242
  use_cache: Optional[bool] = None,
2285
2243
  output_attentions: Optional[bool] = None,
2286
2244
  output_hidden_states: Optional[bool] = None,
2287
- **lm_kwargs,
2245
+ **lm_kwargs: Unpack[TransformersKwargs],
2288
2246
  ) -> Gemma3nCausalLMOutputWithPast:
2289
2247
  r"""
2290
2248
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -2456,7 +2414,7 @@ class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration):
2456
2414
  output_attentions: Optional[bool] = None,
2457
2415
  output_hidden_states: Optional[bool] = None,
2458
2416
  logits_to_keep: Union[int, torch.Tensor] = 0,
2459
- **lm_kwargs,
2417
+ **lm_kwargs: Unpack[TransformersKwargs],
2460
2418
  ) -> Gemma3nCausalLMOutputWithPast:
2461
2419
  r"""
2462
2420
  input_features_mask (torch.Tensor, *optional*, defaults to None):
@@ -827,6 +827,7 @@ class GitVisionModel(GitPreTrainedModel):
827
827
  output_hidden_states: Optional[bool] = None,
828
828
  interpolate_pos_encoding: bool = False,
829
829
  return_dict: Optional[bool] = None,
830
+ **kwargs,
830
831
  ) -> Union[tuple, BaseModelOutput]:
831
832
  r"""
832
833
  Examples:
@@ -972,6 +973,7 @@ class GitModel(GitPreTrainedModel):
972
973
  output_hidden_states: Optional[bool] = None,
973
974
  interpolate_pos_encoding: bool = False,
974
975
  return_dict: Optional[bool] = None,
976
+ **kwargs,
975
977
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
976
978
  r"""
977
979
  Examples:
@@ -28,7 +28,7 @@ import torch.nn as nn
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_forward_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
32
32
  from ...masking_utils import create_causal_mask
33
33
  from ...modeling_layers import (
34
34
  GenericForSequenceClassification,
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
42
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
43
- from ...utils.generic import check_model_inputs
43
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
44
  from .configuration_glm import GlmConfig
45
45
 
46
46
 
@@ -120,7 +120,7 @@ class GlmRotaryEmbedding(nn.Module):
120
120
  position_ids_expanded = position_ids[:, None, :].float()
121
121
 
122
122
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
123
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
123
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
124
124
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
125
125
  emb = torch.cat((freqs, freqs), dim=-1)
126
126
  cos = emb.cos() * self.attention_scaling
@@ -216,6 +216,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
216
216
  return q_embed, k_embed
217
217
 
218
218
 
219
+ @use_kernelized_func(apply_rotary_pos_emb)
219
220
  class GlmAttention(nn.Module):
220
221
  """Multi-headed attention from 'Attention Is All You Need' paper"""
221
222
 
@@ -239,7 +240,6 @@ class GlmAttention(nn.Module):
239
240
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
240
241
  )
241
242
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
242
- self.rotary_fn = apply_rotary_pos_emb
243
243
 
244
244
  def forward(
245
245
  self,
@@ -28,7 +28,7 @@ import torch.nn as nn
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_forward_from_hub
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
32
32
  from ...masking_utils import create_causal_mask
33
33
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
34
34
  from ...modeling_layers import (
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
41
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
42
  from ...processing_utils import Unpack
43
43
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
44
- from ...utils.generic import check_model_inputs
44
+ from ...utils.generic import check_model_inputs, maybe_autocast
45
45
  from .configuration_glm4 import Glm4Config
46
46
 
47
47
 
@@ -198,6 +198,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
198
198
  return q_embed, k_embed
199
199
 
200
200
 
201
+ @use_kernelized_func(apply_rotary_pos_emb)
201
202
  class Glm4Attention(nn.Module):
202
203
  """Multi-headed attention from 'Attention Is All You Need' paper"""
203
204
 
@@ -221,7 +222,6 @@ class Glm4Attention(nn.Module):
221
222
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
222
223
  )
223
224
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
224
- self.rotary_fn = apply_rotary_pos_emb
225
225
 
226
226
  def forward(
227
227
  self,
@@ -325,7 +325,7 @@ class Glm4RotaryEmbedding(nn.Module):
325
325
  position_ids_expanded = position_ids[:, None, :].float()
326
326
 
327
327
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
328
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
328
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
329
329
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
330
330
  emb = torch.cat((freqs, freqs), dim=-1)
331
331
  cos = emb.cos() * self.attention_scaling
@@ -30,7 +30,7 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub
33
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
36
36
  from ...modeling_layers import GradientCheckpointingLayer
@@ -39,7 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
40
  from ...processing_utils import Unpack
41
41
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
42
- from ...utils.generic import check_model_inputs
42
+ from ...utils.generic import check_model_inputs, maybe_autocast
43
43
  from .configuration_glm4_moe import Glm4MoeConfig
44
44
 
45
45
 
@@ -101,7 +101,7 @@ class Glm4MoeRotaryEmbedding(nn.Module):
101
101
  position_ids_expanded = position_ids[:, None, :].float()
102
102
 
103
103
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
104
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
104
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
105
105
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
106
106
  emb = torch.cat((freqs, freqs), dim=-1)
107
107
  cos = emb.cos() * self.attention_scaling
@@ -193,6 +193,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
193
193
  return q_embed, k_embed
194
194
 
195
195
 
196
+ @use_kernelized_func(apply_rotary_pos_emb)
196
197
  class Glm4MoeAttention(nn.Module):
197
198
  """Multi-headed attention from 'Attention Is All You Need' paper"""
198
199
 
@@ -491,6 +492,7 @@ class Glm4MoePreTrainedModel(PreTrainedModel):
491
492
  "hidden_states": Glm4MoeDecoderLayer,
492
493
  "attentions": Glm4MoeAttention,
493
494
  }
495
+ _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
494
496
 
495
497
  @torch.no_grad()
496
498
  def _init_weights(self, module):
@@ -234,7 +234,9 @@ class Glm4vTextConfig(PreTrainedConfig):
234
234
  self.attention_dropout = attention_dropout
235
235
  self.rope_parameters = rope_parameters
236
236
 
237
- super().__init__(tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope"}, **kwargs)
237
+ super().__init__(
238
+ tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
239
+ )
238
240
 
239
241
 
240
242
  class Glm4vConfig(PreTrainedConfig):
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
42
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
43
- from ...utils.generic import check_model_inputs
43
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
44
  from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig
45
45
 
46
46
 
@@ -446,7 +446,7 @@ class Glm4vTextRotaryEmbedding(nn.Module):
446
446
  position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
447
447
 
448
448
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
449
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
449
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
450
450
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
451
451
  emb = torch.cat((freqs, freqs), dim=-1)
452
452
  cos = emb.cos() * self.attention_scaling
@@ -768,7 +768,7 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
768
768
  rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
769
769
  return rotary_pos_emb, pos_ids
770
770
 
771
- def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
771
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
772
772
  """
773
773
  Args:
774
774
  hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
@@ -36,7 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
36
36
  from ...processing_utils import Unpack
37
37
  from ...tokenization_utils_base import PreTokenizedInput, TextInput
38
38
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
39
- from ...utils.generic import check_model_inputs
39
+ from ...utils.generic import check_model_inputs, maybe_autocast
40
40
  from ...video_utils import VideoInput
41
41
  from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, Glm4RotaryEmbedding, eager_attention_forward
42
42
  from ..qwen2_5_vl.modeling_qwen2_5_vl import (
@@ -271,7 +271,9 @@ class Glm4vTextConfig(PreTrainedConfig):
271
271
  self.attention_dropout = attention_dropout
272
272
  self.rope_parameters = rope_parameters
273
273
 
274
- super().__init__(tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope"}, **kwargs)
274
+ super().__init__(
275
+ tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
276
+ )
275
277
 
276
278
 
277
279
  class Glm4vConfig(PreTrainedConfig):
@@ -509,7 +511,7 @@ class Glm4vTextRotaryEmbedding(Glm4RotaryEmbedding):
509
511
  position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
510
512
 
511
513
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
512
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
514
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
513
515
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
514
516
  emb = torch.cat((freqs, freqs), dim=-1)
515
517
  cos = emb.cos() * self.attention_scaling
@@ -786,7 +788,7 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
786
788
  rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
787
789
  return rotary_pos_emb, pos_ids
788
790
 
789
- def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
791
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
790
792
  """
791
793
  Args:
792
794
  hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
@@ -280,7 +280,9 @@ class Glm4vMoeTextConfig(PreTrainedConfig):
280
280
  self.first_k_dense_replace = first_k_dense_replace
281
281
  self.norm_topk_prob = norm_topk_prob
282
282
  self.router_aux_loss_coef = router_aux_loss_coef
283
- super().__init__(tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope"}, **kwargs)
283
+ super().__init__(
284
+ tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
285
+ )
284
286
 
285
287
 
286
288
  class Glm4vMoeConfig(PreTrainedConfig):
@@ -32,7 +32,7 @@ from ... import initialization as init
32
32
  from ...activations import ACT2FN
33
33
  from ...cache_utils import Cache, DynamicCache
34
34
  from ...generation import GenerationMixin
35
- from ...integrations import use_kernel_forward_from_hub
35
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
36
36
  from ...masking_utils import create_causal_mask
37
37
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
38
38
  from ...modeling_layers import GradientCheckpointingLayer
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
41
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
42
  from ...processing_utils import Unpack
43
43
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
44
- from ...utils.generic import OutputRecorder, check_model_inputs
44
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
45
45
  from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig
46
46
 
47
47
 
@@ -150,7 +150,7 @@ class Glm4vMoeTextRotaryEmbedding(nn.Module):
150
150
  position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
151
151
 
152
152
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
153
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
153
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
154
154
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
155
155
  emb = torch.cat((freqs, freqs), dim=-1)
156
156
  cos = emb.cos() * self.attention_scaling
@@ -299,6 +299,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
299
299
  return q_embed, k_embed
300
300
 
301
301
 
302
+ @use_kernelized_func(apply_rotary_pos_emb)
302
303
  class Glm4vMoeTextAttention(nn.Module):
303
304
  """Multi-headed attention from 'Attention Is All You Need' paper"""
304
305
 
@@ -322,7 +323,6 @@ class Glm4vMoeTextAttention(nn.Module):
322
323
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
323
324
  )
324
325
  self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
325
- self.rotary_fn = apply_rotary_pos_emb
326
326
  self.rope_parameters = config.rope_parameters
327
327
 
328
328
  def forward(
@@ -594,6 +594,7 @@ class Glm4vMoePreTrainedModel(PreTrainedModel):
594
594
  "attentions": Glm4vMoeTextAttention,
595
595
  "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
596
596
  }
597
+ _keep_in_fp32_modules_strict = ["e_score_correction_bias"]
597
598
  input_modalities = ("text", "image", "video")
598
599
 
599
600
  @torch.no_grad()
@@ -975,7 +976,7 @@ class Glm4vMoeVisionModel(Glm4vMoePreTrainedModel):
975
976
  rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
976
977
  return rotary_pos_emb, pos_ids
977
978
 
978
- def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
979
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
979
980
  """
980
981
  Args:
981
982
  hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
@@ -227,7 +227,7 @@ class Glm4vMoeTextConfig(Glm4MoeConfig, RotaryEmbeddingConfigMixin):
227
227
  self.norm_topk_prob = norm_topk_prob
228
228
  self.router_aux_loss_coef = router_aux_loss_coef
229
229
  PreTrainedConfig.__init__(
230
- self, tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope"}, **kwargs
230
+ self, tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
231
231
  )
232
232
 
233
233
 
@@ -411,6 +411,7 @@ class GLPNModel(GLPNPreTrainedModel):
411
411
  output_attentions: Optional[bool] = None,
412
412
  output_hidden_states: Optional[bool] = None,
413
413
  return_dict: Optional[bool] = None,
414
+ **kwargs,
414
415
  ) -> Union[tuple, BaseModelOutput]:
415
416
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
416
417
  output_hidden_states = (
@@ -597,6 +598,7 @@ class GLPNForDepthEstimation(GLPNPreTrainedModel):
597
598
  output_attentions: Optional[bool] = None,
598
599
  output_hidden_states: Optional[bool] = None,
599
600
  return_dict: Optional[bool] = None,
601
+ **kwargs,
600
602
  ) -> Union[tuple[torch.Tensor], DepthEstimatorOutput]:
601
603
  r"""
602
604
  labels (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*):
@@ -45,6 +45,7 @@ from ...utils import (
45
45
  auto_docstring,
46
46
  logging,
47
47
  )
48
+ from ...utils.generic import maybe_autocast
48
49
  from .configuration_gpt2 import GPT2Config
49
50
 
50
51
 
@@ -150,7 +151,7 @@ class GPT2Attention(nn.Module):
150
151
  scale_factor /= float(self.layer_idx + 1)
151
152
 
152
153
  # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
153
- with torch.autocast(query.device.type, enabled=False):
154
+ with maybe_autocast(query.device.type, enabled=False):
154
155
  q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
155
156
  attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
156
157
  attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
@@ -1021,6 +1022,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1021
1022
  output_attentions: Optional[bool] = None,
1022
1023
  output_hidden_states: Optional[bool] = None,
1023
1024
  return_dict: Optional[bool] = None,
1025
+ **kwargs,
1024
1026
  ) -> Union[tuple, SequenceClassifierOutputWithPast]:
1025
1027
  r"""
1026
1028
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -1148,6 +1150,7 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1148
1150
  output_attentions: Optional[bool] = None,
1149
1151
  output_hidden_states: Optional[bool] = None,
1150
1152
  return_dict: Optional[bool] = None,
1153
+ **kwargs,
1151
1154
  ) -> Union[tuple, TokenClassifierOutput]:
1152
1155
  r"""
1153
1156
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -1228,6 +1231,7 @@ class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
1228
1231
  output_attentions: Optional[bool] = None,
1229
1232
  output_hidden_states: Optional[bool] = None,
1230
1233
  return_dict: Optional[bool] = None,
1234
+ **kwargs,
1231
1235
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1232
1236
  r"""
1233
1237
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):