transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -1379,6 +1379,7 @@ class LEDEncoder(LEDPreTrainedModel):
1379
1379
  output_attentions=None,
1380
1380
  output_hidden_states=None,
1381
1381
  return_dict=None,
1382
+ **kwargs,
1382
1383
  ):
1383
1384
  r"""
1384
1385
  Args:
@@ -1573,6 +1574,7 @@ class LEDDecoder(LEDPreTrainedModel):
1573
1574
  output_hidden_states=None,
1574
1575
  return_dict=None,
1575
1576
  cache_position=None,
1577
+ **kwargs,
1576
1578
  ):
1577
1579
  r"""
1578
1580
  Args:
@@ -1788,6 +1790,7 @@ class LEDModel(LEDPreTrainedModel):
1788
1790
  output_hidden_states: Optional[bool] = None,
1789
1791
  return_dict: Optional[bool] = None,
1790
1792
  cache_position: Optional[torch.Tensor] = None,
1793
+ **kwargs,
1791
1794
  ) -> Union[tuple[torch.Tensor], LEDSeq2SeqModelOutput]:
1792
1795
  r"""
1793
1796
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1938,6 +1941,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin):
1938
1941
  output_hidden_states: Optional[bool] = None,
1939
1942
  return_dict: Optional[bool] = None,
1940
1943
  cache_position: Optional[torch.Tensor] = None,
1944
+ **kwargs,
1941
1945
  ) -> Union[tuple[torch.Tensor], LEDSeq2SeqLMOutput]:
1942
1946
  r"""
1943
1947
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -2120,6 +2124,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
2120
2124
  output_attentions: Optional[bool] = None,
2121
2125
  output_hidden_states: Optional[bool] = None,
2122
2126
  return_dict: Optional[bool] = None,
2127
+ **kwargs,
2123
2128
  ) -> Union[tuple[torch.Tensor], LEDSeq2SeqSequenceClassifierOutput]:
2124
2129
  r"""
2125
2130
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -2258,6 +2263,7 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
2258
2263
  output_attentions: Optional[bool] = None,
2259
2264
  output_hidden_states: Optional[bool] = None,
2260
2265
  return_dict: Optional[bool] = None,
2266
+ **kwargs,
2261
2267
  ) -> Union[tuple[torch.Tensor], LEDSeq2SeqQuestionAnsweringModelOutput]:
2262
2268
  r"""
2263
2269
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -489,6 +489,7 @@ class LevitModel(LevitPreTrainedModel):
489
489
  pixel_values: Optional[torch.FloatTensor] = None,
490
490
  output_hidden_states: Optional[bool] = None,
491
491
  return_dict: Optional[bool] = None,
492
+ **kwargs,
492
493
  ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
