transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -16,22 +16,24 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
+ import math
19
20
  import os
20
21
  import re
21
22
  from abc import abstractmethod
22
23
  from collections import defaultdict
23
- from collections.abc import MutableMapping, MutableSet
24
+ from collections.abc import Callable, MutableMapping, MutableSet
24
25
  from concurrent.futures import Future, ThreadPoolExecutor
25
26
  from contextlib import contextmanager
26
27
  from copy import deepcopy
27
28
  from dataclasses import dataclass, field
29
+ from itertools import chain
28
30
  from typing import TYPE_CHECKING, Any, Optional, Union
29
31
 
30
32
  import torch
31
33
 
32
- from .integrations.accelerate import offload_weight
34
+ from .integrations.accelerate import get_device, offload_weight
33
35
  from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
34
- from .utils import is_torch_greater_or_equal, logging
36
+ from .utils import is_env_variable_true, is_torch_greater_or_equal, logging
35
37
 
36
38
 
37
39
  _torch_distributed_available = torch.distributed.is_available()
@@ -278,6 +280,166 @@ class PermuteForRope(ConversionOps):
278
280
  return output
279
281
 
280
282
 
283
+ class ErnieFuseAndSplitTextVisionExperts(ConversionOps):
284
+ r"""
285
+ Special operation that splits a module list over all keys and fuses over the number of original modules.
286
+
287
+ Example with 2 original modules "Gate" and "Up" with 2 target keys "Text" and "Vision":
288
+
289
+ ModuleList 1 ModuleList 2
290
+ [ Gate ] [ Up ]
291
+ | | | |
292
+ [Gate_Text] [Gate_Vision] [Up_Text] [Up_Vision]
293
+ \ \ / /
294
+ \ \ / /
295
+ \ / \ /
296
+ \ / \ /
297
+ [GateUp_Text] [GateUp_Vision]
298
+
299
+ The splits are equal and are defined by the amount of target keys.
300
+ The final fusions are defined by the amount of original module lists.
301
+ """
302
+
303
+ def __init__(self, stack_dim: int = 0, concat_dim: int = 1):
304
+ self.stack_dim = stack_dim
305
+ self.concat_dim = concat_dim
306
+
307
+ def split_list_into_chunks(self, tensor_list: list[torch.Tensor], chunks: int = 2):
308
+ split_size = math.ceil(len(tensor_list) / chunks) # best effort split size
309
+ return [tensor_list[i * split_size : (i + 1) * split_size] for i in range(chunks)]
310
+
311
+ @torch.no_grad()
312
+ def convert(
313
+ self,
314
+ input_dict: dict[str, list[torch.Tensor]],
315
+ source_patterns: list[str],
316
+ target_patterns: list[str],
317
+ config,
318
+ **kwargs,
319
+ ) -> dict[str, list[torch.Tensor]]:
320
+ valid_keys = input_dict.keys()
321
+ split_and_fused = defaultdict(list)
322
+ for key in source_patterns:
323
+ if key not in valid_keys:
324
+ raise ValueError(
325
+ f"Expected pattern {key} in collected tensors but only found tensors for: {valid_keys}"
326
+ )
327
+
328
+ tensors = input_dict.get(key, [])
329
+ split_tensor_lists = self.split_list_into_chunks(tensors, chunks=len(target_patterns))
330
+ stacked_tensors = (torch.stack(tensor_group, dim=self.stack_dim) for tensor_group in split_tensor_lists)
331
+ for idx, tensor_group in enumerate(stacked_tensors):
332
+ split_and_fused[target_patterns[idx]].append(tensor_group)
333
+
334
+ for k, v in split_and_fused.items():
335
+ split_and_fused[k] = torch.cat(v, dim=self.concat_dim)
336
+
337
+ return split_and_fused
338
+
339
+ @property
340
+ def reverse_op(self) -> ConversionOps:
341
+ return ErnieSplitAndDecoupleTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim)
342
+
343
+
344
+ class ErnieSplitAndDecoupleTextVisionExperts(ConversionOps):
345
+ r"""
346
+ Special operation that splits a fused module list over all original modules and
347
+ then decouples them into a mixed module list each over all keys.
348
+
349
+ Example with 2 original modules "Gate" and "Up" with 2 target keys "Text" and "Vision":
350
+
351
+ [GateUp_Text] [GateUp_Vision]
352
+ / \ / \
353
+ / \ / \
354
+ / / \ \
355
+ / / \ \
356
+ [Gate_Text] [Gate_Vision] [Up_Text] [Up_Vision]
357
+ | | | |
358
+ [ Gate ] [ Up ]
359
+ ModuleList 1 ModuleList 2
360
+
361
+ The splits are equal and are defined by the amount of original module lists.
362
+ The final decoupled module lists are defined by the amount of keys.
363
+ """
364
+
365
+ def __init__(self, stack_dim: int = 0, concat_dim: int = 1):
366
+ self.stack_dim = stack_dim
367
+ self.concat_dim = concat_dim
368
+
369
+ @torch.no_grad()
370
+ def convert(
371
+ self,
372
+ input_dict: dict[str, list[torch.Tensor]],
373
+ source_patterns: list[str],
374
+ target_patterns: list[str],
375
+ config,
376
+ **kwargs,
377
+ ) -> dict[str, list[torch.Tensor]]:
378
+ fused_modules = len(target_patterns)
379
+ valid_keys = input_dict.keys()
380
+ split_tensors = []
381
+ for key in source_patterns:
382
+ if key not in valid_keys:
383
+ raise ValueError(
384
+ f"Expected pattern {key} in collected tensors but only found tensors for: {valid_keys}"
385
+ )
386
+
387
+ # Assuming that we get single sized lists here to index with 0
388
+ split_tensors.append(input_dict[key][0].chunk(fused_modules, dim=self.concat_dim))
389
+
390
+ decoupled = {}
391
+ for idx, key in enumerate(target_patterns):
392
+ tensor_groups = [
393
+ list(torch.unbind(tensor_group[idx], dim=self.stack_dim)) for tensor_group in split_tensors
394
+ ]
395
+ tensor_list = list(chain.from_iterable(tensor_groups))
396
+ targets = [key.replace("*", f"{i}") for i in range(len(tensor_list))]
397
+ decoupled |= dict(zip(targets, tensor_list))
398
+
399
+ return decoupled
400
+
401
+ @property
402
+ def reverse_op(self) -> ConversionOps:
403
+ return ErnieFuseAndSplitTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim)
404
+
405
+
406
+ class Transpose(ConversionOps):
407
+ """
408
+ Transposes the given tensor along dim0 and dim1.
409
+ """
410
+
411
+ def __init__(self, dim0: int = 0, dim1: int = 1):
412
+ self.dim0 = dim0
413
+ self.dim1 = dim1
414
+
415
+ @torch.no_grad()
416
+ def convert(
417
+ self,
418
+ input_dict: dict[str, list[torch.Tensor]],
419
+ source_patterns: list[str],
420
+ target_patterns: list[str],
421
+ config,
422
+ **kwargs,
423
+ ) -> dict[str, list[torch.Tensor]]:
424
+ if len(input_dict) != len(target_patterns):
425
+ raise ValueError(
426
+ f"Transpose conversion can only happen on each key ({len(input_dict)}) "
427
+ f"and should match exact one target ({len(target_patterns)})."
428
+ )
429
+
430
+ output: dict[str, list[torch.Tensor]] = {}
431
+ for key, target_pattern in zip(input_dict.keys(), target_patterns):
432
+ tensor = input_dict.get(key, [])
433
+ if len(tensor) != 1:
434
+ raise ValueError(f"Transpose conversion requires exactly one tensor, found {len(tensor)}.")
435
+ output[target_pattern] = torch.transpose(tensor[0], dim0=self.dim0, dim1=self.dim1).contiguous()
436
+ return output
437
+
438
+ @property
439
+ def reverse_op(self) -> ConversionOps:
440
+ return Transpose(dim0=self.dim1, dim1=self.dim0)
441
+
442
+
281
443
  @dataclass(slots=True)
