transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -340,9 +340,17 @@ class TrainingArguments:
340
340
  `save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained
341
341
  alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two
342
342
  checkpoints are saved: the last one and the best one (if they are different).
343
- save_safetensors (`bool`, *optional*, defaults to `True`):
344
- Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
345
- default `torch.load` and `torch.save`.
343
+ enable_jit_checkpoint (`bool`, *optional*, defaults to `False`):
344
+ Whether to enable Just-In-Time (JIT) checkpointing on SIGTERM signal. When enabled, training will
345
+ checkpoint upon receiving SIGTERM, allowing for graceful termination without losing
346
+ progress. This is particularly useful for shared clusters with preemptible workloads (e.g., Kueue).
347
+ **Important**: You must configure your orchestrator's graceful shutdown period to allow sufficient time
348
+ for checkpoint completion. For Kubernetes, set `terminationGracePeriodSeconds` in your job definition
349
+ (method varies by cloud-native trainer: Kubeflow, Ray, etc.). Note: the default is only 30 seconds,
350
+ which is typically insufficient. For Slurm, use `--signal=USR1@<seconds>` in your sbatch script to send
351
+ SIGTERM with adequate time before the job time limit. Calculate the required grace period as: longest
352
+ possible iteration time + checkpoint saving time. For example, if an iteration takes 2 minutes and
353
+ checkpoint saving takes 2 minutes, set at least 4 minutes (240 seconds) of grace time.
346
354
  save_on_each_node (`bool`, *optional*, defaults to `False`):
347
355
  When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
348
356
  the main one.
@@ -585,9 +593,9 @@ class TrainingArguments:
585
593
  instance of `Dataset`.
586
594
  report_to (`str` or `list[str]`, *optional*, defaults to `"none"`):
587
595
  The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
588
- `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
589
- `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all integrations
590
- installed, `"none"` for no integrations.
596
+ `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"swanlab"`,
597
+ `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"`
598
+ for no integrations.
591
599
  project (`str`, *optional*, defaults to `"huggingface"`):
592
600
  The name of the project to use for logging. Currently, only used by Trackio.
593
601
  trackio_space_id (`str` or `None`, *optional*, defaults to `"trackio"`):
@@ -852,7 +860,7 @@ class TrainingArguments:
852
860
  warmup_ratio: float | None = field(
853
861
  default=None,
854
862
  metadata={
855
- "help": "This argument is deprecated and will be removed in v5. Use `warmup_steps` instead as it also works with float values."
863
+ "help": "This argument is deprecated and will be removed in v5.2. Use `warmup_steps` instead as it also works with float values."
856
864
  },
857
865
  )
858
866
 
@@ -929,14 +937,24 @@ class TrainingArguments:
929
937
  " for `save_total_limit=5` and `load_best_model_at_end=True`, the four last checkpoints will always be"
930
938
  " retained alongside the best model. When `save_total_limit=1` and `load_best_model_at_end=True`,"
931
939
  " it is possible that two checkpoints are saved: the last one and the best one (if they are different)."
932
- " Default is unlimited checkpoints"
940
+ " Default is unlimited checkpoints."
933
941
  )
934
942
  },
935
943
  )