493
494
  output_hidden_states = (
494
495
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -550,6 +551,7 @@ class LevitForImageClassification(LevitPreTrainedModel):
550
551
  labels: Optional[torch.LongTensor] = None,
551
552
  output_hidden_states: Optional[bool] = None,
552
553
  return_dict: Optional[bool] = None,
554
+ **kwargs,
553
555
  ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
554
556
  r"""
555
557
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -616,6 +618,7 @@ class LevitForImageClassificationWithTeacher(LevitPreTrainedModel):
616
618
  pixel_values: Optional[torch.FloatTensor] = None,
617
619
  output_hidden_states: Optional[bool] = None,
618
620
  return_dict: Optional[bool] = None,
621
+ **kwargs,
619
622
  ) -> Union[tuple, LevitForImageClassificationWithTeacherOutput]:
620
623
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
621
624
 
@@ -26,7 +26,7 @@ from torch import nn
26
26
 
27
27
  from ...cache_utils import Cache
28
28
  from ...generation import GenerationMixin
29
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
29
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
30
30
  from ...masking_utils import create_causal_mask
31
31
  from ...modeling_layers import GradientCheckpointingLayer
32
32
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -34,7 +34,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
34
34
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
35
35
  from ...processing_utils import Unpack
36
36
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
37
- from ...utils.generic import check_model_inputs
37
+ from ...utils.generic import check_model_inputs, maybe_autocast
38
38
  from ...utils.import_utils import is_causal_conv1d_available, is_torchdynamo_compiling
39
39
  from .configuration_lfm2 import Lfm2Config
40
40
 
@@ -122,7 +122,7 @@ class Lfm2RotaryEmbedding(nn.Module):
122
122
  position_ids_expanded = position_ids[:, None, :].float()
123
123
 
124
124
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
125
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
125
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
126
126
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
127
127
  emb = torch.cat((freqs, freqs), dim=-1)
128
128
  cos = emb.cos() * self.attention_scaling
@@ -358,6 +358,7 @@ def eager_attention_forward(
358
358
  return attn_output, attn_weights
359
359
 
360
360
 
361
+ @use_kernelized_func(apply_rotary_pos_emb)
361
362
  class Lfm2Attention(nn.Module):
362
363
  """Multi-headed attention from 'Attention Is All You Need' paper"""
363
364
 
@@ -372,7 +373,6 @@ class Lfm2Attention(nn.Module):
372
373
  self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
373
374
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
374
375
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
375
- self.rotary_fn = apply_rotary_pos_emb
376
376
  self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
377
377
  self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
378
378
  self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
@@ -384,7 +384,6 @@ class Lfm2Attention(nn.Module):
384
384
  attention_mask: Optional[torch.Tensor],
385
385
  past_key_values: Optional[Lfm2HybridConvCache] = None,
386
386
  cache_position: Optional[torch.LongTensor] = None,
387
- position_ids: Optional[torch.LongTensor] = None,
388
387
  **kwargs,
389
388
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
390
389
  input_shape = hidden_states.shape[:-1]
@@ -233,7 +233,6 @@ class Lfm2Attention(LlamaAttention):
233
233
  attention_mask: Optional[torch.Tensor],
234
234
  past_key_values: Optional[Lfm2HybridConvCache] = None,
235
235
  cache_position: Optional[torch.LongTensor] = None,
236
- position_ids: Optional[torch.LongTensor] = None,
237
236
  **kwargs,
238
237
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
239
238
  input_shape = hidden_states.shape[:-1]
@@ -27,7 +27,7 @@ from torch import nn
27
27
  from ... import initialization as init
28
28
  from ...cache_utils import Cache
29
29
  from ...generation import GenerationMixin
30
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
30
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
31
31
  from ...masking_utils import create_causal_mask
32
32
  from ...modeling_layers import GradientCheckpointingLayer
33
33
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast
@@ -35,7 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
35
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
36
  from ...processing_utils import Unpack
37
37
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
38
- from ...utils.generic import check_model_inputs
38
+ from ...utils.generic import check_model_inputs, maybe_autocast
39
39
  from ...utils.import_utils import is_causal_conv1d_available, is_torchdynamo_compiling
40
40
  from .configuration_lfm2_moe import Lfm2MoeConfig
41
41
 
@@ -123,7 +123,7 @@ class Lfm2MoeRotaryEmbedding(nn.Module):
123
123
  position_ids_expanded = position_ids[:, None, :].float()
124
124
 
125
125
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
126
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
126
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
127
127
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
128
128
  emb = torch.cat((freqs, freqs), dim=-1)
129
129
  cos = emb.cos() * self.attention_scaling
@@ -426,6 +426,7 @@ def eager_attention_forward(
426
426
  return attn_output, attn_weights
427
427
 
428
428
 
429
+ @use_kernelized_func(apply_rotary_pos_emb)
429
430
  class Lfm2MoeAttention(nn.Module):
430
431
  """Multi-headed attention from 'Attention Is All You Need' paper"""
431
432
 
@@ -440,7 +441,6 @@ class Lfm2MoeAttention(nn.Module):
440
441
  self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
441
442
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
442
443
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
443
- self.rotary_fn = apply_rotary_pos_emb
444
444
  self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
445
445
  self.q_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps)
446
446
  self.k_layernorm = Lfm2MoeRMSNorm(self.head_dim, eps=config.norm_eps)
@@ -452,7 +452,6 @@ class Lfm2MoeAttention(nn.Module):
452
452
  attention_mask: Optional[torch.Tensor],
453
453
  past_key_values: Optional[Lfm2MoeHybridConvCache] = None,
454
454
  cache_position: Optional[torch.LongTensor] = None,
455
- position_ids: Optional[torch.LongTensor] = None,
456
455
  **kwargs,
457
456
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
458
457
  input_shape = hidden_states.shape[:-1]
@@ -27,6 +27,7 @@ from torch import nn
27
27
  from torch.nn.utils.rnn import pad_sequence
28
28
 
29
29
  from ...activations import ACT2FN
30
+ from ...integrations import use_kernelized_func
30
31
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
31
32
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
32
33
  from ...processing_utils import Unpack
@@ -174,6 +175,7 @@ def eager_attention_forward(
174
175
  return attn_output, attn_weights
175
176
 
176
177
 
178
+ @use_kernelized_func(apply_rotary_pos_emb)
177
179
  class LightGlueAttention(nn.Module):
178
180
  """Multi-headed attention from 'Attention Is All You Need' paper"""
179
181
 
@@ -199,7 +201,6 @@ class LightGlueAttention(nn.Module):
199
201
  self.o_proj = nn.Linear(
200
202
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
201
203
  )
202
- self.rotary_fn = apply_rotary_pos_emb
203
204
 
204
205
  def forward(
205
206
  self,
@@ -870,6 +871,7 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel):
870
871
  labels: Optional[torch.LongTensor] = None,
871
872
  output_attentions: Optional[bool] = None,
872
873
  output_hidden_states: Optional[bool] = None,
874
+ **kwargs,
873
875
  ) -> Union[tuple, "LightGlueKeypointMatchingOutput"]:
874
876
  loss = None
875
877
  if labels is not None:
@@ -927,6 +927,7 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel):
927
927
  labels: Optional[torch.LongTensor] = None,
928
928
  output_attentions: Optional[bool] = None,
929
929
  output_hidden_states: Optional[bool] = None,
930
+ **kwargs,
930
931
  ) -> Union[tuple, "LightGlueKeypointMatchingOutput"]:
931
932
  loss = None
932
933
  if labels is not None:
@@ -538,6 +538,7 @@ class LiltModel(LiltPreTrainedModel):
538
538
  output_attentions: Optional[bool] = None,
539
539
  output_hidden_states: Optional[bool] = None,
540
540
  return_dict: Optional[bool] = None,
541
+ **kwargs,
541
542
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
542
543
  r"""
543
544
  bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
@@ -665,6 +666,7 @@ class LiltForSequenceClassification(LiltPreTrainedModel):
665
666
  output_attentions: Optional[bool] = None,
666
667
  output_hidden_states: Optional[bool] = None,
667
668
  return_dict: Optional[bool] = None,
669
+ **kwargs,
668
670
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
669
671
  r"""
670
672
  bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
@@ -780,6 +782,7 @@ class LiltForTokenClassification(LiltPreTrainedModel):
780
782
  output_attentions: Optional[bool] = None,
781
783
  output_hidden_states: Optional[bool] = None,
782
784
  return_dict: Optional[bool] = None,
785
+ **kwargs,
783
786
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
784
787
  r"""
785
788
  bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
@@ -897,6 +900,7 @@ class LiltForQuestionAnswering(LiltPreTrainedModel):
897
900
  output_attentions: Optional[bool] = None,
898
901
  output_hidden_states: Optional[bool] = None,
899
902
  return_dict: Optional[bool] = None,
903
+ **kwargs,
900
904
  ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
901
905
  r"""
902
906
  bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
@@ -26,7 +26,7 @@ from torch import nn
26
26
  from ...activations import ACT2FN
27
27
  from ...cache_utils import Cache, DynamicCache
28
28
  from ...generation import GenerationMixin
29
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
29
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
30
30
  from ...masking_utils import create_causal_mask
31
31
  from ...modeling_layers import (
32
32
  GenericForQuestionAnswering,
@@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
42
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
43
  from ...processing_utils import Unpack
44
44
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
45
- from ...utils.generic import check_model_inputs
45
+ from ...utils.generic import check_model_inputs, maybe_autocast
46
46
  from .configuration_llama import LlamaConfig
47
47
 
48
48
 
@@ -126,7 +126,7 @@ class LlamaRotaryEmbedding(nn.Module):
126
126
  position_ids_expanded = position_ids[:, None, :].float()
127
127
 
128
128
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
129
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
129
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
130
130
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
131
131
  emb = torch.cat((freqs, freqs), dim=-1)
132
132
  cos = emb.cos() * self.attention_scaling
@@ -224,6 +224,7 @@ def eager_attention_forward(
224
224
  return attn_output, attn_weights
225
225
 
226
226
 
227
+ @use_kernelized_func(apply_rotary_pos_emb)
227
228
  class LlamaAttention(nn.Module):
228
229
  """Multi-headed attention from 'Attention Is All You Need' paper"""
229
230
 
@@ -249,7 +250,6 @@ class LlamaAttention(nn.Module):
249
250
  self.o_proj = nn.Linear(
250
251
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
251
252
  )
252
- self.rotary_fn = apply_rotary_pos_emb
253
253
 
254
254
  def forward(
255
255
  self,
@@ -12,11 +12,12 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+ from typing import Optional, Union
15
16
 
16
- from tokenizers import AddedToken, Tokenizer, decoders, pre_tokenizers
17
+ from tokenizers import Tokenizer, decoders, pre_tokenizers
17
18
  from tokenizers.models import BPE
18
19
 
19
- from ...tokenization_utils_base import _get_prepend_scheme, generate_merges
20
+ from ...tokenization_utils_base import _get_prepend_scheme
20
21
  from ...tokenization_utils_tokenizers import TokenizersBackend
21
22
  from ...utils import logging
22
23
 
@@ -61,6 +62,10 @@ class LlamaTokenizer(TokenizersBackend):
61
62
  refer to this superclass for more information regarding those methods.
62
63
 
63
64
  Args:
65
+ vocab (`str`, `dict` or `list`, *optional*):
66
+ Path to the vocabulary file, a dictionary or a list of tokens.
67
+ merges (`str` or `list`, *optional*):
68
+ Path to the merges file or a list of merges.
64
69
  clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
65
70
  Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
66
71
  extra spaces.
@@ -84,42 +89,32 @@ class LlamaTokenizer(TokenizersBackend):
84
89
  vocab_files_names = VOCAB_FILES_NAMES
85
90
  padding_side = "left"
86
91
  model_input_names = ["input_ids", "attention_mask"]
92
+ model = BPE
87
93
 
88
94
  def __init__(
89
95
  self,
96
+ vocab: Optional[Union[str, dict, list]] = None,
97
+ merges: Optional[Union[str, list]] = None,
90
98
  clean_up_tokenization_spaces=False,
91
99
  unk_token="<unk>",
92
100
  bos_token="<s>",
93
101
  eos_token="</s>",
94
- add_bos_token=True,
95
- add_eos_token=False,
96
102
  use_default_system_prompt=False,
97
103
  legacy=False,
98
104
  add_prefix_space=None,
99
- vocab=None,
100
- merges=None,
101
105
  **kwargs,
102
106
  ):
103
107
  self.add_prefix_space = add_prefix_space if add_prefix_space is not None else True
104
-
105
- if vocab is not None:
106
- self._vocab = (
107
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
108
- )
109
- else:
108
+ self.legacy = legacy
109
+ self._vocab = vocab
110
+ if vocab is None:
110
111
  self._vocab = {
111
112
  str(unk_token): 0,
112
113
  str(bos_token): 1,
113
114
  str(eos_token): 2,
114
115
  }
115
116
 
116
- special_tokens = {str(eos_token), str(bos_token), str(unk_token)}
117
-
118
- filtered_vocab = {t: i for t, i in self._vocab.items() if t not in special_tokens}
119
- if merges is not None:
120
- self._merges = [tuple(merge) if isinstance(merge, list) else merge for merge in merges]
121
- else:
122
- self._merges = generate_merges(filtered_vocab)
117
+ self._merges = merges or []
123
118
  self._tokenizer = Tokenizer(
124
119
  BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True, byte_fallback=True, dropout=None)
125
120
  )
@@ -138,40 +133,17 @@ class LlamaTokenizer(TokenizersBackend):
138
133
  sequence += [decoders.Strip(content=" ", left=1)]
139
134
 
140
135
  self._tokenizer.decoder = decoders.Sequence(sequence)
141
- tokenizer_object = self._tokenizer
142
-
136
+ self.use_default_system_prompt = use_default_system_prompt
143
137
  super().__init__(
144
- tokenizer_object=tokenizer_object,
145
138
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
146
139
  unk_token=unk_token,
147
140
  bos_token=bos_token,
148
141
  eos_token=eos_token,
149
- add_bos_token=add_bos_token,
150
- add_eos_token=add_eos_token,
151
142
  use_default_system_prompt=use_default_system_prompt,
152
143
  add_prefix_space=add_prefix_space,
153
144
  **kwargs,
154
145
  )
155
146
 
156
- self._add_bos_token = add_bos_token
157
- self._add_eos_token = add_eos_token
158
- self.use_default_system_prompt = use_default_system_prompt
159
-
160
- self._post_init()
161
-
162
- def _post_init(self):
163
- """Post-initialization setup that needs to run after _tokenizer is set."""
164
- # Only set pre_tokenizer/normalizer for Llama-3 style tokenizers (use Sequence)
165
- pre_tok = self._tokenizer.pre_tokenizer
166
- if pre_tok is None or type(pre_tok).__name__ != "Sequence":
167
- self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
168
- replacement="▁", prepend_scheme="first", split=False
169
- )
170
- self._tokenizer.normalizer = None
171
- self.add_tokens([AddedToken(token, special=True) for token in self.all_special_tokens])
172
- super()._post_init()
173
- self.update_post_processor()
174
-
175
147
 
176
148
  __all__ = ["LlamaTokenizer", "LlamaTokenizerFast"]
177
149
 
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import (
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
42
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
43
- from ...utils.generic import check_model_inputs
43
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
44
  from .configuration_llama4 import Llama4Config, Llama4TextConfig
45
45
 
46
46
 
@@ -228,7 +228,7 @@ class Llama4TextRotaryEmbedding(nn.Module):
228
228
  position_ids_expanded = position_ids[:, None, :].float()
229
229
 
230
230
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
231
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
231
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
232
232
  freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
233
233
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation
234
234
  freqs_cis = freqs_cis * self.attention_scaling
@@ -1072,6 +1072,7 @@ class Llama4VisionModel(Llama4PreTrainedModel):
1072
1072
  output_attentions: Optional[bool] = None,
1073
1073
  output_hidden_states: Optional[bool] = None,
1074
1074
  return_dict: Optional[bool] = None,
1075
+ **kwargs,
1075
1076
  ) -> Union[BaseModelOutput, tuple[torch.Tensor, ...]]:
1076
1077
  r"""
1077
1078
 
@@ -40,7 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
42
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
43
- from ...utils.generic import check_model_inputs
43
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
44
  from .configuration_longcat_flash import LongcatFlashConfig
45
45
 
46
46
 
@@ -121,7 +121,7 @@ class LongcatFlashRotaryEmbedding(nn.Module):
121
121
  position_ids_expanded = position_ids[:, None, :].float()
122
122
 
123
123
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
124
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
124
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
125
125
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
126
126
  emb = torch.cat((freqs, freqs), dim=-1)
127
127
  cos = emb.cos() * self.attention_scaling
@@ -431,7 +431,7 @@ class LongcatFlashMLA(nn.Module):
431
431
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
432
432
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
433
433
 
434
- if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
434
+ if "flash" in self.config._attn_implementation and self.qk_head_dim != self.v_head_dim:
435
435
  value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
436
436
 
437
437
  attention_interface: Callable = eager_attention_forward
@@ -449,7 +449,7 @@ class LongcatFlashMLA(nn.Module):
449
449
  **kwargs,
450
450
  )