282
444
  class WeightTransform:
283
445
  source_patterns: Union[str, list[str]] = field(init=True)
@@ -302,8 +464,11 @@ class WeightTransform:
302
464
  for i, pattern in enumerate(self.target_patterns):
303
465
  # Some mapping contains `^` to notify start of string when matching -> remove it during reverse mapping
304
466
  pattern = pattern.removeprefix("^")
305
- # Remove negative lookahead if any. This is ugly but needed for reverse mapping of Qwen2.5 and Sam3!
306
- pattern = re.sub(r"\(\?!.+\)", "", pattern)
467
+ # Some mapping contains `$` to notify end of string when matching -> remove it during reverse mapping
468
+ pattern = pattern.removesuffix("$")
469
+ # Remove negative lookahead/behind if any. This is ugly but needed for reverse mapping of
470
+ # Qwen2.5, Sam3, Ernie4.5 VL MoE!
471
+ pattern = re.sub(r"\(\?.+\)", "", pattern)
307
472
  # Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3)
308
473
  if r"(.+)" in pattern:
309
474
  pattern = pattern.replace(r"(.+)", r"\1")
@@ -327,10 +492,6 @@ class WeightTransform:
327
492
  self.collected_tensors[source_pattern].append(future)
328
493
  self.layer_targets[target_key].add(source_key)
