transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import os
20
20
  import re
21
21
  from abc import abstractmethod
22
22
  from collections import defaultdict
23
- from collections.abc import MutableMapping, MutableSet
23
+ from collections.abc import Callable, MutableMapping, MutableSet
24
24
  from concurrent.futures import Future, ThreadPoolExecutor
25
25
  from contextlib import contextmanager
26
26
  from copy import deepcopy
@@ -31,7 +31,7 @@ import torch
31
31
 
32
32
  from .integrations.accelerate import offload_weight
33
33
  from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
34
- from .utils import is_torch_greater_or_equal, logging
34
+ from .utils import is_env_variable_true, is_torch_greater_or_equal, logging
35
35
 
36
36
 
37
37
  _torch_distributed_available = torch.distributed.is_available()
@@ -327,10 +327,6 @@ class WeightTransform:
327
327
  self.collected_tensors[source_pattern].append(future)
328
328
  self.layer_targets[target_key].add(source_key)
329
329
 
330
- def reset(self) -> None:
331
- """Clean-up the collected tensors to make sure we don't keep references to past tensors in memory."""
332
- self.collected_tensors = defaultdict(list)
333
-
334
330
  def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
335
331
  """
336
332
  Return a tuple (renamed_key, source_pattern_producing_the_match).
@@ -375,6 +371,32 @@ class WeightTransform:
375
371
 
376
372
  return reverse_transform
377
373
 
374
+ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]:
375
+ """
376
+ Materialize all the tensors that were saved in `self.collected_tensors`. This function removes them from the
377
+ internal attribute to avoid keeping them in memory during the different `self.convert` operations, and return
378
+ a new dictionary (otherwise we use more memory than needed during loading).
379
+
380
+ We basically have 3 cases here:
381
+ - async loading (default): the tensors are Future instances that we need to wait for
382
+ - sync loading: the tensors are Callable, we need to call the Callable to actually load them from disk
383
+ - saving: the tensors are already torch.Tensor instances (the existing model weights)
384
+ """
385
+ collected_tensors = {}
386
+ for key in set(self.collected_tensors.keys()):
387
+ # Remove from internal attribute
388
+ tensors = self.collected_tensors.pop(key)
389
+ # Async loading
390
+ if isinstance(tensors[0], Future):
391
+ tensors = [future.result() for future in tensors]
392
+ # Sync loading
393
+ elif callable(tensors[0]):
394
+ tensors = [func() for func in tensors]
395
+ # Add them to the new dictionary
396
+ collected_tensors[key] = tensors
397
+
398
+ return collected_tensors
399
+
378
400
 
379
401
  @dataclass(slots=True)
380
402
  class WeightRenaming(WeightTransform):
@@ -387,21 +409,21 @@ class WeightRenaming(WeightTransform):
387
409
  config=None,
388
410
  hf_quantizer=None,
389
411
  missing_keys: Optional[MutableSet[str]] = None,
390
- misc: Optional[MutableMapping[str, str]] = None,
412
+ conversion_errors: Optional[MutableMapping[str, str]] = None,
391
413
  ):
392
- # Collect the tensor if using threading
393
- for pattern, futures in self.collected_tensors.items():
394
- self.collected_tensors[pattern] = (
395
- futures if isinstance(futures[0], torch.Tensor) else [future.result() for future in futures]
396
- )
414
+ # Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
415
+ # attribute during the whole process
416
+ collected_tensors = self.materialize_tensors()
397
417
 
398
418
  # Perform renaming op (for a simple WeightRenaming, `self.source_patterns` and `self.target_patterns` can
399
419
  # only be of length 1, and are actually the full key names - we also have only 1 single related tensor)
400
420
  target_key = self.target_patterns[0]
401
- collected_tensors = {target_key: self.collected_tensors[self.source_patterns[0]]}
421
+ collected_tensors = {target_key: collected_tensors[self.source_patterns[0]]}
402
422
 
403
423
  if hf_quantizer is not None and self.quantization_operation is not None:
404
- with log_to_misc(layer_name, misc, (self.collected_tensors, layer_name), self.quantization_operation):
424
+ with log_conversion_errors(
425
+ layer_name, conversion_errors, (len(collected_tensors), layer_name), self.quantization_operation
426
+ ):
405
427
  collected_tensors = self.quantization_operation.convert(
406
428
  collected_tensors,
407
429
  source_patterns=self.source_patterns,
@@ -412,7 +434,7 @@ class WeightRenaming(WeightTransform):
412
434
  missing_keys=missing_keys,
413
435
  )
414
436
 
415
- return collected_tensors, misc
437
+ return collected_tensors, conversion_errors
416
438
 
417
439
 
418
440
  @dataclass(slots=True)
@@ -435,17 +457,14 @@ class WeightConverter(WeightTransform):
435
457
  config=None,
436
458
  hf_quantizer=None,
437
459
  missing_keys: Optional[MutableSet[str]] = None,
438
- misc: Optional[MutableMapping[str, str]] = None,
460
+ conversion_errors: Optional[MutableMapping[str, str]] = None,
439
461
  ):
440
- # Collect all tensors if using threading
441
- for pattern, futures in self.collected_tensors.items():
442
- self.collected_tensors[pattern] = (
443
- futures if isinstance(futures[0], torch.Tensor) else [future.result() for future in futures]
444
- )
462
+ # Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
463
+ # attribute during the whole process
464
+ collected_tensors = self.materialize_tensors()
445
465
 
446
- collected_tensors = self.collected_tensors
447
466
  for op in self.operations:
448
- with log_to_misc(layer_name, misc, (collected_tensors, layer_name), op):
467
+ with log_conversion_errors(layer_name, conversion_errors, (len(collected_tensors), layer_name), op):
449
468
  collected_tensors = op.convert(
450
469
  collected_tensors,
451
470
  source_patterns=self.source_patterns,
@@ -462,11 +481,19 @@ class WeightConverter(WeightTransform):
462
481
  full_name = layer_name
463
482
  if ".*." in layer_name:
464
483
  full_name = layer_name.replace(".*.", ".0.")
465
- prefix, _, suffix = next(full_name.partition(k) for k in collected_tensors.keys() if k in full_name)
466
- # Rename the tensors
467
- collected_tensors = {prefix + k + suffix: v for k, v in collected_tensors.items()}
484
+
485
+ try:
486
+ prefix, _, suffix = next(full_name.partition(k) for k in collected_tensors.keys() if k in full_name)
487
+ # Rename the tensors
488
+ collected_tensors = {prefix + k + suffix: v for k, v in collected_tensors.items()}
489
+ # some quantizers need to already rename in `convert` as they cannot only rely on prefix and suffix
490
+ except StopIteration:
491
+ pass
492
+
468
493
  if hf_quantizer is not None and self.quantization_operation is not None:
469
- with log_to_misc(layer_name, misc, (collected_tensors, layer_name), self.quantization_operation):
494
+ with log_conversion_errors(
495
+ layer_name, conversion_errors, (len(collected_tensors), layer_name), self.quantization_operation
496
+ ):
470
497
  collected_tensors = self.quantization_operation.convert(
471
498
  collected_tensors,
472
499
  source_patterns=self.source_patterns,
@@ -476,7 +503,7 @@ class WeightConverter(WeightTransform):
476
503
  model=model,
477
504
  missing_keys=missing_keys,
478
505
  )
479
- return collected_tensors, misc
506
+ return collected_tensors, conversion_errors
480
507
 
481
508
 
482
509
  # For I/O bound operations (i.e. here reading files), it is better to have fewer threads, e.g. 4 is a good default.
@@ -485,25 +512,46 @@ class WeightConverter(WeightTransform):
485
512
  GLOBAL_WORKERS = min(4, os.cpu_count() or 4)
486
513
 
487
514
 
488
- def _materialize_copy(tensor, device=None, dtype=None):
515
+ def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Tensor:
516
+ # This slicing is what actually loads the tensor from the safetensors slice object
489
517
  tensor = tensor[...]
490
518
  if dtype is not None or device is not None:
491
519
  tensor = tensor.to(device=device, dtype=dtype)
492
520
  return tensor
493
521
 
494
522
 
495
- def spawn_materialize(thread_pool, tensor, device=None, dtype=None) -> Future:
523
+ def spawn_materialize(
524
+ thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, device=None, dtype=None
525
+ ) -> Future | Callable:
526
+ """Materialize a tensor from file asynchronously if `thread_pool` is provided, or return a Callable that will
527
+ load the tensor synchronously when called."""
528
+
496
529
  def _job():
497
530
  return _materialize_copy(tensor, device, dtype)
498
531
 
499
- return thread_pool.submit(_job)
532
+ if thread_pool is not None:
533
+ return thread_pool.submit(_job)
534
+ else:
535
+ # Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
536
+ # memory during Conversion
537
+ return _job
500
538
 
501
539
 
502
- def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
540
+ def spawn_tp_materialize(
541
+ thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, sharding_method, tensor_idx, dtype=None
542
+ ) -> Future | Callable:
543
+ """Materialize and shard a tensor (according to the TP-plan) from file asynchronously if `thread_pool` is provided, or
544
+ return a Callable that will load the tensor synchronously when called."""
545
+
503
546
  def _job():
504
547
  return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]
505
548
 
506
- return thread_pool.submit(_job)
549
+ if thread_pool is not None:
550
+ return thread_pool.submit(_job)
551
+ else:
552
+ # Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
553
+ # memory during Conversion
554
+ return _job
507
555
 
508
556
 
509
557
  def dot_natural_key(s: str):
@@ -516,13 +564,14 @@ def dot_natural_key(s: str):
516
564
 
517
565
 
518
566
  @contextmanager
519
- def log_to_misc(
567
+ def log_conversion_errors(
520
568
  first_target_key: str,
521
- misc: MutableMapping[str, str],
569
+ conversion_errors: MutableMapping[str, str],
522
570
  extras: Any = None,
523
571
  op: Union[list[ConversionOps], ConversionOps, None] = None,
524
572
  ):
525
- # A simple helper to handle errors with contextual messages.
573
+ """Catch all exceptions during `convert` calls, and log the errors for later. Re-raise a `SkipParameters` exception
574
+ that will be catched later to skip the parameters that raised the original Exception."""
526
575
  try:
527
576
  yield
528
577
  except Exception as e:
@@ -539,19 +588,21 @@ def log_to_misc(
539
588
 
540
589
  op_name = _format_op_name(op)
541
590
  if isinstance(extras, tuple) and len(extras) == 2:
542
- values, target_keys = extras
591
+ length, target_keys = extras
543
592
  descriptor = f"{op_name} " if op_name else ""
544
- misc[first_target_key] = (
545
- f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}"
593
+ conversion_errors[first_target_key] = (
594
+ f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {length}"
546
595
  )
547
596
  elif isinstance(extras, str):
548
597
  suffix = f" via {op_name}" if op_name else ""
549
- misc[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
598
+ conversion_errors[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
550
599
  elif extras is None and op_name:
551
- misc[first_target_key] = f"{op_name}: {e}"
600
+ conversion_errors[first_target_key] = f"{op_name}: {e}"
552
601
  else:
553
- misc[first_target_key] = f"{extras} |Error: {e}"
554
- raise SkipLayer()
602
+ conversion_errors[first_target_key] = f"{extras} |Error: {e}"
603
+
604
+ # Raise a specific Exception that we can catch easily
605
+ raise SkipParameters()
555
606
 
556
607
 
557
608
  def set_param_for_module(
@@ -560,44 +611,42 @@ def set_param_for_module(
560
611
  param_value: torch.Tensor,
561
612
  mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
562
613
  missing_keys: MutableSet[str],
563
- misc: MutableMapping[str, Any],
564
614
  unexpected_keys: MutableSet[str],
565
615
  distributed_operation: Optional[TensorParallelLayer],
566
616
  hf_quantizer: HfQuantizer,
567
617
  ):
568
- with log_to_misc(target_name, misc, target_name):
569
- module_path, _, param_name = target_name.rpartition(".")
570
- module_obj = model.get_submodule(module_path) if module_path else model
571
-
572
- ref = getattr(module_obj, param_name)
573
- if ref is None:
574
- unexpected_keys.add(target_name)
618
+ module_path, _, param_name = target_name.rpartition(".")
619
+ module_obj = model.get_submodule(module_path) if module_path else model
620
+
621
+ ref = getattr(module_obj, param_name)
622
+ if ref is None:
623
+ unexpected_keys.add(target_name)
624
+ else:
625
+ use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
626
+ if not isinstance(param_value, torch.nn.Parameter):
627
+ if distributed_operation is not None:
628
+ param_value = DTensor.from_local(
629
+ param_value,
630
+ distributed_operation.device_mesh,
631
+ getattr(distributed_operation, "shard", Replicate()),
632
+ run_check=False,
633
+ shape=ref.size(),
634
+ stride=ref.stride(),
635
+ )
636
+ if not use_dtensor:
637
+ # we convert to local
638
+ param_value = param_value.to_local()
639
+ if param_name not in module_obj._buffers:
640
+ param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
641
+
642
+ # Remove from missing keys (it's either mismatched, or all good)
643
+ missing_keys.discard(target_name)
644
+ if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
645
+ mismatch_keys.add((target_name, param_value.shape, ref.shape))
575
646
  else:
576
- use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
577
- if not isinstance(param_value, torch.nn.Parameter):
578
- if distributed_operation is not None:
579
- param_value = DTensor.from_local(
580
- param_value,
581
- distributed_operation.device_mesh,
582
- getattr(distributed_operation, "shard", Replicate()),
583
- run_check=False,
584
- shape=ref.size(),
585
- stride=ref.stride(),
586
- )
587
- if not use_dtensor:
588
- # we convert to local
589
- param_value = param_value.to_local()
590
- if param_name not in module_obj._buffers:
591
- param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
592
-
593
- # Remove from missing keys (it's either mismatched, or all good)
594
- missing_keys.discard(target_name)
595
- if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
596
- mismatch_keys.add((target_name, param_value.shape, ref.shape))
597
- else:
598
- # super important otherwise _init_weight will re-init the param
599
- param_value._is_hf_initialized = True
600
- setattr(module_obj, param_name, param_value)
647
+ # super important otherwise _init_weight will re-init the param
648
+ param_value._is_hf_initialized = True
649
+ setattr(module_obj, param_name, param_value)
601
650
 
602
651
 
603
652
  def offload_and_maybe_resave_param(
@@ -619,8 +668,9 @@ def offload_and_maybe_resave_param(
619
668
  return disk_offload_index
620
669
 
621
670
 
622
- class SkipLayer(Exception):
623
- """Control-flow sentinel: abort processing of the current layer only."""
671
+ class SkipParameters(Exception):
672
+ """Control-flow sentinel: abort processing of the current parameters only (that were supposed to be created
673
+ by a WeightConverter)."""
624
674
 
625
675
  pass
626
676
 
@@ -688,7 +738,7 @@ def convert_and_load_state_dict_in_model(
688
738
  target_patterns=["q", "k","v"],
689
739
  operations=[Chunk(dim=0, chunks=3)]),
690
740
  collected_tensors={
691
- "qkv": [Future, Future, Future]},
741
+ "qkv": [Future]},
692
742
  layer_targets={
693
743
  "model.layers.0.attention.q.weight": {"model.layers.0.attention.qkv.weight"},
694
744
  "model.layers.0.attention.k.weight": {"model.layers.0.attention.qkv.weight"},
@@ -774,16 +824,20 @@ def convert_and_load_state_dict_in_model(
774
824
  meta_model_state_dict = model.state_dict()
775
825
  missing_keys = set(meta_model_state_dict.keys())
776
826
 
777
- misc = {}
827
+ conversion_errors = {}
778
828
  mismatch_keys = set()
779
829
  unexpected_keys = set()
780
- # Global thread_pool
781
- thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
830
+
831
+ # We use threading by default, if not explicitly deactivated via env variable. If we have to offload,
832
+ # we cannot use it either to control the memory as we are under memory constraints, so we need to be sequential
833
+ if is_env_variable_true("HF_DEACTIVATE_ASYNC_LOAD") or "disk" in device_map.values():
834
+ thread_pool = None
835
+ else:
836
+ thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
782
837
 
783
838
  renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
784
839
  converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
785
-
786
- param_name_to_load: dict[str, Union[WeightRenaming | WeightConverter]] = {}
840
+ param_name_to_load: dict[str, WeightRenaming | WeightConverter] = {}
787
841
 
788
842
  # build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
789
843
  # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
@@ -826,16 +880,17 @@ def convert_and_load_state_dict_in_model(
826
880
  if hf_quantizer and hf_quantizer.pre_quantized and original_key != renamed_key:
827
881
  # if the key was renamed as it is not available in the state dict otherwise, it means that we are deserializing it,
828
882
  # so we need to make sure to load the tensor with the same dtype from the checkpoint
883
+ # TODO: make the condition more srict for native fp8 model such as qwen2moe fp8
829
884
  _dtype = None
830
885
  elif dtype_plan != {} and dtype_policy_alt.search(renamed_key):
831
886
  matched_dtype_pattern = dtype_policy_alt.search(renamed_key)
832
887
  if matched_dtype_pattern is not None:
833
- _dtype = dtype_plan[matched_dtype_pattern.group()]
888
+ _dtype = dtype_plan[dtype_policy_by_group_name[matched_dtype_pattern.lastgroup]]
834
889
  elif empty_param is not None and empty_param.dtype != _dtype:
835
890
  _dtype = empty_param.dtype # usually correct when initializing
836
891
 
837
- # 4. Handle TP sharding or device_map placement -> scheduled materialization
838
- future = None
892
+ # 4. Handle TP sharding or device_map placement
893
+ future_or_tensor = None
839
894
  if device_mesh:
840
895
  if matched_tp_pattern := tp_plan_alt.search(renamed_key):
841
896
  matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup]
@@ -845,7 +900,7 @@ def convert_and_load_state_dict_in_model(
845
900
  device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
846
901
  )
847
902
  shard_index = len(mapping.collected_tensors.get(original_key, []))
848
- future = spawn_tp_materialize(
903
+ future_or_tensor = spawn_tp_materialize(
849
904
  thread_pool,
850
905
  tensor,
851
906
  mapping.distributed_operation,
@@ -853,14 +908,14 @@ def convert_and_load_state_dict_in_model(
853
908
  _dtype,
854
909
  )
855
910
 
856
- if future is None:
911
+ if future_or_tensor is None:
857
912
  device_match = device_map_regex.match(renamed_key)
858
913
  param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
859
914
  # If disk, we need to materialize on cpu first
860
915
  param_device = "cpu" if param_device == "disk" else param_device
861
- future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
916
+ future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype)
862
917
 
863
- mapping.add_tensor(renamed_key, original_key, source_pattern, future)
918
+ mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor)
864
919
  elif source_pattern is not None: # add all target keys as unexpected
865
920
  mapping = pattern_to_converter[source_pattern]
866
921
  for k in mapping.target_patterns:
@@ -868,52 +923,58 @@ def convert_and_load_state_dict_in_model(
868
923
  else:
869
924
  unexpected_keys.add(renamed_key)
870
925
 
871
- total_entries = len(param_name_to_load)
872
- with logging.tqdm(total=total_entries, desc="Loading weights") as pbar:
873
- for first_param_name, mapping in param_name_to_load.items():
874
- pbar.update(1)
875
- pbar.set_postfix({"Materializing param": first_param_name})
876
- pbar.refresh()
877
- try:
878
- realized_value, misc = mapping.convert(
879
- first_param_name,
880
- model=model,
881
- config=model.config,
882
- hf_quantizer=hf_quantizer,
883
- missing_keys=missing_keys,
884
- misc=misc,
885
- )
886
- for target_name, param in realized_value.items():
887
- param = param[0] if isinstance(param, list) else param
888
- device_match = device_map_regex.match(target_name)
889
- param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
890
- # Offloading support
891
- if param_device == "disk":
892
- disk_offload_index = offload_and_maybe_resave_param(
893
- target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
894
- )
895
- else:
896
- set_param_for_module(
897
- model,
898
- target_name,
899
- param,
900
- mismatch_keys,
901
- missing_keys,
902
- misc,
903
- unexpected_keys,
904
- mapping.distributed_operation,
905
- hf_quantizer,
906
- )
907
-
908
- # Cleanup the tensors
909
- mapping.reset()
910
- except SkipLayer:
911
- continue
926
+ try:
927
+ total_entries = len(param_name_to_load)
928
+ with logging.tqdm(total=total_entries, desc="Loading weights") as pbar:
929
+ for first_param_name, mapping in param_name_to_load.items():
930
+ pbar.update(1)
931
+ pbar.set_postfix({"Materializing param": first_param_name})
932
+ pbar.refresh()
933
+ try:
934
+ realized_value, conversion_errors = mapping.convert(
935
+ first_param_name,
936
+ model=model,
937
+ config=model.config,
938
+ hf_quantizer=hf_quantizer,
939
+ missing_keys=missing_keys,
940
+ conversion_errors=conversion_errors,
941
+ )
942
+ for target_name, param in realized_value.items():
943
+ param = param[0] if isinstance(param, list) else param
944
+ device_match = device_map_regex.match(target_name)
945
+ param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
946
+ # Offloading support
947
+ if param_device == "disk":
948
+ disk_offload_index = offload_and_maybe_resave_param(
949
+ target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
950
+ )
951
+ else:
952
+ set_param_for_module(
953
+ model,
954
+ target_name,
955
+ param,
956
+ mismatch_keys,
957
+ missing_keys,
958
+ unexpected_keys,
959
+ mapping.distributed_operation,
960
+ hf_quantizer,
961
+ )
962
+
963
+ # Cleanup all the tensors that were gathered before next iteration
964
+ del realized_value
965
+
966
+ except SkipParameters:
967
+ continue
968
+
969
+ # Close the pool, independently of whether the code was interrupted or finished successfully
970
+ finally:
971
+ if thread_pool is not None:
972
+ # `cancel_futures=True` in case the program was interupted, to avoid wasting time on exit
973
+ thread_pool.shutdown(wait=False, cancel_futures=True)
912
974
 
913
975
  # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user)
914
976
  model._weight_conversions = weight_mapping
915
- thread_pool.shutdown(wait=False)
916
- return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc
977
+ return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, conversion_errors
917
978
 
918
979
 
919
980
  def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch.Tensor]):
@@ -960,7 +1021,7 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
960
1021
  new_state_dict = {}
961
1022
  for first_param_name, reversed_converter in conversion_mapping.items():
962
1023
  # Apply the reverse converter
963
- realized_value, misc = reversed_converter.convert(first_param_name, model=model, config=model.config)
1024
+ realized_value, _ = reversed_converter.convert(first_param_name, model=model, config=model.config)
964
1025
  for target_name, param in realized_value.items():
965
1026
  param = param[0] if isinstance(param, list) else param
966
1027
  new_state_dict[target_name] = param
@@ -711,9 +711,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
711
711
  if self.random_replace_prob < 0 or self.random_replace_prob > 1:
712
712
  raise ValueError("random_replace_prob should be between 0 and 1.")
713
713
 
714
- self.mask_replace_prob = float(self.mask_replace_prob)
715
- self.random_replace_prob = float(self.random_replace_prob)
716
-
717
714
  if self.whole_word_mask:
718
715
  if not self.tokenizer.is_fast:
719
716
  warnings.warn(
@@ -729,6 +726,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
729
726
  self.mask_replace_prob = 1
730
727
  self.random_replace_prob = 0
731
728
 
729
+ self.mask_replace_prob = float(self.mask_replace_prob)
730
+ self.random_replace_prob = float(self.random_replace_prob)
731
+
732
732
  self.generator = None
733
733
 
734
734
  def get_generator(self, seed):
@@ -1413,9 +1413,17 @@ class DataCollatorWithFlattening(DefaultDataCollator):
1413
1413
  max_length = 0
1414
1414
  for seq_idx, sample in enumerate(features):
1415
1415
  input_ids = sample["input_ids"]
1416
+ # Convert to list if tensor
1417
+ if hasattr(input_ids, "tolist"):
1418
+ input_ids = input_ids.tolist()
1416
1419
  batch["input_ids"] += input_ids
1420
+
1417
1421
  if is_labels_provided:
1418
- batch["labels"] += [separator_id] + sample["labels"][1:]
1422
+ labels = sample["labels"]
1423
+ # Convert to list if tensor
1424
+ if hasattr(labels, "tolist"):
1425
+ labels = labels.tolist()
1426
+ batch["labels"] += [separator_id] + labels[1:]
1419
1427
  else:
1420
1428
  batch["labels"] += [separator_id] + input_ids[1:]
1421
1429
  if self.return_position_ids:
@@ -9,7 +9,6 @@ deps = {
9
9
  "blobfile": "blobfile",
10
10
  "codecarbon": "codecarbon>=2.8.1",
11
11
  "cookiecutter": "cookiecutter==1.7.3",
12
- "dataclasses": "dataclasses",
13
12
  "datasets": "datasets>=2.15.0",
14
13
  "deepspeed": "deepspeed>=0.9.3",
15
14
  "diffusers": "diffusers",
@@ -23,7 +22,7 @@ deps = {
23
22
  "GitPython": "GitPython<3.1.19",
24
23
  "hf-doc-builder": "hf-doc-builder>=0.3.0",
25
24
  "hf_xet": "hf_xet",
26
- "huggingface-hub": "huggingface-hub>=1.0.0,<2.0",
25
+ "huggingface-hub": "huggingface-hub>=1.2.1,<2.0",
27
26
  "importlib_metadata": "importlib_metadata",
28
27
  "ipadic": "ipadic>=1.0.0,<2.0",
29
28
  "jinja2": "jinja2>=3.1.0",
@@ -30,7 +30,7 @@ from pathlib import Path
30
30
  from types import ModuleType
31
31
  from typing import Any, Optional, Union
32
32
 
33
- from huggingface_hub import try_to_load_from_cache
33
+ from huggingface_hub import is_offline_mode, try_to_load_from_cache
34
34
  from packaging import version
35
35
 
36
36
  from .utils import (
@@ -38,7 +38,6 @@ from .utils import (
38
38
  TRANSFORMERS_DYNAMIC_MODULE_NAME,
39
39
  cached_file,
40
40
  extract_commit_hash,
41
- is_offline_mode,
42
41
  logging,
43
42
  )
44
43
  from .utils.import_utils import VersionComparison, split_package_version
@@ -22,7 +22,7 @@ from collections import UserDict
22
22
  from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
23
23
 
24
24
  import numpy as np
25
- from huggingface_hub import create_repo
25
+ from huggingface_hub import create_repo, is_offline_mode
26
26
 
27
27
  from .dynamic_module_utils import custom_object_save
28
28
  from .utils import (
@@ -32,7 +32,6 @@ from .utils import (
32
32
  TensorType,
33
33
  copy_func,
34
34
  is_numpy_array,
35
- is_offline_mode,
36
35
  is_torch_available,
37
36
  is_torch_device,
38
37
  is_torch_dtype,
@@ -68,7 +68,6 @@ from .utils import (
68
68
  is_in_notebook,
69
69
  is_ipex_available,
70
70
  is_librosa_available,
71
- is_offline_mode,
72
71
  is_onnx_available,
73
72
  is_pandas_available,
74
73
  is_phonemizer_available,
@@ -86,7 +86,11 @@ else:
86
86
  "StopStringCriteria",
87
87
  ]
88
88
  _import_structure["continuous_batching"] = [
89
+ "ContinuousBatchingManager",
89
90
  "ContinuousMixin",
91
+ "FIFOScheduler",
92
+ "PrefillFirstScheduler",
93
+ "Scheduler",
90
94
  ]
91
95
  _import_structure["utils"] = [
92
96
  "GenerationMixin",
@@ -127,7 +131,13 @@ if TYPE_CHECKING:
127
131
  EarlyExitCandidateGenerator,
128
132
  PromptLookupCandidateGenerator,
129
133
  )
130
- from .continuous_batching import ContinuousMixin
134
+ from .continuous_batching import (
135
+ ContinuousBatchingManager,
136
+ ContinuousMixin,
137
+ FIFOScheduler,
138
+ PrefillFirstScheduler,
139
+ Scheduler,
140
+ )
131
141
  from .logits_process import (
132
142
  AlternatingCodebooksLogitsProcessor,
133
143
  ClassifierFreeGuidanceLogitsProcessor,