451
451
 
452
- if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
452
+ if "flash" in self.config._attn_implementation and self.qk_head_dim != self.v_head_dim:
453
453
  attn_output = attn_output[:, :, :, : self.v_head_dim]
454
454
 
455
455
  attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
@@ -215,7 +215,7 @@ class LongcatFlashMLA(DeepseekV3Attention):
215
215
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
216
216
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
217
217
 
218
- if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
218
+ if "flash" in self.config._attn_implementation and self.qk_head_dim != self.v_head_dim:
219
219
  value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
220
220
 
221
221
  attention_interface: Callable = eager_attention_forward
@@ -233,7 +233,7 @@ class LongcatFlashMLA(DeepseekV3Attention):
233
233
  **kwargs,
234
234
  )
235
235
 
236
- if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
236
+ if "flash" in self.config._attn_implementation and self.qk_head_dim != self.v_head_dim:
237
237
  attn_output = attn_output[:, :, :, : self.v_head_dim]
238
238
 
239
239
  attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
@@ -1414,6 +1414,7 @@ class LongformerModel(LongformerPreTrainedModel):
1414
1414
  output_attentions: Optional[bool] = None,
1415
1415
  output_hidden_states: Optional[bool] = None,
1416
1416
  return_dict: Optional[bool] = None,