329
494
 
330
- def reset(self) -> None:
331
- """Clean-up the collected tensors to make sure we don't keep references to past tensors in memory."""
332
- self.collected_tensors = defaultdict(list)
333
-
334
495
  def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
335
496
  """
336
497
  Return a tuple (renamed_key, source_pattern_producing_the_match).
@@ -342,19 +503,19 @@ class WeightTransform:
342
503
  match_object = self.compiled_sources.search(source_key)
343
504
  if match_object is None:
344
505
  return source_key, None
506
+
345
507
  # Find the source that produced the match (it's the first group that matched, as the search stops after first branch match)
346
508
  matching_group_name = next(name for name, val in match_object.groupdict().items() if val is not None)
347
509
  source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])]
348
510
  # If we matched, we always replace with the first target pattern, in case we have several (one to many transform)
349
511
  replacement = self.target_patterns[0]
350
- # # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
512
+ # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
351
513
  if r"\1" in replacement:
352
514
  # The index of the internal group we need to replace is the index of the matched named group as it comes
353
515
  # inside that matched named group
354
516
  replaced_group_idx = self.compiled_sources.groupindex[matching_group_name] + 1
355
517
  replacement = replacement.replace(r"\1", match_object.group(replaced_group_idx))
356
518
  renamed_key = source_key.replace(match_object.group(0), replacement)
357
-
358
519
  return renamed_key, source_pattern_that_matched
359
520
 
360
521
  def reverse_transform(self) -> WeightTransform:
@@ -375,6 +536,32 @@ class WeightTransform:
375
536
 
376
537
  return reverse_transform
377
538
 
539
+ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]:
540
+ """
541
+ Materialize all the tensors that were saved in `self.collected_tensors`. This function removes them from the
542
+ internal attribute to avoid keeping them in memory during the different `self.convert` operations, and return
543
+ a new dictionary (otherwise we use more memory than needed during loading).
544
+
545
+ We basically have 3 cases here:
546
+ - async loading (default): the tensors are Future instances that we need to wait for
547
+ - sync loading: the tensors are Callable, we need to call the Callable to actually load them from disk
548
+ - saving: the tensors are already torch.Tensor instances (the existing model weights)
549
+ """
550
+ collected_tensors = {}
551
+ for key in set(self.collected_tensors.keys()):
552
+ # Remove from internal attribute
553
+ tensors = self.collected_tensors.pop(key)
554
+ # Async loading
555
+ if isinstance(tensors[0], Future):
556
+ tensors = [future.result() for future in tensors]
557
+ # Sync loading
558
+ elif callable(tensors[0]):
559
+ tensors = [func() for func in tensors]
560
+ # Add them to the new dictionary
561
+ collected_tensors[key] = tensors
562
+
563
+ return collected_tensors
564
+
378
565
 
379
566
  @dataclass(slots=True)
380
567
  class WeightRenaming(WeightTransform):
@@ -387,21 +574,21 @@ class WeightRenaming(WeightTransform):
387
574
  config=None,
388
575
  hf_quantizer=None,
389
576
  missing_keys: Optional[MutableSet[str]] = None,
390
- misc: Optional[MutableMapping[str, str]] = None,
577
+ conversion_errors: Optional[MutableMapping[str, str]] = None,
391
578
  ):
392
- # Collect the tensor if using threading
393
- for pattern, futures in self.collected_tensors.items():
394
- self.collected_tensors[pattern] = (
395
- futures if isinstance(futures[0], torch.Tensor) else [future.result() for future in futures]
396
- )
579
+ # Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
580
+ # attribute during the whole process
581
+ collected_tensors = self.materialize_tensors()
397
582
 
398
583
  # Perform renaming op (for a simple WeightRenaming, `self.source_patterns` and `self.target_patterns` can
399
584
  # only be of length 1, and are actually the full key names - we also have only 1 single related tensor)
400
585
  target_key = self.target_patterns[0]
401
- collected_tensors = {target_key: self.collected_tensors[self.source_patterns[0]]}
586
+ collected_tensors = {target_key: collected_tensors[self.source_patterns[0]]}
402
587
 
403
588
  if hf_quantizer is not None and self.quantization_operation is not None:
404
- with log_to_misc(layer_name, misc, (self.collected_tensors, layer_name), self.quantization_operation):
589
+ with log_conversion_errors(
590
+ layer_name, conversion_errors, (len(collected_tensors), layer_name), self.quantization_operation
591
+ ):
405
592
  collected_tensors = self.quantization_operation.convert(
406
593
  collected_tensors,
407
594
  source_patterns=self.source_patterns,
@@ -412,7 +599,14 @@ class WeightRenaming(WeightTransform):
412
599
  missing_keys=missing_keys,
413
600
  )
414
601
 
415
- return collected_tensors, misc
602
+ return collected_tensors, conversion_errors
603
+
604
+
605
+ # List of classes that are known to be able to use m:n
606
+ _INTERNAL_MANY_TO_MANY_CONVERSIONS = (
607
+ ErnieFuseAndSplitTextVisionExperts,
608
+ ErnieSplitAndDecoupleTextVisionExperts,
609
+ )
416
610
 
417
611
 
418
612
  @dataclass(slots=True)
@@ -422,9 +616,11 @@ class WeightConverter(WeightTransform):
422
616
  def __post_init__(self):
423
617
  WeightTransform.__post_init__(self)
424
618
  if bool(len(self.source_patterns) - 1) + bool(len(self.target_patterns) - 1) >= 2:
425
- raise ValueError(
426
- f"source keys={self.source_patterns}, target_patterns={self.target_patterns} but you can only have one to many, one to one or many to one."
427
- )
619
+ # We allow many-to-many only if we use an internal operation that can handle it
620
+ if not any(isinstance(op, _INTERNAL_MANY_TO_MANY_CONVERSIONS) for op in self.operations):
621
+ raise ValueError(
622
+ f"source keys={self.source_patterns}, target_patterns={self.target_patterns} but you can only have one to many, one to one or many to one."
623
+ )
428
624
  if not self.operations:
429
625
  raise ValueError("WeightConverter requires at least one operation.")
430
626
 
@@ -435,17 +631,14 @@ class WeightConverter(WeightTransform):
435
631
  config=None,
436
632
  hf_quantizer=None,
437
633
  missing_keys: Optional[MutableSet[str]] = None,
438
- misc: Optional[MutableMapping[str, str]] = None,
634
+ conversion_errors: Optional[MutableMapping[str, str]] = None,
439
635
  ):
440
- # Collect all tensors if using threading
441
- for pattern, futures in self.collected_tensors.items():
442
- self.collected_tensors[pattern] = (
443
- futures if isinstance(futures[0], torch.Tensor) else [future.result() for future in futures]
444
- )
636
+ # Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
637
+ # attribute during the whole process
638
+ collected_tensors = self.materialize_tensors()
445
639
 
446
- collected_tensors = self.collected_tensors
447
640
  for op in self.operations:
448
- with log_to_misc(layer_name, misc, (collected_tensors, layer_name), op):
641
+ with log_conversion_errors(layer_name, conversion_errors, (len(collected_tensors), layer_name), op):
449
642
  collected_tensors = op.convert(
450
643
  collected_tensors,
451
644
  source_patterns=self.source_patterns,
@@ -462,11 +655,19 @@ class WeightConverter(WeightTransform):
462
655
  full_name = layer_name
463
656
  if ".*." in layer_name:
464
657
  full_name = layer_name.replace(".*.", ".0.")
465
- prefix, _, suffix = next(full_name.partition(k) for k in collected_tensors.keys() if k in full_name)
466
- # Rename the tensors
467
- collected_tensors = {prefix + k + suffix: v for k, v in collected_tensors.items()}
658
+
659
+ try:
660
+ prefix, _, suffix = next(full_name.partition(k) for k in collected_tensors.keys() if k in full_name)
661
+ # Rename the tensors
662
+ collected_tensors = {prefix + k + suffix: v for k, v in collected_tensors.items()}
663
+ # some quantizers need to already rename in `convert` as they cannot only rely on prefix and suffix
664
+ except StopIteration:
665
+ pass
666
+
468
667
  if hf_quantizer is not None and self.quantization_operation is not None:
469
- with log_to_misc(layer_name, misc, (collected_tensors, layer_name), self.quantization_operation):
668
+ with log_conversion_errors(
669
+ layer_name, conversion_errors, (len(collected_tensors), layer_name), self.quantization_operation
670
+ ):
470
671
  collected_tensors = self.quantization_operation.convert(
471
672
  collected_tensors,
472
673
  source_patterns=self.source_patterns,
@@ -476,7 +677,7 @@ class WeightConverter(WeightTransform):
476
677
  model=model,
477
678
  missing_keys=missing_keys,
478
679
  )
479
- return collected_tensors, misc
680
+ return collected_tensors, conversion_errors
480
681
 
481
682
 
482
683
  # For I/O bound operations (i.e. here reading files), it is better to have fewer threads, e.g. 4 is a good default.
@@ -485,25 +686,46 @@ class WeightConverter(WeightTransform):
485
686
  GLOBAL_WORKERS = min(4, os.cpu_count() or 4)
486
687
 
487
688
 
488
- def _materialize_copy(tensor, device=None, dtype=None):
689
+ def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Tensor:
690
+ # This slicing is what actually loads the tensor from the safetensors slice object
489
691
  tensor = tensor[...]
490
692
  if dtype is not None or device is not None:
491
693
  tensor = tensor.to(device=device, dtype=dtype)
492
694
  return tensor
493
695
 
494
696
 
495
- def spawn_materialize(thread_pool, tensor, device=None, dtype=None) -> Future:
697
+ def spawn_materialize(
698
+ thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, device=None, dtype=None
699
+ ) -> Future | Callable:
700
+ """Materialize a tensor from file asynchronously if `thread_pool` is provided, or return a Callable that will
701
+ load the tensor synchronously when called."""
702
+
496
703
  def _job():
497
704
  return _materialize_copy(tensor, device, dtype)
498
705
 
499
- return thread_pool.submit(_job)
706
+ if thread_pool is not None:
707
+ return thread_pool.submit(_job)
708
+ else:
709
+ # Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
710
+ # memory during Conversion
711
+ return _job
712
+
500
713
 
714
+ def spawn_tp_materialize(
715
+ thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, sharding_method, tensor_idx, device=None, dtype=None
716
+ ) -> Future | Callable:
717
+ """Materialize and shard a tensor (according to the TP-plan) from file asynchronously if `thread_pool` is provided, or
718
+ return a Callable that will load the tensor synchronously when called."""
501
719
 
502
- def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
503
720
  def _job():
504
- return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]
721
+ return sharding_method.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype)
505
722
 
506
- return thread_pool.submit(_job)
723
+ if thread_pool is not None:
724
+ return thread_pool.submit(_job)
725
+ else:
726
+ # Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
727
+ # memory during Conversion
728
+ return _job
507
729
 
508
730
 
509
731
  def dot_natural_key(s: str):
@@ -516,13 +738,14 @@ def dot_natural_key(s: str):
516
738
 
517
739
 
518
740
  @contextmanager
519
- def log_to_misc(
741
+ def log_conversion_errors(
520
742
  first_target_key: str,
521
- misc: MutableMapping[str, str],
743
+ conversion_errors: MutableMapping[str, str],
522
744
  extras: Any = None,
523
745
  op: Union[list[ConversionOps], ConversionOps, None] = None,
524
746
  ):
525
- # A simple helper to handle errors with contextual messages.
747
+ """Catch all exceptions during `convert` calls, and log the errors for later. Re-raise a `SkipParameters` exception
748
+ that will be catched later to skip the parameters that raised the original Exception."""
526
749
  try:
527
750
  yield
528
751
  except Exception as e:
@@ -539,19 +762,21 @@ def log_to_misc(
539
762
 
540
763
  op_name = _format_op_name(op)
541
764
  if isinstance(extras, tuple) and len(extras) == 2:
542
- values, target_keys = extras
765
+ length, target_keys = extras
543
766
  descriptor = f"{op_name} " if op_name else ""
544
- misc[first_target_key] = (
545
- f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}"
767
+ conversion_errors[first_target_key] = (
768
+ f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {length}"
546
769
  )
547
770
  elif isinstance(extras, str):
548
771
  suffix = f" via {op_name}" if op_name else ""
549
- misc[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
772
+ conversion_errors[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
550
773
  elif extras is None and op_name:
551
- misc[first_target_key] = f"{op_name}: {e}"
774
+ conversion_errors[first_target_key] = f"{op_name}: {e}"
552
775
  else:
553
- misc[first_target_key] = f"{extras} |Error: {e}"
554
- raise SkipLayer()
776
+ conversion_errors[first_target_key] = f"{extras} |Error: {e}"
777
+
778
+ # Raise a specific Exception that we can catch easily
779
+ raise SkipParameters()
555
780
 
556
781
 
557
782
  def set_param_for_module(
@@ -560,22 +785,20 @@ def set_param_for_module(
560
785
  param_value: torch.Tensor,
561
786
  mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
562
787
  missing_keys: MutableSet[str],
563
- misc: MutableMapping[str, Any],
564
788
  unexpected_keys: MutableSet[str],
565
789
  distributed_operation: Optional[TensorParallelLayer],
566
790
  hf_quantizer: HfQuantizer,
567
791
  ):
568
- with log_to_misc(target_name, misc, target_name):
569
- module_path, _, param_name = target_name.rpartition(".")
570
- module_obj = model.get_submodule(module_path) if module_path else model
571
-
572
- ref = getattr(module_obj, param_name)
573
- if ref is None:
574
- unexpected_keys.add(target_name)
575
- else:
576
- use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
577
- if not isinstance(param_value, torch.nn.Parameter):
578
- if distributed_operation is not None:
792
+ module_path, _, param_name = target_name.rpartition(".")
793
+ module_obj = model.get_submodule(module_path) if module_path else model
794
+
795
+ ref = getattr(module_obj, param_name)
796
+ if ref is None:
797
+ unexpected_keys.add(target_name)
798
+ else:
799
+ if not isinstance(param_value, torch.nn.Parameter):
800
+ if distributed_operation is not None:
801
+ if getattr(distributed_operation, "use_dtensor", False):
579
802
  param_value = DTensor.from_local(
580
803
  param_value,
581
804
  distributed_operation.device_mesh,
@@ -584,20 +807,17 @@ def set_param_for_module(
584
807
  shape=ref.size(),
585
808
  stride=ref.stride(),
586
809
  )
587
- if not use_dtensor:
588
- # we convert to local
589
- param_value = param_value.to_local()
590
- if param_name not in module_obj._buffers:
591
- param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
592
-
593
- # Remove from missing keys (it's either mismatched, or all good)
594
- missing_keys.discard(target_name)
595
- if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
596
- mismatch_keys.add((target_name, param_value.shape, ref.shape))
597
- else:
598
- # super important otherwise _init_weight will re-init the param
599
- param_value._is_hf_initialized = True
600
- setattr(module_obj, param_name, param_value)
810
+ if param_name not in module_obj._buffers:
811
+ param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
812
+
813
+ # Remove from missing keys (it's either mismatched, or all good)
814
+ missing_keys.discard(target_name)
815
+ if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
816
+ mismatch_keys.add((target_name, param_value.shape, ref.shape))
817
+ else:
818
+ # super important otherwise _init_weight will re-init the param
819
+ param_value._is_hf_initialized = True
820
+ setattr(module_obj, param_name, param_value)
601
821
 
602
822
 
603
823
  def offload_and_maybe_resave_param(
@@ -619,8 +839,9 @@ def offload_and_maybe_resave_param(
619
839
  return disk_offload_index
620
840
 
621
841
 
622
- class SkipLayer(Exception):
623
- """Control-flow sentinel: abort processing of the current layer only."""
842
+ class SkipParameters(Exception):
843
+ """Control-flow sentinel: abort processing of the current parameters only (that were supposed to be created
844
+ by a WeightConverter)."""
624
845
 
625
846
  pass
626
847
 
@@ -675,6 +896,7 @@ def convert_and_load_state_dict_in_model(
675
896
  device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
676
897
  disk_offload_index: dict | None = None,
677
898
  disk_offload_folder: str | None = None,
899
+ offload_buffers: bool = False,
678
900
  ):
679
901
  r"""
