transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -36,7 +36,7 @@ from typing import Optional, TypeVar, Union, get_type_hints
36
36
  from zipfile import is_zipfile
37
37
 
38
38
  import torch
39
- from huggingface_hub import create_repo, split_torch_state_dict_into_shards
39
+ from huggingface_hub import create_repo, is_offline_mode, split_torch_state_dict_into_shards
40
40
  from packaging import version
41
41
  from safetensors import safe_open
42
42
  from safetensors.torch import save_file as safe_save_file
@@ -85,7 +85,7 @@ from .integrations.tensor_parallel import (
85
85
  verify_tp_plan,
86
86
  )
87
87
  from .loss.loss_utils import LOSS_MAPPING
88
- from .modeling_flash_attention_utils import lazy_import_flash_attention
88
+ from .modeling_flash_attention_utils import lazy_import_flash_attention, lazy_import_paged_flash_attention
89
89
  from .pytorch_utils import id_tensor_storage
90
90
  from .quantizers import HfQuantizer
91
91
  from .quantizers.auto import get_hf_quantizer
@@ -93,7 +93,6 @@ from .quantizers.quantizers_utils import get_module_from_name
93
93
  from .safetensors_conversion import auto_conversion
94
94
  from .utils import (
95
95
  ADAPTER_SAFE_WEIGHTS_NAME,
96
- ADAPTER_WEIGHTS_NAME,
97
96
  DUMMY_INPUTS,
98
97
  SAFE_WEIGHTS_INDEX_NAME,
99
98
  SAFE_WEIGHTS_NAME,
@@ -110,7 +109,6 @@ from .utils import (
110
109
  is_flash_attn_2_available,
111
110
  is_flash_attn_3_available,
112
111
  is_kernels_available,
113
- is_offline_mode,
114
112
  is_torch_flex_attn_available,
115
113
  is_torch_greater_or_equal,
116
114
  is_torch_mlu_available,
@@ -279,7 +277,9 @@ def get_state_dict_dtype(state_dict):
279
277
  return t.dtype
280
278
 
281
279
  # if no floating dtype was found return whatever the first dtype is
282
- return next(state_dict.values()).dtype
280
+ if len(state_dict) == 0:
281
+ return torch.float32
282
+ return next(iter(state_dict.values())).dtype
283
283
 
284
284
 
285
285
  str_to_torch_dtype = {
@@ -552,8 +552,7 @@ def _get_resolved_checkpoint_files(
552
552
  raise OSError(
553
553
  f"{pretrained_model_name_or_path} does not appear to have a file named"
554
554
  f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
555
- "and thus cannot be loaded with `safetensors`. Please make sure that the model has "
556
- "been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
555
+ "and thus cannot be loaded with `safetensors`. Please do not set `use_safetensors=True`."
557
556
  )
558
557
  else:
559
558
  # This repo has no safetensors file of any kind, we switch to PyTorch.
@@ -772,7 +771,7 @@ def _get_dtype(
772
771
  for key in config.sub_configs:
773
772
  if (sub_config := getattr(config, key)) is not None:
774
773
  sub_config.dtype = default_dtype
775
-
774
+ dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
776
775
  return config, dtype, dtype_orig
777
776
 
778
777
 
@@ -799,7 +798,11 @@ class ModuleUtilsMixin:
799
798
  """
800
799
  `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
801
800
  """
802
- return next(param.dtype for param in self.parameters() if param.is_floating_point())
801
+ dtype = self._dtype or next(param.dtype for param in self.parameters() if param.is_floating_point())
802
+ if isinstance(dtype, str):
803
+ if hasattr(torch, dtype):
804
+ dtype = getattr(torch, dtype)
805
+ return dtype
803
806
 
804
807
  def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
805
808
  """
@@ -1078,6 +1081,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1078
1081
  _keep_in_fp32_modules_strict = None
1079
1082
 
1080
1083
  dtype_plan: Optional[dict[str, torch.dtype]] = None
1084
+ _dtype: Optional[Union[str, torch.dtype]] = torch.get_default_dtype()
1081
1085
 
1082
1086
  # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
1083
1087
  # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
@@ -1222,6 +1226,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1222
1226
  f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
1223
1227
  )
1224
1228
  self.config = config
1229
+ default_dtype = torch.get_default_dtype()
1230
+ self._dtype = default_dtype
1225
1231
 
1226
1232
  # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
1227
1233
  # setting it recursively)
@@ -1460,6 +1466,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1460
1466
  Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
1461
1467
  `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
1462
1468
  """
1469
+ if isinstance(dtype, str):
1470
+ if hasattr(torch, dtype):
1471
+ dtype = getattr(torch, dtype)
1472
+ else:
1473
+ raise ValueError(f"Received an invalid string dtype: {dtype}")
1463
1474
  if not dtype.is_floating_point:
1464
1475
  raise ValueError(
1465
1476
  f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
@@ -1468,6 +1479,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1468
1479
  logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
1469
1480
  dtype_orig = torch.get_default_dtype()
1470
1481
  torch.set_default_dtype(dtype)
1482
+ cls._dtype = dtype
1471
1483
  return dtype_orig
1472
1484
 
1473
1485
  @property
@@ -1764,9 +1776,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1764
1776
  """
1765
1777
  applicable_attn_implementation = attn_implementation
1766
1778
 
1779
+ is_paged = attn_implementation is not None and attn_implementation.startswith("paged|")
1780
+
1767
1781
  # If FA not installed, do not fail but use kernels instead
1768
1782
  requested_original_flash_attn = attn_implementation is not None and (
1769
- attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3"
1783
+ attn_implementation.removeprefix("paged|") == "flash_attention_2"
1784
+ or attn_implementation.removeprefix("paged|") == "flash_attention_3"
1770
1785
  )
1771
1786
  if (
1772
1787
  requested_original_flash_attn
@@ -1784,10 +1799,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1784
1799
  else:
1785
1800
  applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
1786
1801
 
1802
+ if is_paged:
1803
+ applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
1804
+
1787
1805
  if is_kernel(applicable_attn_implementation):
1788
1806
  try:
1789
1807
  # preload flash attention here to allow compile with fullgraph
1790
- lazy_import_flash_attention(applicable_attn_implementation)
1808
+ if is_paged:
1809
+ lazy_import_paged_flash_attention(applicable_attn_implementation)
1810
+ else:
1811
+ lazy_import_flash_attention(applicable_attn_implementation)
1791
1812
 
1792
1813
  # log that we used kernel fallback if successful
1793
1814
  if requested_original_flash_attn:
@@ -2104,7 +2125,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2104
2125
  possible_module_names = ["language_model", "text_model", "decoder"]
2105
2126
  for name in possible_module_names:
2106
2127
  if hasattr(self, name):
2107
- print(name)
2108
2128
  setattr(self, name, decoder)
2109
2129
  return
2110
2130
 
@@ -3002,10 +3022,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3002
3022
  save_directory: Union[str, os.PathLike],
3003
3023
  is_main_process: bool = True,
3004
3024
  state_dict: Optional[dict] = None,
3005
- save_function: Callable = torch.save,
3006
3025
  push_to_hub: bool = False,
3007
- max_shard_size: Union[int, str] = "5GB",
3008
- safe_serialization: bool = True,
3026
+ max_shard_size: Union[int, str] = "50GB",
3009
3027
  variant: Optional[str] = None,
3010
3028
  token: Optional[Union[str, bool]] = None,
3011
3029
  save_peft_format: bool = True,
@@ -3027,18 +3045,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3027
3045
  The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
3028
3046
  save parts of the model or if special precautions need to be taken when recovering the state dictionary
3029
3047
  of a model (like when using model parallelism).
3030
- save_function (`Callable`):
3031
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
3032
- need to replace `torch.save` by another method.
3033
3048
  push_to_hub (`bool`, *optional*, defaults to `False`):
3034
3049
  Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
3035
3050
  repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
3036
3051
  namespace).
3037
- max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
3052
+ max_shard_size (`int` or `str`, *optional*, defaults to `"50GB"`):
3038
3053
  The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
3039
3054
  lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
3040
- We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
3041
- without CPU OOM issues.
3042
3055
 
3043
3056
  <Tip warning={true}>
3044
3057
 
@@ -3047,10 +3060,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3047
3060
 
3048
3061
  </Tip>
3049
3062
 
3050
- safe_serialization (`bool`, *optional*, defaults to `True`):
3051
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
3052
3063
  variant (`str`, *optional*):
3053
- If specified, weights are saved in the format pytorch_model.<variant>.bin.
3064
+ If specified, weights are saved in the format model.<variant>.safetensors.
3054
3065
  token (`str` or `bool`, *optional*):
3055
3066
  The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
3056
3067
  the token generated when running `hf auth login` (stored in `~/.huggingface`).
@@ -3072,9 +3083,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3072
3083
 
3073
3084
  hf_quantizer = getattr(self, "hf_quantizer", None)
3074
3085
  quantization_serializable = (
3075
- hf_quantizer is not None
3076
- and isinstance(hf_quantizer, HfQuantizer)
3077
- and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
3086
+ hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable()
3078
3087
  )
3079
3088
 
3080
3089
  if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
@@ -3110,7 +3119,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3110
3119
 
3111
3120
  metadata = {}
3112
3121
  if hf_quantizer is not None:
3113
- state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization)
3122
+ state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self)
3114
3123
  metadata["format"] = "pt"
3115
3124
 
3116
3125
  # Only save the model itself if we are using distributed training
@@ -3202,75 +3211,72 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3202
3211
  if self._tp_size is not None:
3203
3212
  state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
3204
3213
 
3205
- if safe_serialization:
3206
- # TODO: fix safe_serialization for tied weights
3207
- # Safetensors does not allow tensor aliasing.
3208
- # We're going to remove aliases before saving
3209
- ptrs = collections.defaultdict(list)
3210
- for name, tensor in state_dict.items():
3211
- if not isinstance(tensor, torch.Tensor):
3212
- # Sometimes in the state_dict we have non-tensor objects.
3213
- # e.g. in bitsandbytes we have some `str` objects in the state_dict
3214
- # In the non-tensor case, fall back to the pointer of the object itself
3215
- ptrs[id(tensor)].append(name)
3216
-
3217
- elif tensor.device.type == "meta":
3218
- # In offloaded cases, there may be meta tensors in the state_dict.
3219
- # For these cases, key by the pointer of the original tensor object
3220
- # (state_dict tensors are detached and therefore no longer shared)
3221
- tensor = self.get_parameter(name)
3222
- ptrs[id(tensor)].append(name)
3214
+ # Safetensors does not allow tensor aliasing - we're going to remove aliases before saving
3215
+ ptrs = collections.defaultdict(list)
3216
+ for name, tensor in state_dict.items():
3217
+ if not isinstance(tensor, torch.Tensor):
3218
+ # Sometimes in the state_dict we have non-tensor objects.
3219
+ # e.g. in bitsandbytes we have some `str` objects in the state_dict
3220
+ # In the non-tensor case, fall back to the pointer of the object itself
3221
+ ptrs[id(tensor)].append(name)
3222
+
3223
+ elif tensor.device.type == "meta":
3224
+ # In offloaded cases, there may be meta tensors in the state_dict.
3225
+ # For these cases, key by the pointer of the original tensor object
3226
+ # (state_dict tensors are detached and therefore no longer shared)
3227
+ tensor = self.get_parameter(name)
3228
+ ptrs[id(tensor)].append(name)
3223
3229
 
3224
- else:
3225
- ptrs[id_tensor_storage(tensor)].append(name)
3226
-
3227
- shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
3228
-
3229
- # Recursively descend to find tied weight keys
3230
- _tied_weights_keys = set(_get_tied_weight_keys(self))
3231
- error_names = []
3232
- to_delete_names = set()
3233
- for names in shared_ptrs.values():
3234
- # Removing the keys which are declared as known duplicates on
3235
- # load. This allows to make sure the name which is kept is consistent.
3236
- if _tied_weights_keys is not None:
3237
- found = 0
3238
- for name in sorted(names):
3239
- matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
3240
- if matches_pattern and name in state_dict:
3241
- found += 1
3242
- if found < len(names):
3243
- to_delete_names.add(name)
3244
- # We are entering a place where the weights and the transformers configuration do NOT match.
3245
- shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
3246
- # Those are actually tensor sharing but disjoint from each other, we can safely clone them
3247
- # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
3248
- for name in disjoint_names:
3249
- state_dict[name] = state_dict[name].clone()
3250
-
3251
- # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
3252
- # If the link between tensors was done at runtime then `from_pretrained` will not get
3253
- # the key back leading to random tensor. A proper warning will be shown
3254
- # during reload (if applicable), but since the file is not necessarily compatible with
3255
- # the config, better show a proper warning.
3256
- shared_names, identical_names = _find_identical(shared_names, state_dict)
3257
- # delete tensors that have identical storage
3258
- for inames in identical_names:
3259
- known = inames.intersection(to_delete_names)
3260
- for name in known:
3261
- del state_dict[name]
3262
- unknown = inames.difference(to_delete_names)
3263
- if len(unknown) > 1:
3264
- error_names.append(unknown)
3265
-
3266
- if shared_names:
3267
- error_names.extend(shared_names)
3268
-
3269
- if len(error_names) > 0:
3270
- raise RuntimeError(
3271
- f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
3272
- "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
3273
- )
3230
+ else:
3231
+ ptrs[id_tensor_storage(tensor)].append(name)
3232
+
3233
+ shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
3234
+
3235
+ # Recursively descend to find tied weight keys
3236
+ _tied_weights_keys = set(_get_tied_weight_keys(self))
3237
+ error_names = []
3238
+ to_delete_names = set()
3239
+ for names in shared_ptrs.values():
3240
+ # Removing the keys which are declared as known duplicates on
3241
+ # load. This allows to make sure the name which is kept is consistent.
3242
+ if _tied_weights_keys is not None:
3243
+ found = 0
3244
+ for name in sorted(names):
3245
+ matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
3246
+ if matches_pattern and name in state_dict:
3247
+ found += 1
3248
+ if found < len(names):
3249
+ to_delete_names.add(name)
3250
+ # We are entering a place where the weights and the transformers configuration do NOT match.
3251
+ shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
3252
+ # Those are actually tensor sharing but disjoint from each other, we can safely clone them
3253
+ # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
3254
+ for name in disjoint_names:
3255
+ state_dict[name] = state_dict[name].clone()
3256
+
3257
+ # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
3258
+ # If the link between tensors was done at runtime then `from_pretrained` will not get
3259
+ # the key back leading to random tensor. A proper warning will be shown
3260
+ # during reload (if applicable), but since the file is not necessarily compatible with
3261
+ # the config, better show a proper warning.
3262
+ shared_names, identical_names = _find_identical(shared_names, state_dict)
3263
+ # delete tensors that have identical storage
3264
+ for inames in identical_names:
3265
+ known = inames.intersection(to_delete_names)
3266
+ for name in known:
3267
+ del state_dict[name]
3268
+ unknown = inames.difference(to_delete_names)
3269
+ if len(unknown) > 1:
3270
+ error_names.append(unknown)
3271
+
3272
+ if shared_names:
3273
+ error_names.extend(shared_names)
3274
+
3275
+ if len(error_names) > 0:
3276
+ raise RuntimeError(
3277
+ f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
3278
+ "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
3279
+ )
3274
3280
 
3275
3281
  # Revert all renaming and/or weight operations
3276
3282
  if save_original_format:
@@ -3278,10 +3284,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3278
3284
 
3279
3285
  # Shard the model if it is too big.
3280
3286
  if not _hf_peft_config_loaded:
3281
- weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
3287
+ weights_name = SAFE_WEIGHTS_NAME
3282
3288
  weights_name = _add_variant(weights_name, variant)
3283
3289
  else:
3284
- weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
3290
+ weights_name = ADAPTER_SAFE_WEIGHTS_NAME
3285
3291
 
3286
3292
  filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
3287
3293
  state_dict_split = split_torch_state_dict_into_shards(
@@ -3350,13 +3356,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3350
3356
  del shard_state_dict
3351
3357
  gc.collect()
3352
3358
 
3353
- if safe_serialization:
3354
- # At some point we will need to deal better with save_function (used for TPU and other distributed
3355
- # joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting
3356
- # too much before scheduling the next write when its in a different file
3357
- safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
3358
- else:
3359
- save_function(shard, os.path.join(save_directory, shard_file))
3359
+ # TODO: we should def parallelize this we are otherwise just waiting
3360
+ # too much before scheduling the next write when its in a different file
3361
+ safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
3360
3362
 
3361
3363
  del state_dict
3362
3364
 
@@ -3364,7 +3366,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3364
3366
  path_to_weights = os.path.join(save_directory, weights_name)
3365
3367
  logger.info(f"Model weights saved in {path_to_weights}")
3366
3368
  else:
3367
- save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
3369
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME
3368
3370
  save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
3369
3371
  # Save the index as well
3370
3372
  with open(save_index_file, "w", encoding="utf-8") as f:
@@ -3835,6 +3837,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3835
3837
  # For BC on torch_dtype argument
3836
3838
  if torch_dtype is not None:
3837
3839
  dtype = dtype if dtype is not None else torch_dtype
3840
+ if dtype is None:
3841
+ dtype = "auto"
3838
3842
 
3839
3843
  if is_offline_mode() and not local_files_only:
3840
3844
  local_files_only = True
@@ -4039,7 +4043,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4039
4043
  hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
4040
4044
 
4041
4045
  if _adapter_model_path is not None:
4042
- adapter_kwargs["key_mapping"] = weight_conversions # TODO: Dynamic weight loader for adapters
4046
+ adapter_kwargs["key_mapping"] = key_mapping
4043
4047
  model.load_adapter(
4044
4048
  _adapter_model_path,
4045
4049
  adapter_name=adapter_name,
@@ -4090,10 +4094,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4090
4094
  # Prepare parameters offloading if needed
4091
4095
  if device_map is not None and "disk" in device_map.values():
4092
4096
  disk_offload_index = accelerate_disk_offload(
4097
+ model,
4093
4098
  disk_offload_folder,
4094
4099
  checkpoint_files,
4095
4100
  device_map,
4096
- expected_keys,
4097
4101
  sharded_metadata,
4098
4102
  dtype,
4099
4103
  weight_mapping,
@@ -4115,7 +4119,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4115
4119
  state_dict = merged_state_dict
4116
4120
  error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
4117
4121
  # This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
4118
- missing_keys, unexpected_keys, mismatched_keys, misc = set(), set(), set(), set()
4122
+ missing_keys, unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set(), set()
4119
4123
  else:
4120
4124
  all_pointer = set()
4121
4125
  # Checkpoints are safetensors
@@ -4137,7 +4141,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4137
4141
  else:
4138
4142
  raise ValueError("Neither a state dict nor checkpoint files were found.")
4139
4143
 
4140
- missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, misc = (
4144
+ missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
4141
4145
  convert_and_load_state_dict_in_model(
4142
4146
  model,
4143
4147
  merged_state_dict,
@@ -4180,7 +4184,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4180
4184
  tp_device = list(device_map.values())[0]
4181
4185
  # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
4182
4186
  # not part of the state_dict (persistent=False)
4183
- for buffer in model.buffers(): # TODO to avaoid this buffer could be added to the ckpt
4187
+ for buffer in model.buffers(): # TODO to avoid this buffer could be added to the ckpt
4184
4188
  if buffer.device != tp_device:
4185
4189
  buffer.data = buffer.to(tp_device)
4186
4190
 
@@ -4211,7 +4215,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4211
4215
  missing_keys=missing_keys,
4212
4216
  mismatched_keys=mismatched_keys,
4213
4217
  mismatched_shapes=mismatched_keys,
4214
- misc=misc,
4218
+ conversion_errors=conversion_errors,
4215
4219
  ignore_mismatched_sizes=ignore_mismatched_sizes,
4216
4220
  )
4217
4221
 
@@ -126,6 +126,7 @@ if TYPE_CHECKING:
126
126
  from .falcon import *
127
127
  from .falcon_h1 import *
128
128
  from .falcon_mamba import *
129
+ from .fast_vlm import *
129
130
  from .fastspeech2_conformer import *
130
131
  from .flaubert import *
131
132
  from .flava import *
@@ -185,6 +186,7 @@ if TYPE_CHECKING:
185
186
  from .jetmoe import *
186
187
  from .kosmos2 import *
187
188
  from .kyutai_speech_to_text import *
189
+ from .lasr import *
188
190
  from .layoutlm import *
189
191
  from .layoutlmv2 import *
190
192
  from .layoutlmv3 import *
@@ -263,6 +265,7 @@ if TYPE_CHECKING:
263
265
  from .ovis2 import *
264
266
  from .owlv2 import *
265
267
  from .owlvit import *
268
+ from .paddleocr_vl import *
266
269
  from .paligemma import *
267
270
  from .parakeet import *
268
271
  from .patchtsmixer import *
@@ -28,7 +28,7 @@ from torch import nn
28
28
  from ...activations import ACT2FN
29
29
  from ...cache_utils import Cache, DynamicCache
30
30
  from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_func_from_hub
31
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
32
32
  from ...integrations.hub_kernels import use_kernel_forward_from_hub
33
33
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
34
  from ...modeling_layers import GradientCheckpointingLayer
@@ -37,7 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
37
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
38
  from ...processing_utils import Unpack
39
39
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
40
- from ...utils.generic import check_model_inputs
40
+ from ...utils.generic import check_model_inputs, maybe_autocast
41
41
  from .configuration_afmoe import AfmoeConfig
42
42
 
43
43
 
@@ -97,7 +97,7 @@ class AfmoeRotaryEmbedding(nn.Module):
97
97
  position_ids_expanded = position_ids[:, None, :].float()
98
98
 
99
99
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
100
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
100
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
101
101
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
102
102
  emb = torch.cat((freqs, freqs), dim=-1)
103
103
  cos = emb.cos() * self.attention_scaling
@@ -338,6 +338,7 @@ def eager_attention_forward(
338
338
  return attn_output, attn_weights
339
339
 
340
340
 
341
+ @use_kernelized_func(apply_rotary_pos_emb)
341
342
  class AfmoeAttention(nn.Module):
342
343
  """
343
344
  Multi-headed attention module with optional sliding window and gating.
@@ -369,7 +370,6 @@ class AfmoeAttention(nn.Module):
369
370
  self.o_proj = nn.Linear(
370
371
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
371
372
  )
372
- self.rotary_fn = apply_rotary_pos_emb
373
373
  # Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
374
374
  # We only add AFMoE-specific attributes
375
375
  self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
  """Tokenization classes for ALBERT model."""
16
16
 
17
- from typing import Optional
17
+ from typing import Optional, Union
18
18
 
19
19
  from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
20
20
  from tokenizers.models import Unigram
@@ -73,8 +73,8 @@ class AlbertTokenizer(TokenizersBackend):
73
73
  other word.
74
74
  trim_offsets (`bool`, *optional*, defaults to `True`):
75
75
  Whether the post processing step should trim offsets to avoid including whitespaces.
76
- vocab (`dict`, *optional*):
77
- Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
76
+ vocab (`str` or `list[tuple[str, float]]`, *optional*):
77
+ Custom vocabulary with `(token, score)` tuples. If not provided, vocabulary is loaded from `vocab_file`.
78
78
  vocab_file (`str`, *optional*):
79
79
  [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
80
80
  contains the vocabulary necessary to instantiate a tokenizer.
@@ -82,10 +82,11 @@ class AlbertTokenizer(TokenizersBackend):
82
82
 
83
83
  vocab_files_names = VOCAB_FILES_NAMES
84
84
  model_input_names = ["input_ids", "attention_mask"]
85
- slow_tokenizer_class = None
85
+ model = Unigram
86
86
 
87
87
  def __init__(
88
88
  self,
89
+ vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
89
90
  do_lower_case: bool = True,
90
91
  keep_accents: bool = False,
91
92
  bos_token: str = "[CLS]",
@@ -97,19 +98,15 @@ class AlbertTokenizer(TokenizersBackend):
97
98
  mask_token: str = "[MASK]",
98
99
  add_prefix_space: bool = True,
99
100
  trim_offsets: bool = True,
100
- vocab: Optional[dict] = None,
101
- vocab_file: Optional[str] = None,
102
101
  **kwargs,
103
102
  ):
104
- self.vocab_file = vocab_file
105
103
  self.add_prefix_space = add_prefix_space
106
104
  self.trim_offsets = trim_offsets
107
-
108
105
  self.do_lower_case = do_lower_case
109
106
  self.keep_accents = keep_accents
110
107
 
111
108
  if vocab is not None:
112
- self._vocab_scores = [(token, 0.0) for token in vocab.keys()] if isinstance(vocab, dict) else list(vocab)
109
+ self._vocab_scores = vocab
113
110
  else:
114
111
  self._vocab_scores = [
115
112
  (str(pad_token), 0.0),
@@ -163,10 +160,7 @@ class AlbertTokenizer(TokenizersBackend):
163
160
  ],
164
161
  )
165
162
 
166
- tokenizer_object = self._tokenizer
167
-
168
163
  super().__init__(
169
- tokenizer_object=tokenizer_object,
170
164
  do_lower_case=self.do_lower_case,
171
165
  keep_accents=self.keep_accents,
172
166
  bos_token=bos_token,
@@ -1004,6 +1004,7 @@ class AlignVisionModel(AlignPreTrainedModel):
1004
1004
  pixel_values: Optional[torch.FloatTensor] = None,
1005
1005
  output_hidden_states: Optional[bool] = None,
1006
1006
  return_dict: Optional[bool] = None,
1007
+ **kwargs,
1007
1008
  ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
1008
1009
  r"""
1009
1010
  Examples:
@@ -1169,6 +1170,7 @@ class AlignModel(AlignPreTrainedModel):
1169
1170
  output_attentions: Optional[bool] = None,
1170
1171
  output_hidden_states: Optional[bool] = None,
1171
1172
  return_dict: Optional[bool] = None,
1173
+ **kwargs,
1172
1174
  ) -> Union[tuple, AlignOutput]:
1173
1175
  r"""
1174
1176
  return_loss (`bool`, *optional*):
@@ -891,6 +891,7 @@ class AltCLIPVisionModel(AltCLIPPreTrainedModel):
891
891
  output_hidden_states: Optional[bool] = None,
892
892
  interpolate_pos_encoding: bool = False,
893
893
  return_dict: Optional[bool] = None,
894
+ **kwargs,
894
895
  ) -> Union[tuple, BaseModelOutputWithPooling]:
895
896
  r"""
896
897
  Examples:
@@ -970,6 +971,7 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
970
971
  output_attentions: Optional[bool] = None,
971
972
  output_hidden_states: Optional[bool] = None,
972
973
  return_dict: Optional[bool] = None,
974
+ **kwargs,
973
975
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
974
976
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
975
977
  output_hidden_states = (
@@ -1061,6 +1063,7 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
1061
1063
  output_attentions: Optional[bool] = None,
1062
1064
  return_dict: Optional[bool] = None,
1063
1065
  output_hidden_states: Optional[bool] = None,
1066
+ **kwargs,
1064
1067
  ) -> Union[tuple, BaseModelOutputWithPoolingAndProjection]:
1065
1068
  r"""
1066
1069
  Examples:
@@ -1236,6 +1239,7 @@ class AltCLIPModel(AltCLIPPreTrainedModel):
1236
1239
  output_hidden_states: Optional[bool] = None,
1237
1240
  interpolate_pos_encoding: bool = False,
1238
1241
  return_dict: Optional[bool] = None,
1242
+ **kwargs,
1239
1243
  ) -> Union[tuple, AltCLIPOutput]:
1240
1244
  r"""
1241
1245
  return_loss (`bool`, *optional*):