1417
+ **kwargs,
1417
1418
  ) -> Union[tuple, LongformerBaseModelOutputWithPooling]:
1418
1419
  r"""
1419
1420
  global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1567,6 +1568,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
1567
1568
  output_attentions: Optional[bool] = None,
1568
1569
  output_hidden_states: Optional[bool] = None,
1569
1570
  return_dict: Optional[bool] = None,
1571
+ **kwargs,
1570
1572
  ) -> Union[tuple, LongformerMaskedLMOutput]:
1571
1573
  r"""
1572
1574
  global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1678,6 +1680,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
1678
1680
  output_attentions: Optional[bool] = None,
1679
1681
  output_hidden_states: Optional[bool] = None,
1680
1682
  return_dict: Optional[bool] = None,
1683
+ **kwargs,
1681
1684
  ) -> Union[tuple, LongformerSequenceClassifierOutput]:
1682
1685
  r"""
1683
1686
  global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1800,6 +1803,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
1800
1803
  output_attentions: Optional[bool] = None,
1801
1804
  output_hidden_states: Optional[bool] = None,
1802
1805
  return_dict: Optional[bool] = None,
1806
+ **kwargs,
1803
1807
  ) -> Union[tuple, LongformerQuestionAnsweringModelOutput]:
1804
1808
  r"""