936
- save_safetensors: bool = field(
937
- default=True,
944
+ enable_jit_checkpoint: bool = field(
945
+ default=False,
938
946
  metadata={
939
- "help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
947
+ "help": (
948
+ "Whether to enable Just-In-Time (JIT) checkpointing on SIGTERM signal. "
949
+ "When enabled, training will checkpoint upon receiving SIGTERM, "
950
+ "allowing for graceful termination without losing progress. "
951
+ "This is particularly useful for shared clusters with preemptible workloads (Kueue). "
952
+ "IMPORTANT: You must configure your orchestrator's graceful shutdown period. "
953
+ "Kubernetes: set terminationGracePeriodSeconds (default 30s is insufficient!) in your job definition. "
954
+ "Slurm: use --signal=USR1@<seconds> in sbatch to send SIGTERM before time limit. "
955
+ "Calculate required grace period as: iteration time + checkpoint saving time. "
956
+ "Example: 2min iteration + 2min checkpoint = 240 seconds minimum."
957
+ )
940
958
  },
941
959
  )
942
960
  save_on_each_node: bool = field(
@@ -1504,14 +1522,6 @@ class TrainingArguments:
1504
1522
  f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
1505
1523
  )
1506
1524
 
1507
- if not self.save_safetensors:
1508
- logger.info(
1509
- f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. "
1510
- f"Safetensors should be a preferred weights saving format due to security and performance reasons. "
1511
- f"If your model cannot be saved by safetensors please feel free to open an issue at "
1512
- f"https://github.com/huggingface/safetensors!"
1513
- )
1514
-
1515
1525
  if (
1516
1526
  self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU
1517
1527
  ) and self.metric_for_best_model is None:
@@ -2359,8 +2369,8 @@ class TrainingArguments:
2359
2369
  report_to (`str` or `list[str]`, *optional*, defaults to `"none"`):
2360
2370
  The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
2361
2371
  `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
2362
- `"neptune"`, `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all
2363
- integrations installed, `"none"` for no integrations.
2372
+ `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all integrations
2373
+ installed, `"none"` for no integrations.
2364
2374
  first_step (`bool`, *optional*, defaults to `False`):
2365
2375
  Whether to log and evaluate the first `global_step` or not.
2366
2376
  nan_inf_filter (`bool`, *optional*, defaults to `True`):
@@ -2565,7 +2575,7 @@ class TrainingArguments:
2565
2575
  ```
2566
2576
  """
2567
2577
  if warmup_ratio is not None:
2568
- logger.warning("warmup_ratio is deprecated and will be removed in v5. Use `warmup_steps` instead.")
2578
+ logger.warning("warmup_ratio is deprecated and will be removed in v5.2 . Use `warmup_steps` instead.")
2569
2579
  warmup_steps = warmup_ratio
2570
2580
 
2571
2581
  self.lr_scheduler_type = SchedulerType(name)
@@ -2742,10 +2752,24 @@ class TrainingArguments:
2742
2752
  fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join(
2743
2753
  self.fsdp_config["transformer_layer_cls_to_wrap"]
2744
2754
  )
2745
- fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("fsdp_version", 1)
2755
+ fsdp_version = int(self.fsdp_config.get("version", 1))
2756
+ fsdp_plugin_args["fsdp_version"] = fsdp_version
2746
2757
  prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
2747
- fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper()
2748
- fsdp_plugin_args["forward_prefetch"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower()
2758
+ if fsdp_version == 2:
2759
+ fsdp_plugin_args["reshard_after_forward"] = str_to_bool(
2760
+ str(self.fsdp_config.get("reshard_after_forward", "false")).lower()
2761
+ )
2762
+ else:
2763
+ fsdp_plugin_args["forward_prefetch"] = str_to_bool(
2764
+ str(self.fsdp_config.get("forward_prefetch", "false")).lower()
2765
+ )
2766
+ fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper()
2767
+ fsdp_plugin_args["reshard_after_forward"] = str(
2768
+ self.fsdp_config.get("reshard_after_forward", "FULL_SHARD")
2769
+ ).lower()
2770
+ fsdp_plugin_args["use_orig_params"] = str_to_bool(
2771
+ str(self.fsdp_config.get("use_orig_params", "true")).lower()
2772
+ )
2749
2773
 
2750
2774
  sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower()
2751
2775
  cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()
@@ -2755,11 +2779,10 @@ class TrainingArguments:
2755
2779
  raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')
2756
2780
 
2757
2781
  # we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers
2758
- fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading
2782
+ fsdp_plugin_args["cpu_ram_efficient_loading"] = str_to_bool(cpu_ram_efficient_loading)
2759
2783
  os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
2760
2784
 
2761
- fsdp_plugin_args["sync_module_states"] = sync_module_states
2762
- fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
2785
+ fsdp_plugin_args["sync_module_states"] = str_to_bool(sync_module_states)
2763
2786
 
2764
2787
  return fsdp_plugin_args
2765
2788
 
@@ -2771,3 +2794,18 @@ class ParallelMode(Enum):
2771
2794
  SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel"
2772
2795
  SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel"
2773
2796
  TPU = "tpu"
2797
+
2798
+
2799
+ def str_to_bool(value, to_bool: bool = True) -> int | bool:
2800
+ """
2801
+ Converts a string representation of truth to `True` (1) or `False` (0).
2802
+
2803
+ True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
2804
+ """
2805
+ value = value.lower()
2806
+ if value in ("y", "yes", "t", "true", "on", "1"):
2807
+ return 1 if not to_bool else True
2808
+ elif value in ("n", "no", "f", "false", "off", "0"):
2809
+ return 0 if not to_bool else False
2810
+ else:
2811
+ raise ValueError(f"invalid truth value {value}")
@@ -91,7 +91,6 @@ from .hub import (
91
91
  extract_commit_hash,
92
92
  has_file,
93
93
  http_user_agent,
94
- is_offline_mode,
95
94
  list_repo_templates,
96
95
  try_to_load_from_cache,
97
96
  )
@@ -114,8 +113,6 @@ from .import_utils import (
114
113
  is_apex_available,
115
114
  is_apollo_torch_available,
116
115
  is_aqlm_available,
117
- is_auto_awq_available,
118
- is_auto_gptq_available,
119
116
  is_auto_round_available,
120
117
  is_av_available,
121
118
  is_bitsandbytes_available,
@@ -129,7 +126,8 @@ from .import_utils import (
129
126
  is_datasets_available,
130
127
  is_decord_available,
131
128
  is_detectron2_available,
132
- is_eetq_available,
129
+ is_env_variable_false,
130
+ is_env_variable_true,
133
131
  is_essentia_available,
134
132
  is_faiss_available,
135
133
  is_fbgemm_gpu_available,
@@ -161,6 +159,7 @@ from .import_utils import (
161
159
  is_libcst_available,
162
160
  is_librosa_available,
163
161
  is_liger_kernel_available,
162
+ is_llm_awq_available,
164
163
  is_lomo_available,
165
164
  is_matplotlib_available,
166
165
  is_mistral_common_available,
@@ -67,6 +67,7 @@ HARDCODED_CONFIG_FOR_MODELS = {
67
67
  "donut": "DonutSwinConfig",
68
68
  "esmfold": "EsmConfig",
69
69
  "parakeet": "ParakeetCTCConfig",
70
+ "lasr": "LasrCTCConfig",
70
71
  }
71
72
 
72
73
  _re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
@@ -21,7 +21,7 @@ import os
21
21
  import warnings
22
22
  from collections import OrderedDict, UserDict, defaultdict
23
23
  from collections.abc import Callable, Iterable, MutableMapping
24
- from contextlib import AbstractContextManager, ExitStack
24
+ from contextlib import AbstractContextManager, ExitStack, nullcontext
25
25
  from dataclasses import dataclass, fields, is_dataclass
26
26
  from enum import Enum
27
27
  from functools import partial, wraps
@@ -42,6 +42,7 @@ _is_torch_available = False
42
42
  if is_torch_available():
43
43
  # required for @can_return_tuple decorator to work with torchdynamo
44
44
  import torch
45
+ from torch.types import _dtype
45
46
 
46
47
  from ..model_debugging_utils import model_addition_debugger_context
47
48
 
@@ -154,6 +155,28 @@ def is_torch_dtype(x):
154
155
  return isinstance(x, torch.dtype)
155
156
 
156
157
 
158
+ def maybe_autocast(
159
+ device_type: str,
160
+ dtype: Optional["_dtype"] = None,
161
+ enabled: bool = True,
162
+ cache_enabled: Optional[bool] = None,
163
+ ):
164
+ """
165
+ Context manager that only autocasts if:
166
+
167
+ - `autocast` is already enabled in this context
168
+ - Or this call to `maybe_autocast` has `enabled=True`
169
+
170
+ This prevents `autocast` being added to the graph when it is effectively a no-op.
171
+ Which makes graph splitting in `torch.compile` more flexible as it removes the
172
+ requirement that partition IDs be monotonically increasing.
173
+ """
174
+ if torch.is_autocast_enabled(device_type) or enabled:
175
+ return torch.autocast(device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
176
+ else:
177
+ return nullcontext()
178
+
179
+
157
180
  def _is_mlx(x):
158
181
  import mlx.core as mx
159
182
 
@@ -680,6 +703,8 @@ class TransformersKwargs(TypedDict, total=False):
680
703
  Maximum sequence length for query state.
681
704
  max_length_k (`int`, *optional*):
682
705
  Maximum sequence length for key state.
706
+ position_ids (`torch.LongTensor`, *optional*)
707
+ Indices of positions of each input sequence tokens.
683
708
  """
684
709
 
685
710
  num_items_in_batch: Optional["torch.Tensor"]
@@ -690,6 +715,7 @@ class TransformersKwargs(TypedDict, total=False):
690
715
  cu_seq_lens_k: Optional["torch.LongTensor"]
691
716
  max_length_q: int | None
692
717
  max_length_k: int | None
718
+ position_ids: Optional["torch.LongTensor"]
693
719
 
694
720
 
695
721
  def is_timm_config_dict(config_dict: dict[str, Any]) -> bool:
transformers/utils/hub.py CHANGED
@@ -37,6 +37,7 @@ from huggingface_hub import (
37
37
  create_repo,
38
38
  hf_hub_download,
39
39
  hf_hub_url,
40
+ is_offline_mode,
40
41
  list_repo_tree,
41
42
  snapshot_download,
42
43
  try_to_load_from_cache,
@@ -83,13 +84,6 @@ class DownloadKwargs(TypedDict, total=False):
83
84
  commit_hash: str | None
84
85
 
85
86
 
86
- def is_offline_mode():
87
- # Import inside the function so test patches on `huggingface_hub.constants` are picked up.
88
- from huggingface_hub import constants as hf_hub_constants
89
-
90
- return hf_hub_constants.HF_HUB_OFFLINE
91
-
92
-
93
87
  # Determine default cache directory.
94
88
  # The best way to set the cache path is with the environment variable HF_HOME. For more details, check out this
95
89
  # documentation page: https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables.
@@ -727,8 +721,7 @@ class PushToHubMixin:
727
721
  revision: str | None = None,
728
722
  create_pr: bool = False,
729
723
  # Serialization details
730
- max_shard_size: int | str | None = "5GB",
731
- safe_serialization: bool = True,
724
+ max_shard_size: int | str | None = "50GB",
732
725
  tags: list[str] | None = None,
733
726
  ) -> str:
734
727
  """
@@ -751,13 +744,10 @@ class PushToHubMixin:
751
744
  Branch to push the uploaded files to.
752
745
  create_pr (`bool`, *optional*, defaults to `False`):
753
746
  Whether or not to create a PR with the uploaded files or directly commit.
754
- max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
747
+ max_shard_size (`int` or `str`, *optional*, defaults to `"50GB"`):
755
748
  Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
756
749
  will then be each of size lower than this size. If expressed as a string, needs to be digits followed
757
- by a unit (like `"5MB"`). We default it to `"5GB"` so that users can easily load models on free-tier
758
- Google Colab instances without any CPU OOM issues.
759
- safe_serialization (`bool`, *optional*, defaults to `True`):
760
- Whether or not to convert the model weights in safetensors format for safer serialization.
750
+ by a unit (like `"5MB"`).
761
751
  tags (`list[str]`, *optional*):
762
752
  List of tags to push on the Hub.
763
753
 
@@ -783,7 +773,7 @@ class PushToHubMixin:
783
773
 
784
774
  with tempfile.TemporaryDirectory() as tmp_dir:
785
775
  # Save all files.
786
- self.save_pretrained(tmp_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
776
+ self.save_pretrained(tmp_dir, max_shard_size=max_shard_size)
787
777
 
788
778
  # Update model card
789
779
  model_card.save(os.path.join(tmp_dir, "README.md"))
@@ -55,9 +55,15 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[
55
55
  # importlib.metadata works with the distribution package, which may be different from the import
56
56
  # name (e.g. `PIL` is the import name, but `pillow` is the distribution name)
57
57
  distributions = PACKAGE_DISTRIBUTION_MAPPING[pkg_name]
58
- # In most cases, the packages are well-behaved and both have the same name. If it's not the case, we
59
- # pick the first item of the list as best guess (it's almost always a list of length 1 anyway)
60
- distribution_name = pkg_name if pkg_name in distributions else distributions[0]
58
+ # Per PEP 503, underscores and hyphens are equivalent in package names.
59
+ # Prefer the distribution that matches the (normalized) package name.
60
+ normalized_pkg_name = pkg_name.replace("_", "-")
61
+ if normalized_pkg_name in distributions:
62
+ distribution_name = normalized_pkg_name
63
+ elif pkg_name in distributions:
64
+ distribution_name = pkg_name
65
+ else:
66
+ distribution_name = distributions[0]
61
67
  package_version = importlib.metadata.version(distribution_name)
62
68
  except (importlib.metadata.PackageNotFoundError, KeyError):
63
69
  # If we cannot find the metadata (because of editable install for example), try to import directly.
@@ -71,6 +77,16 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[
71
77
  return package_exists
72
78
 
73
79
 
80
+ def is_env_variable_true(env_variable: str) -> bool:
81
+ """Detect whether `env_variable` has been set to a true value in the environment"""
82
+ return os.getenv(env_variable, "false").lower() in ("true", "1", "y", "yes", "on")
83
+
84
+
85
+ def is_env_variable_false(env_variable: str) -> bool:
86
+ """Detect whether `env_variable` has been set to a false value in the environment"""
87
+ return os.getenv(env_variable, "true").lower() in ("false", "0", "n", "no", "off")
88
+
89
+
74
90
  ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
75
91
  ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
76
92
 
@@ -978,7 +994,7 @@ def is_optimum_available() -> bool:
978
994
 
979
995
 
980
996
  @lru_cache
981
- def is_auto_awq_available() -> bool:
997
+ def is_llm_awq_available() -> bool:
982
998
  return _is_package_available("awq")
983
999
 
984
1000
 
@@ -1015,21 +1031,11 @@ def is_compressed_tensors_available() -> bool:
1015
1031
  return _is_package_available("compressed_tensors")
1016
1032
 
1017
1033
 
1018
- @lru_cache
1019
- def is_auto_gptq_available() -> bool:
1020
- return _is_package_available("auto_gptq")
1021
-
1022
-
1023
1034
  @lru_cache
1024
1035
  def is_gptqmodel_available() -> bool:
1025
1036
  return _is_package_available("gptqmodel")
1026
1037
 
1027
1038
 
1028
- @lru_cache
1029
- def is_eetq_available() -> bool:
1030
- return _is_package_available("eetq")
1031
-
1032
-
1033
1039
  @lru_cache
1034
1040
  def is_fbgemm_gpu_available() -> bool:
1035
1041
  return _is_package_available("fbgemm_gpu")
@@ -1297,6 +1303,34 @@ def is_torch_fx_proxy(x):
1297
1303
  return False
1298
1304
 
1299
1305
 
1306
+ def is_jax_jitting(x):
1307
+ """returns True if we are inside of `jax.jit` context, False otherwise.
1308
+
1309
+ When a torch model is being compiled with `jax.jit` using torchax,
1310
+ the tensor that goes through the model would be an instance of
1311
+ `torchax.tensor.Tensor`, which is a tensor subclass. This tensor has
1312
+ a `jax` method to return the inner Jax array
1313
+ (https://github.com/google/torchax/blob/13ce870a1d9adb2430333c27bb623469e3aea34e/torchax/tensor.py#L134).
1314
+ Here we use ducktyping to detect if the inner jax array is a jax Tracer
1315
+ then we are in tracing context. (See more at: https://github.com/jax-ml/jax/discussions/9241)
1316
+
1317
+ Args:
1318
+ x: torch.Tensor
1319
+
1320
+ Returns:
1321
+ bool: whether we are inside of jax jit tracing.
1322
+ """
1323
+
1324
+ if not hasattr(x, "jax"):
1325
+ return False
1326
+ try:
1327
+ import jax
1328
+
1329
+ return isinstance(x.jax(), jax.core.Tracer)
1330
+ except Exception:
1331
+ return False
1332
+
1333
+
1300
1334
  def is_jit_tracing() -> bool:
1301
1335
  try:
1302
1336
  import torch
@@ -1306,13 +1340,24 @@ def is_jit_tracing() -> bool:
1306
1340
  return False
1307
1341
 
1308
1342
 
1343
+ def is_cuda_stream_capturing() -> bool:
1344
+ try:
1345
+ import torch
1346
+
1347
+ return torch.cuda.is_current_stream_capturing()
1348
+ except Exception:
1349
+ return False
1350
+
1351
+
1309
1352
  def is_tracing(tensor=None) -> bool:
1310
- """Checks whether we are tracing a graph with dynamo (compile or export), torch.jit, or torch.fx"""
1353
+ """Checks whether we are tracing a graph with dynamo (compile or export), torch.jit, torch.fx, jax.jit (with torchax) or
1354
+ CUDA stream capturing"""
1311
1355
  # Note that `is_torchdynamo_compiling` checks both compiling and exporting (the export check is stricter and
1312
1356
  # only checks export)
1313
- _is_tracing = is_torchdynamo_compiling() or is_jit_tracing()
1357
+ _is_tracing = is_torchdynamo_compiling() or is_jit_tracing() or is_cuda_stream_capturing()
1314
1358
  if tensor is not None:
1315
1359
  _is_tracing |= is_torch_fx_proxy(tensor)
1360
+ _is_tracing |= is_jax_jitting(tensor)
1316
1361
  return _is_tracing
1317
1362
 
1318
1363
 
@@ -208,6 +208,7 @@ class KernelConfig(PushToHubMixin):
208
208
  from kernels import Mode
209
209
 
210
210
  compatible_mapping = {}
211
+ current_device = infer_device(model)
211
212
  for layer_name, kernel in self.kernel_mapping.items():
212
213
  # Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE
213
214
  mode = Mode.TRAINING if model.training else Mode.INFERENCE
@@ -216,10 +217,11 @@ class KernelConfig(PushToHubMixin):
216
217
 
217
218
  if isinstance(kernel, str):
218
219
  repo_name = kernel
219
- device = infer_device(model)
220
- add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
220
+ add_to_mapping(layer_name, current_device, repo_name, mode, compatible_mapping)
221
221
  elif isinstance(kernel, dict):
222
222
  for device, repo_name in kernel.items():
223
+ if device != current_device:
224
+ continue
223
225
  add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
224
226
 
225
227
  self.kernel_mapping = compatible_mapping
@@ -148,9 +148,8 @@ def log_state_dict_report(
148
148
  mismatched_keys=None,
149
149
  mismatched_shapes=None,
150
150
  ignore_mismatched_sizes=True,
151
- misc=None,
151
+ conversion_errors=None,
152
152
  color=True, # allow disabling for plain logs
153
- min_width_full_table=60, # terminal min width to attempt full table
154
153
  ):
155
154
  """Log a readable report about state_dict loading issues.
156
155
 
@@ -165,12 +164,13 @@ def log_state_dict_report(
165
164
  missing_keys = missing_keys or []
166
165
  mismatched_keys = mismatched_keys or []
167
166
  mismatched_shapes = mismatched_shapes or []
168
- misc = misc or {}
167
+ conversion_errors = conversion_errors or {}
169
168
 
170
169
  # Detect whether the current stdout supports ANSI colors; allow callers to pass `color=False` to force no color
171
170
  color_enabled = bool(color and sys.stdout.isatty())
172
171
  ansi = ANSI(color_enabled)
173
172
 
173
+ # Re-raise errors early if needed
174
174
  if error_msgs:
175
175
  error_msg = "\n\t".join(error_msgs)
176
176
  if "size mismatch" in error_msg:
@@ -204,9 +204,9 @@ def log_state_dict_report(
204
204
  )
205
205
  rows.append(data)
206
206
 
207
- if misc:
208
- for k, v in update_key_name(misc).items():
209
- status = "MISC"
207
+ if conversion_errors:
208
+ for k, v in update_key_name(conversion_errors).items():
209
+ status = "CONVERSION"
210
210
  status = _color(status, "purple", ansi)
211
211
  _details = v[:term_w]
212
212
  rows.append([k, status, _details])
@@ -228,16 +228,25 @@ def log_state_dict_report(
228
228
  if unexpected_keys:
229
229
  tips += f"\n- {_color('UNEXPECTED', 'orange', ansi) + ansi['italic']}\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch."
230
230
  if missing_keys:
231
- tips += f"\n- {_color('MISSING', 'red', ansi) + ansi['italic']}\t:those params were newly initialized because missing form the checkpoint. Consider training on your downstream task."
231
+ tips += f"\n- {_color('MISSING', 'red', ansi) + ansi['italic']}\t:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task."
232
232
  if mismatched_keys:
233
- tips += f"\n- {_color('MISMATCH', 'yellow', ansi) + ansi['italic']}\t:ckpt weights were loaded, but they did not match the original empty weight."
234
- if misc:
235
- tips += f"\n- {_color('MISC', 'purple', ansi) + ansi['italic']}\t:originate from the conversion scheme"
233
+ tips += f"\n- {_color('MISMATCH', 'yellow', ansi) + ansi['italic']}\t:ckpt weights were loaded, but they did not match the original empty weight shapes."
234
+ if conversion_errors:
235
+ tips += f"\n- {_color('CONVERSION', 'purple', ansi) + ansi['italic']}\t:originate from the conversion scheme"
236
236
  tips += f"{ansi['reset']}"
237
237
 
238
+ # Log the report as warning
238
239
  logger.warning(prelude + table + tips)
240
+
241
+ # Re-raise in those case, after the report
242
+ if conversion_errors:
243
+ raise RuntimeError(
244
+ "We encountered some issues during automatic conversion of the weights. For details look at the `CONVERSION` entries of "
245
+ "the above report!"
246
+ )
239
247
  if not ignore_mismatched_sizes and mismatched_keys:
240
248
  raise RuntimeError(
241
249
  "You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!"
242
250
  )
251
+
243
252
  return prelude + table + tips