680
902
  We build a mapping from the keys obtained by renaming each of the checkpoint keys according to the weight_mapping rules.
@@ -688,7 +910,7 @@ def convert_and_load_state_dict_in_model(
688
910
  target_patterns=["q", "k","v"],
689
911
  operations=[Chunk(dim=0, chunks=3)]),
690
912
  collected_tensors={
691
- "qkv": [Future, Future, Future]},
913
+ "qkv": [Future]},
692
914
  layer_targets={
693
915
  "model.layers.0.attention.q.weight": {"model.layers.0.attention.qkv.weight"},
694
916
  "model.layers.0.attention.k.weight": {"model.layers.0.attention.qkv.weight"},
@@ -765,25 +987,26 @@ def convert_and_load_state_dict_in_model(
765
987
  prefix = model.base_model_prefix
766
988
  tp_plan = tp_plan or {}
767
989
  device_map = device_map or {"": "cpu"}
768
- # Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
769
- device_map_regex = re.compile(
770
- "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
771
- )
772
990
  dtype_plan = dtype_plan or {}
773
991
  weight_mapping = weight_mapping or []
774
992
  meta_model_state_dict = model.state_dict()
775
- missing_keys = set(meta_model_state_dict.keys())
993
+ model_buffers = {k for k, _ in model.named_buffers()}
776
994
 
777
- misc = {}
995
+ missing_keys = set(meta_model_state_dict.keys())
996
+ conversion_errors = {}
778
997
  mismatch_keys = set()
779
998
  unexpected_keys = set()
780
- # Global thread_pool
781
- thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
999
+
1000
+ # We use threading by default, if not explicitly deactivated via env variable. If we have to offload,
1001
+ # we cannot use it either to control the memory as we are under memory constraints, so we need to be sequential
1002
+ if is_env_variable_true("HF_DEACTIVATE_ASYNC_LOAD") or "disk" in device_map.values():
1003
+ thread_pool = None
1004
+ else:
1005
+ thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
782
1006
 
783
1007
  renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
784
1008
  converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
785
-
786
- param_name_to_load: dict[str, Union[WeightRenaming | WeightConverter]] = {}
1009
+ param_name_to_load: dict[str, WeightRenaming | WeightConverter] = {}
787
1010
 
788
1011
  # build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
789
1012
  # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
@@ -826,41 +1049,40 @@ def convert_and_load_state_dict_in_model(
826
1049
  if hf_quantizer and hf_quantizer.pre_quantized and original_key != renamed_key:
827
1050
  # if the key was renamed as it is not available in the state dict otherwise, it means that we are deserializing it,
828
1051
  # so we need to make sure to load the tensor with the same dtype from the checkpoint
1052
+ # TODO: make the condition more srict for native fp8 model such as qwen2moe fp8
829
1053
  _dtype = None
830
1054
  elif dtype_plan != {} and dtype_policy_alt.search(renamed_key):
831
1055
  matched_dtype_pattern = dtype_policy_alt.search(renamed_key)
832
1056
  if matched_dtype_pattern is not None:
833
- _dtype = dtype_plan[matched_dtype_pattern.group()]
1057
+ _dtype = dtype_plan[dtype_policy_by_group_name[matched_dtype_pattern.lastgroup]]
834
1058
  elif empty_param is not None and empty_param.dtype != _dtype:
835
1059
  _dtype = empty_param.dtype # usually correct when initializing
836
1060
 
837
- # 4. Handle TP sharding or device_map placement -> scheduled materialization
838
- future = None
1061
+ # 4. Handle TP sharding or device_map placement
1062
+ future_or_tensor = None
839
1063
  if device_mesh:
840
1064
  if matched_tp_pattern := tp_plan_alt.search(renamed_key):
841
1065
  matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup]
842
1066
  if getattr(mapping, "distributed_operation", None) is None:
843
1067
  tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
844
1068
  mapping.distributed_operation = tp_layer(
845
- device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
1069
+ device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone()
846
1070
  )
847
1071
  shard_index = len(mapping.collected_tensors.get(original_key, []))
848
- future = spawn_tp_materialize(
1072
+ future_or_tensor = spawn_tp_materialize(
849
1073
  thread_pool,
850
1074
  tensor,
851
1075
  mapping.distributed_operation,
852
1076
  shard_index,
1077
+ device_map[""],
853
1078
  _dtype,
854
1079
  )
855
1080
 
856
- if future is None:
857
- device_match = device_map_regex.match(renamed_key)
858
- param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
859
- # If disk, we need to materialize on cpu first
860
- param_device = "cpu" if param_device == "disk" else param_device
861
- future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
1081
+ if future_or_tensor is None:
1082
+ param_device = get_device(device_map, renamed_key, valid_torch_device=True)
1083
+ future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype)
862
1084
 
863
- mapping.add_tensor(renamed_key, original_key, source_pattern, future)
1085
+ mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor)
864
1086
  elif source_pattern is not None: # add all target keys as unexpected
865
1087
  mapping = pattern_to_converter[source_pattern]
866
1088
  for k in mapping.target_patterns:
@@ -868,52 +1090,57 @@ def convert_and_load_state_dict_in_model(
868
1090
  else:
869
1091
  unexpected_keys.add(renamed_key)
870
1092
 
871
- total_entries = len(param_name_to_load)
872
- with logging.tqdm(total=total_entries, desc="Loading weights") as pbar:
873
- for first_param_name, mapping in param_name_to_load.items():
874
- pbar.update(1)
875
- pbar.set_postfix({"Materializing param": first_param_name})
876
- pbar.refresh()
877
- try:
878
- realized_value, misc = mapping.convert(
879
- first_param_name,
880
- model=model,
881
- config=model.config,
882
- hf_quantizer=hf_quantizer,
883
- missing_keys=missing_keys,
884
- misc=misc,
885
- )
886
- for target_name, param in realized_value.items():
887
- param = param[0] if isinstance(param, list) else param
888
- device_match = device_map_regex.match(target_name)
889
- param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
890
- # Offloading support
891
- if param_device == "disk":
892
- disk_offload_index = offload_and_maybe_resave_param(
893
- target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
894
- )
895
- else:
896
- set_param_for_module(
897
- model,
898
- target_name,
899
- param,
900
- mismatch_keys,
901
- missing_keys,
902
- misc,
903
- unexpected_keys,
904
- mapping.distributed_operation,
905
- hf_quantizer,
906
- )
907
-
908
- # Cleanup the tensors
909
- mapping.reset()
910
- except SkipLayer:
911
- continue
1093
+ try:
1094
+ total_entries = len(param_name_to_load)
1095
+ with logging.tqdm(total=total_entries, desc="Loading weights") as pbar:
1096
+ for first_param_name, mapping in param_name_to_load.items():
1097
+ pbar.update(1)
1098
+ pbar.set_postfix({"Materializing param": first_param_name})
1099
+ pbar.refresh()
1100
+ try:
1101
+ realized_value, conversion_errors = mapping.convert(
1102
+ first_param_name,
1103
+ model=model,
1104
+ config=model.config,
1105
+ hf_quantizer=hf_quantizer,
1106
+ missing_keys=missing_keys,
1107
+ conversion_errors=conversion_errors,
1108
+ )
1109
+ for target_name, param in realized_value.items():
1110
+ param = param[0] if isinstance(param, list) else param
1111
+ param_device = get_device(device_map, target_name)
1112
+ # Offloading support
1113
+ if param_device == "disk" and (target_name not in model_buffers or offload_buffers):
1114
+ disk_offload_index = offload_and_maybe_resave_param(
1115
+ target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
1116
+ )
1117
+ else:
1118
+ set_param_for_module(
1119
+ model,
1120
+ target_name,
1121
+ param,
1122
+ mismatch_keys,
1123
+ missing_keys,
1124
+ unexpected_keys,
1125
+ mapping.distributed_operation,
1126
+ hf_quantizer,
1127
+ )
1128
+
1129
+ # Cleanup all the tensors that were gathered before next iteration
1130
+ del realized_value
1131
+
1132
+ except SkipParameters:
1133
+ continue
1134
+
1135
+ # Close the pool, independently of whether the code was interrupted or finished successfully
1136
+ finally:
1137
+ if thread_pool is not None:
1138
+ # `cancel_futures=True` in case the program was interupted, to avoid wasting time on exit
1139
+ thread_pool.shutdown(wait=False, cancel_futures=True)
912
1140
 
913
1141
  # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user)
914
1142
  model._weight_conversions = weight_mapping
915
- thread_pool.shutdown(wait=False)
916
- return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc
1143
+ return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, conversion_errors
917
1144
 
918
1145
 
919
1146
  def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch.Tensor]):
@@ -960,7 +1187,7 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
960
1187
  new_state_dict = {}
961
1188
  for first_param_name, reversed_converter in conversion_mapping.items():
962
1189
  # Apply the reverse converter
963
- realized_value, misc = reversed_converter.convert(first_param_name, model=model, config=model.config)
1190
+ realized_value, _ = reversed_converter.convert(first_param_name, model=model, config=model.config)
964
1191
  for target_name, param in realized_value.items():
965
1192
  param = param[0] if isinstance(param, list) else param
966
1193
  new_state_dict[target_name] = param