1805
1809
  global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1928,6 +1932,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
1928
1932
  output_attentions: Optional[bool] = None,
1929
1933
  output_hidden_states: Optional[bool] = None,
1930
1934
  return_dict: Optional[bool] = None,
1935
+ **kwargs,
1931
1936
  ) -> Union[tuple, LongformerTokenClassifierOutput]:
1932
1937
  r"""
1933
1938
  global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -2007,6 +2012,7 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
2007
2012
  output_attentions: Optional[bool] = None,
2008
2013
  output_hidden_states: Optional[bool] = None,
2009
2014
  return_dict: Optional[bool] = None,
2015
+ **kwargs,
2010
2016
  ) -> Union[tuple, LongformerMultipleChoiceModelOutput]:
2011
2017
  r"""
2012
2018
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -1283,6 +1283,7 @@ class LongT5Stack(LongT5PreTrainedModel):
1283
1283
  output_hidden_states=None,
1284
1284
  return_dict=None,
1285
1285
  cache_position=None,
1286
+ **kwargs,
1286
1287
  ):
1287
1288
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1288
1289
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1618,6 +1619,7 @@ class LongT5Model(LongT5PreTrainedModel):
1618
1619
  output_hidden_states: Optional[bool] = None,
1619
1620
  return_dict: Optional[bool] = None,
1620
1621
  cache_position: Optional[torch.LongTensor] = None,
1622
+ **kwargs,
1621
1623
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
1622
1624
  r"""
1623
1625
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1783,6 +1785,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin):
1783
1785
  output_hidden_states: Optional[bool] = None,
1784
1786
  return_dict: Optional[bool] = None,
1785
1787
  cache_position: Optional[torch.LongTensor] = None,
1788
+ **kwargs,
1786
1789
  ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1787
1790
  r"""
1788
1791
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1946,6 +1949,7 @@ class LongT5EncoderModel(LongT5PreTrainedModel):
1946
1949
  output_attentions: Optional[bool] = None,
1947
1950
  output_hidden_states: Optional[bool] = None,
1948
1951
  return_dict: Optional[bool] = None,
1952
+ **kwargs,
1949
1953
  ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
1950
1954
  r"""
1951
1955
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -837,6 +837,7 @@ class LukeModel(LukePreTrainedModel):
837
837
  output_attentions: Optional[bool] = None,
838
838
  output_hidden_states: Optional[bool] = None,
839
839
  return_dict: Optional[bool] = None,
840
+ **kwargs,
840
841
  ) -> Union[tuple, BaseLukeModelOutputWithPooling]:
841
842
  r"""
842
843
  entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
@@ -1087,6 +1088,7 @@ class LukeForMaskedLM(LukePreTrainedModel):
1087
1088
  output_attentions: Optional[bool] = None,
1088
1089
  output_hidden_states: Optional[bool] = None,
1089
1090
  return_dict: Optional[bool] = None,
1091
+ **kwargs,
1090
1092
  ) -> Union[tuple, LukeMaskedLMOutput]:
1091
1093
  r"""
1092
1094
  entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
@@ -1220,6 +1222,7 @@ class LukeForEntityClassification(LukePreTrainedModel):
1220
1222
  output_attentions: Optional[bool] = None,
1221
1223
  output_hidden_states: Optional[bool] = None,
1222
1224
  return_dict: Optional[bool] = None,
1225
+ **kwargs,
1223
1226
  ) -> Union[tuple, EntityClassificationOutput]:
1224
1227
  r"""
1225
1228
  entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
@@ -1348,6 +1351,7 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
1348
1351
  output_attentions: Optional[bool] = None,
1349
1352
  output_hidden_states: Optional[bool] = None,
1350
1353
  return_dict: Optional[bool] = None,
1354
+ **kwargs,
1351
1355
  ) -> Union[tuple, EntityPairClassificationOutput]:
1352
1356
  r"""
1353
1357
  entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
@@ -1483,6 +1487,7 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
1483
1487
  output_attentions: Optional[bool] = None,
1484
1488
  output_hidden_states: Optional[bool] = None,
1485
1489
  return_dict: Optional[bool] = None,
1490
+ **kwargs,
1486
1491
  ) -> Union[tuple, EntitySpanClassificationOutput]:
1487
1492
  r"""
1488
1493
  entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
@@ -1638,6 +1643,7 @@ class LukeForSequenceClassification(LukePreTrainedModel):
1638
1643
  output_attentions: Optional[bool] = None,
1639
1644
  output_hidden_states: Optional[bool] = None,
1640
1645
  return_dict: Optional[bool] = None,
1646
+ **kwargs,
1641
1647
  ) -> Union[tuple, LukeSequenceClassifierOutput]:
1642
1648
  r"""
1643
1649
  entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
@@ -1764,6 +1770,7 @@ class LukeForTokenClassification(LukePreTrainedModel):
1764
1770
  output_attentions: Optional[bool] = None,
1765
1771
  output_hidden_states: Optional[bool] = None,
1766
1772
  return_dict: Optional[bool] = None,
1773
+ **kwargs,
1767
1774
  ) -> Union[tuple, LukeTokenClassifierOutput]:
1768
1775
  r"""
1769
1776
  entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
@@ -1865,6 +1872,7 @@ class LukeForQuestionAnswering(LukePreTrainedModel):
1865
1872
  output_attentions: Optional[bool] = None,
1866
1873
  output_hidden_states: Optional[bool] = None,
1867
1874
  return_dict: Optional[bool] = None,
1875
+ **kwargs,
1868
1876
  ) -> Union[tuple, LukeQuestionAnsweringModelOutput]:
1869
1877
  r"""
1870
1878
  entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
@@ -1982,6 +1990,7 @@ class LukeForMultipleChoice(LukePreTrainedModel):
1982
1990
  output_attentions: Optional[bool] = None,
1983
1991
  output_hidden_states: Optional[bool] = None,
1984
1992
  return_dict: Optional[bool] = None,
1993
+ **kwargs,
1985
1994
  ) -> Union[tuple, LukeMultipleChoiceModelOutput]:
1986
1995
  r"""
1987
1996
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):