transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@
16
16
  import collections
17
17
  import copy
18
18
  import functools
19
- import gc
20
19
  import importlib.metadata
21
20
  import inspect
22
21
  import json
@@ -26,7 +25,7 @@ import sys
26
25
  import warnings
27
26
  from abc import abstractmethod
28
27
  from collections import defaultdict
29
- from collections.abc import Callable, Sequence
28
+ from collections.abc import Callable, Iterator, Sequence
30
29
  from contextlib import contextmanager
31
30
  from enum import Enum
32
31
  from functools import partial, wraps
@@ -36,7 +35,7 @@ from typing import Optional, TypeVar, Union, get_type_hints
36
35
  from zipfile import is_zipfile
37
36
 
38
37
  import torch
39
- from huggingface_hub import create_repo, split_torch_state_dict_into_shards
38
+ from huggingface_hub import create_repo, is_offline_mode, split_torch_state_dict_into_shards
40
39
  from packaging import version
41
40
  from safetensors import safe_open
42
41
  from safetensors.torch import save_file as safe_save_file
@@ -63,7 +62,8 @@ from .integrations.accelerate import (
63
62
  accelerate_dispatch,
64
63
  check_and_set_device_map,
65
64
  expand_device_map,
66
- init_empty_weights,
65
+ get_device,
66
+ load_offloaded_parameter,
67
67
  )
68
68
  from .integrations.deepspeed import _load_state_dict_into_zero3_model
69
69
  from .integrations.eager_paged import eager_paged_attention_forward
@@ -85,7 +85,8 @@ from .integrations.tensor_parallel import (
85
85
  verify_tp_plan,
86
86
  )
87
87
  from .loss.loss_utils import LOSS_MAPPING
88
- from .modeling_flash_attention_utils import lazy_import_flash_attention
88
+ from .modeling_flash_attention_utils import lazy_import_flash_attention, lazy_import_paged_flash_attention
89
+ from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
89
90
  from .pytorch_utils import id_tensor_storage
90
91
  from .quantizers import HfQuantizer
91
92
  from .quantizers.auto import get_hf_quantizer
@@ -93,7 +94,6 @@ from .quantizers.quantizers_utils import get_module_from_name
93
94
  from .safetensors_conversion import auto_conversion
94
95
  from .utils import (
95
96
  ADAPTER_SAFE_WEIGHTS_NAME,
96
- ADAPTER_WEIGHTS_NAME,
97
97
  DUMMY_INPUTS,
98
98
  SAFE_WEIGHTS_INDEX_NAME,
99
99
  SAFE_WEIGHTS_NAME,
@@ -109,8 +109,8 @@ from .utils import (
109
109
  is_accelerate_available,
110
110
  is_flash_attn_2_available,
111
111
  is_flash_attn_3_available,
112
+ is_grouped_mm_available,
112
113
  is_kernels_available,
113
- is_offline_mode,
114
114
  is_torch_flex_attn_available,
115
115
  is_torch_greater_or_equal,
116
116
  is_torch_mlu_available,
@@ -132,7 +132,6 @@ from .utils.quantization_config import QuantizationMethod
132
132
  if is_accelerate_available():
133
133
  from accelerate.hooks import add_hook_to_module
134
134
  from accelerate.utils import extract_model_from_parallel
135
- from accelerate.utils.modeling import get_state_dict_from_offload
136
135
 
137
136
 
138
137
  _torch_distributed_available = torch.distributed.is_available()
@@ -154,10 +153,15 @@ logger = logging.get_logger(__name__)
154
153
  XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
155
154
  XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
156
155
  SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
157
- _init_weights = True
158
156
  _is_quantized = False
159
157
  _is_ds_init_called = False
160
158
 
159
+ # Mapping from flash attention implementations to their kernel fallback repositories
160
+ FLASH_ATTN_KERNEL_FALLBACK = {
161
+ "flash_attention_2": "kernels-community/flash-attn2",
162
+ "flash_attention_3": "kernels-community/vllm-flash-attn3",
163
+ }
164
+
161
165
 
162
166
  def is_local_dist_rank_0():
163
167
  return (
@@ -167,51 +171,6 @@ def is_local_dist_rank_0():
167
171
  )
168
172
 
169
173
 
170
- TORCH_INIT_FUNCTIONS = {
171
- "uniform_": nn.init.uniform_,
172
- "normal_": nn.init.normal_,
173
- "trunc_normal_": nn.init.trunc_normal_,
174
- "constant_": nn.init.constant_,
175
- "xavier_uniform_": nn.init.xavier_uniform_,
176
- "xavier_normal_": nn.init.xavier_normal_,
177
- "kaiming_uniform_": nn.init.kaiming_uniform_,
178
- "kaiming_normal_": nn.init.kaiming_normal_,
179
- "uniform": nn.init.uniform,
180
- "normal": nn.init.normal,
181
- "xavier_uniform": nn.init.xavier_uniform,
182
- "xavier_normal": nn.init.xavier_normal,
183
- "kaiming_uniform": nn.init.kaiming_uniform,
184
- "kaiming_normal": nn.init.kaiming_normal,
185
- "orthogonal_": nn.init.orthogonal_,
186
- }
187
-
188
-
189
- @contextmanager
190
- def no_init_weights():
191
- """
192
- Context manager to globally disable weight initialization to speed up loading large models.
193
- """
194
- global _init_weights
195
- old_init_weights = _init_weights
196
-
197
- _init_weights = False
198
-
199
- def _skip_init(*args, **kwargs):
200
- pass
201
-
202
- # Save the original initialization functions
203
- for name, init_func in TORCH_INIT_FUNCTIONS.items():
204
- setattr(torch.nn.init, name, _skip_init)
205
-
206
- try:
207
- yield
208
- finally:
209
- _init_weights = old_init_weights
210
- # Restore the original initialization functions
211
- for name, init_func in TORCH_INIT_FUNCTIONS.items():
212
- setattr(torch.nn.init, name, init_func)
213
-
214
-
215
174
  @contextmanager
216
175
  def set_quantized_state():
217
176
  global _is_quantized
@@ -235,23 +194,28 @@ def set_zero3_state():
235
194
  _is_ds_init_called = False
236
195
 
237
196
 
238
- def restore_default_dtype(func):
197
+ @contextmanager
198
+ def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
239
199
  """
240
- Decorator to restore the default torch dtype
241
- at the end of the function. Serves
242
- as a backup in case calling the function raises
243
- an error after the function has changed the default dtype but before it could restore it.
200
+ Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context.
201
+ If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid.
244
202
  """
203
+ # Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
204
+ if not dtype.is_floating_point:
205
+ if model_class_name is not None:
206
+ error_message = (
207
+ f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype"
208
+ )
209
+ else:
210
+ error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
211
+ raise ValueError(error_message)
245
212
 
246
- @wraps(func)
247
- def _wrapper(*args, **kwargs):
248
- old_dtype = torch.get_default_dtype()
249
- try:
250
- return func(*args, **kwargs)
251
- finally:
252
- torch.set_default_dtype(old_dtype)
253
-
254
- return _wrapper
213
+ original_dtype = torch.get_default_dtype()
214
+ try:
215
+ torch.set_default_dtype(dtype)
216
+ yield
217
+ finally:
218
+ torch.set_default_dtype(original_dtype)
255
219
 
256
220
 
257
221
  def get_torch_context_manager_or_global_device():
@@ -279,7 +243,9 @@ def get_state_dict_dtype(state_dict):
279
243
  return t.dtype
280
244
 
281
245
  # if no floating dtype was found return whatever the first dtype is
282
- return next(state_dict.values()).dtype
246
+ if len(state_dict) == 0:
247
+ return torch.float32
248
+ return next(iter(state_dict.values())).dtype
283
249
 
284
250
 
285
251
  str_to_torch_dtype = {
@@ -405,11 +371,94 @@ def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]
405
371
  return shared_tensors, identical
406
372
 
407
373
 
374
+ def remove_tied_weights_from_state_dict(
375
+ state_dict: dict[str, torch.Tensor], model: "PreTrainedModel"
376
+ ) -> dict[str, torch.Tensor]:
377
+ """
378
+ Remove all tied weights from the given `state_dict`, making sure to keep only the main weight that `model`
379
+ will expect when reloading (even if we know tie weights symmetrically, it's better to keep the intended one).
380
+ This is because `safetensors` does not allow tensor aliasing - so we're going to remove aliases before saving.
381
+ """
382
+ # To avoid any potential mistakes and mismatches between config and actual tied weights, here we check the pointers
383
+ # of the Tensors themselves -> we are guaranteed to find all the actual tied weights
384
+ ptrs = collections.defaultdict(list)
385
+ for name, tensor in state_dict.items():
386
+ if not isinstance(tensor, torch.Tensor):
387
+ # Sometimes in the state_dict we have non-tensor objects.
388
+ # e.g. in bitsandbytes we have some `str` objects in the state_dict
389
+ # In the non-tensor case, fall back to the pointer of the object itself
390
+ ptrs[id(tensor)].append(name)
391
+
392
+ elif tensor.device.type == "meta":
393
+ # In offloaded cases, there may be meta tensors in the state_dict.
394
+ # For these cases, key by the pointer of the original tensor object
395
+ # (state_dict tensors are detached and therefore no longer shared)
396
+ tensor = model.get_parameter(name)
397
+ ptrs[id(tensor)].append(name)
398
+
399
+ else:
400
+ ptrs[id_tensor_storage(tensor)].append(name)
401
+
402
+ shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
403
+
404
+ # Recursively descend to find tied weight keys
405
+ all_potential_tied_weights_keys = set(_get_tied_weight_keys(model))
406
+ error_names = []
407
+ to_delete_names = set()
408
+ # Removing the keys which are declared as known duplicates on load. This allows to make sure the name which is
409
+ # kept is consistent
410
+ if all_potential_tied_weights_keys is not None:
411
+ for names in shared_ptrs.values():
412
+ found = 0
413
+ for name in sorted(names):
414
+ matches_pattern = any(re.search(pat, name) for pat in all_potential_tied_weights_keys)
415
+ if matches_pattern and name in state_dict:
416
+ found += 1
417
+ if found < len(names):
418
+ to_delete_names.add(name)
419
+ # We are entering a place where the weights and the transformers configuration do NOT match.
420
+ shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
421
+ # Those are actually tensor sharing but disjoint from each other, we can safely clone them
422
+ # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
423
+ for name in disjoint_names:
424
+ state_dict[name] = state_dict[name].clone()
425
+
426
+ # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
427
+ # If the link between tensors was done at runtime then `from_pretrained` will not get
428
+ # the key back leading to random tensor. A proper warning will be shown
429
+ # during reload (if applicable), but since the file is not necessarily compatible with
430
+ # the config, better show a proper warning.
431
+ shared_names, identical_names = _find_identical(shared_names, state_dict)
432
+ # delete tensors that have identical storage
433
+ for inames in identical_names:
434
+ known = inames.intersection(to_delete_names)
435
+ for name in known:
436
+ del state_dict[name]
437
+ unknown = inames.difference(to_delete_names)
438
+ if len(unknown) > 1:
439
+ error_names.append(unknown)
440
+
441
+ if shared_names:
442
+ error_names.extend(shared_names)
443
+
444
+ if len(error_names) > 0:
445
+ raise RuntimeError(
446
+ f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. "
447
+ f"We found all the potential target tied weights keys to be: {all_potential_tied_weights_keys}.\n"
448
+ "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
449
+ )
450
+
451
+ return state_dict
452
+
453
+
408
454
  def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
409
- """Cast a single parameter `param_name` into the `model`, with value `tensor`."""
410
- module, param_type = get_module_from_name(model, param_name)
411
- # This will check potential shape mismatch if skipped before
412
- module.load_state_dict({param_type: tensor}, strict=False, assign=True)
455
+ """Cast a single parameter or buffer `param_name` into the `model`, with value `tensor`."""
456
+ parent, param_type = get_module_from_name(model, param_name)
457
+ if param_type in parent._parameters and not isinstance(tensor, nn.Parameter):
458
+ tensor = nn.Parameter(tensor, requires_grad=tensor.is_floating_point())
459
+ # We need to use setattr here, as we set non-persistent buffers as well with this function (`load_state_dict`
460
+ # does not allow to do it)
461
+ setattr(parent, param_type, tensor)
413
462
 
414
463
 
415
464
  def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
@@ -552,8 +601,7 @@ def _get_resolved_checkpoint_files(
552
601
  raise OSError(
553
602
  f"{pretrained_model_name_or_path} does not appear to have a file named"
554
603
  f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
555
- "and thus cannot be loaded with `safetensors`. Please make sure that the model has "
556
- "been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
604
+ "and thus cannot be loaded with `safetensors`. Please do not set `use_safetensors=True`."
557
605
  )
558
606
  else:
559
607
  # This repo has no safetensors file of any kind, we switch to PyTorch.
@@ -697,23 +745,22 @@ def _get_resolved_checkpoint_files(
697
745
 
698
746
 
699
747
  def _get_dtype(
700
- cls,
701
748
  dtype: Optional[Union[str, torch.dtype, dict]],
702
749
  checkpoint_files: Optional[list[str]],
703
750
  config: PreTrainedConfig,
704
751
  sharded_metadata: Optional[dict],
705
752
  state_dict: Optional[dict],
706
753
  weights_only: bool,
707
- ) -> tuple[PreTrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
754
+ hf_quantizer: Optional[HfQuantizer] = None,
755
+ ) -> tuple[PreTrainedConfig, torch.dtype]:
708
756
  """Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
709
757
  inferred dtype. We do the following:
710
- 1. If dtype is not None, we use that dtype
711
- 2. If dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
712
- weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
713
- we also may have config.dtype available, but we won't rely on it till v5
758
+ 1. If dtype is "auto", we try to read the config, else auto-detect dtype from the loaded state_dict, by checking
759
+ its first weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
760
+ 2. Else, use the dtype provided as a dict or str
714
761
  """
715
- dtype_orig = None
716
762
  is_sharded = sharded_metadata is not None
763
+ asked_dtype = dtype
717
764
 
718
765
  if dtype is not None:
719
766
  if isinstance(dtype, str):
@@ -737,43 +784,49 @@ def _get_dtype(
737
784
  )
738
785
  elif hasattr(torch, dtype):
739
786
  dtype = getattr(torch, dtype)
740
- config.dtype = dtype
741
- for sub_config_key in config.sub_configs:
742
- if (sub_config := getattr(config, sub_config_key)) is not None:
743
- sub_config.dtype = dtype
744
- elif isinstance(dtype, torch.dtype):
745
- config.dtype = dtype
746
- for sub_config_key in config.sub_configs:
747
- if (sub_config := getattr(config, sub_config_key)) is not None:
748
- sub_config.dtype = dtype
749
- elif isinstance(dtype, dict):
750
- for key, curr_dtype in dtype.items():
751
- if hasattr(config, key):
752
- value = getattr(config, key)
753
- curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
754
- value.dtype = curr_dtype
755
- # main torch dtype for modules that aren't part of any sub-config
756
- dtype = dtype.get("")
757
- dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
758
- config.dtype = dtype
759
- if dtype is None:
760
- dtype = torch.float32
761
- else:
787
+ else:
788
+ raise ValueError(
789
+ "`dtype` provided as a `str` can only be `'auto'`, or a string representation of a valid `torch.dtype`"
790
+ )
791
+
792
+ # cast it to a proper `torch.dtype` object
793
+ dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
794
+ elif not isinstance(dtype, (dict, torch.dtype)):
762
795
  raise ValueError(
763
796
  f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
764
797
  f"for each sub-config in composite configs, but received {dtype}"
765
798
  )
799
+ else:
800
+ # set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
801
+ dtype = torch.get_default_dtype()
766
802
 
767
- dtype_orig = cls._set_default_dtype(dtype)
803
+ if hf_quantizer is not None:
804
+ hf_quantizer.update_dtype(dtype)
805
+
806
+ # Get the main dtype
807
+ if isinstance(dtype, dict):
808
+ main_dtype = dtype.get("", torch.get_default_dtype())
809
+ main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
768
810
  else:
769
- # set fp32 as the default dtype for BC
770
- default_dtype = torch.get_default_dtype()
771
- config.dtype = default_dtype
772
- for key in config.sub_configs:
773
- if (sub_config := getattr(config, key)) is not None:
774
- sub_config.dtype = default_dtype
811
+ main_dtype = dtype
812
+
813
+ # Set it on the config and subconfigs
814
+ config.dtype = main_dtype
815
+ for sub_config_key in config.sub_configs:
816
+ if (sub_config := getattr(config, sub_config_key)) is not None:
817
+ # The dtype was "auto" -> try to read the subconfig dtype value if any
818
+ if asked_dtype == "auto":
819
+ sub_dtype = getattr(sub_config, "dtype", main_dtype)
820
+ sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
821
+ # The dtype was provided as a dict, try to see if we match the subconfig name
822
+ elif isinstance(dtype, dict):
823
+ sub_dtype = dtype.get(sub_config_key, main_dtype)
824
+ sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
825
+ else:
826
+ sub_dtype = main_dtype
827
+ sub_config.dtype = sub_dtype
775
828
 
776
- return config, dtype, dtype_orig
829
+ return config, main_dtype
777
830
 
778
831
 
779
832
  class PipelineParallel(Enum):
@@ -969,54 +1022,52 @@ class EmbeddingAccessMixin:
969
1022
  `nn.Module`: A torch module mapping vocabulary to hidden states.
970
1023
  """
971
1024
 
972
- # 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
973
- # for most NLP models), and if so, return it.
974
-
975
1025
  name = getattr(self, "_input_embed_layer", "embed_tokens")
976
1026
 
1027
+ # 1) Direct attribute (most NLP models).
977
1028
  if (default_embedding := getattr(self, name, None)) is not None:
978
1029
  return default_embedding
979
- # 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1030
+ # 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision/audio models).
1031
+ if hasattr(self, "embeddings") and hasattr(self.embeddings, name):
1032
+ return getattr(self.embeddings, name)
1033
+ # 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides).
1034
+ if hasattr(self, "model") and hasattr(self.model, name):
1035
+ return getattr(self.model, name)
980
1036
 
981
- if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
982
- return self.model.embed_tokens
1037
+ if hasattr(self, "base_model"):
1038
+ base_model = self.base_model
1039
+ if base_model is not None and base_model is not self:
1040
+ return base_model.get_input_embeddings()
983
1041
 
984
- # 3) vanilla decoder‑only architectures
985
- elif hasattr(self, "embed_tokens"):
986
- return self.embed_tokens
987
- else:
988
- base_model = getattr(self, "base_model_prefix", None)
989
- if base_model is not None:
990
- base_model = getattr(self, base_model, None)
991
- if base_model is not None and base_model is not self:
992
- return base_model.get_input_embeddings()
993
- raise NotImplementedError(
994
- f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
995
- "please override in the subclass."
996
- )
1042
+ raise NotImplementedError(
1043
+ f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
1044
+ )
997
1045
 
998
1046
  def set_input_embeddings(self, value: nn.Module):
999
1047
  """Fallback setter that handles **~70%** of models in the code-base.
1000
1048
 
1001
1049
  Order of attempts:
1002
- 1. `self.model.embed_tokens`
1003
- 2. `self.embed_tokens`
1004
- 3. delegate to the *base model* if one exists
1005
- 4. otherwise raise `NotImplementedError` so subclasses still can (and
1050
+ 1. `self.<_input_embed_layer>` (direct attribute)
1051
+ 2. `self.embeddings.<_input_embed_layer>` (nested embeddings for vision/audio models)
1052
+ 3. `self.model.<_input_embed_layer>` (encoder/decoder models)
1053
+ 4. delegate to the *base model* if one exists
1054
+ 5. otherwise raise `NotImplementedError` so subclasses still can (and
1006
1055
  should) override for exotic layouts.
1007
1056
  """
1008
1057
 
1009
- # 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1010
1058
  name = getattr(self, "_input_embed_layer", "embed_tokens")
1011
- if hasattr(self, "model") and hasattr(self.model, name):
1012
- setattr(self.model, name, value)
1013
- # 2) as well as vanilla decoder‑only architectures
1014
- elif hasattr(self, name):
1059
+ # 1) Direct attribute (most NLP models)
1060
+ if hasattr(self, name):
1015
1061
  setattr(self, name, value)
1016
- # 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
1017
- elif getattr(self, self.base_model_prefix, self) is not self:
1018
- base_model = getattr(self, self.base_model_prefix, self)
1019
- base_model.set_input_embeddings(value)
1062
+ # 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models)
1063
+ elif hasattr(self, "embeddings") and hasattr(self.embeddings, name):
1064
+ setattr(self.embeddings, name, value)
1065
+ # 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1066
+ elif hasattr(self, "model") and hasattr(self.model, name):
1067
+ setattr(self.model, name, value)
1068
+ # 4) recurse once into the registered *base* model (e.g. for encoder/decoder)
1069
+ elif hasattr(self, "base_model") and self.base_model is not self:
1070
+ self.base_model.set_input_embeddings(value)
1020
1071
  else:
1021
1072
  raise NotImplementedError(
1022
1073
  f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
@@ -1228,6 +1279,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1228
1279
  self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
1229
1280
  self.config._attn_implementation, is_init_check=True
1230
1281
  )
1282
+ # Check the experts implementation is supported, or set it if not yet set (on the internal attr, to avoid
1283
+ # setting it recursively)
1284
+ self.config._experts_implementation_internal = self._check_and_adjust_experts_implementation(
1285
+ self.config._experts_implementation
1286
+ )
1231
1287
  if self.can_generate():
1232
1288
  self.generation_config = GenerationConfig.from_model_config(config)
1233
1289
 
@@ -1343,7 +1399,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1343
1399
  def pp_plan(self, plan: dict[str, tuple[str, str]]):
1344
1400
  self._pp_plan = plan
1345
1401
 
1346
- def dequantize(self):
1402
+ def dequantize(self, dtype=None):
1347
1403
  """
1348
1404
  Potentially dequantize the model in case it has been quantized by a quantization method that support
1349
1405
  dequantization.
@@ -1353,7 +1409,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1353
1409
  if hf_quantizer is None:
1354
1410
  raise ValueError("You need to first quantize your model in order to dequantize it")
1355
1411
 
1356
- return hf_quantizer.dequantize(self)
1412
+ return hf_quantizer.dequantize(self, dtype=dtype)
1357
1413
 
1358
1414
  def _backward_compatibility_gradient_checkpointing(self):
1359
1415
  if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
@@ -1394,7 +1450,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1394
1450
  self.model_tags.append(tag)
1395
1451
 
1396
1452
  @classmethod
1397
- @restore_default_dtype
1398
1453
  def _from_config(cls, config, **kwargs):
1399
1454
  """
1400
1455
  All context managers that the model should be initialized under go here.
@@ -1403,9 +1458,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1403
1458
  dtype (`torch.dtype`, *optional*):
1404
1459
  Override the default `dtype` and load the model under this dtype.
1405
1460
  """
1406
- # when we init a model from within another model (e.g. VLMs) and dispatch on FA2
1407
- # a warning is raised that dtype should be fp16. Since we never pass dtype from within
1408
- # modeling code, we can try to infer it here same way as done in `from_pretrained`
1409
1461
  # For BC on the old `torch_dtype`
1410
1462
  dtype = kwargs.pop("dtype", config.dtype)
1411
1463
  if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
@@ -1415,61 +1467,32 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1415
1467
  if isinstance(dtype, str):
1416
1468
  dtype = getattr(torch, dtype)
1417
1469
 
1418
- # override default dtype if needed
1419
- dtype_orig = None
1420
- if dtype is not None:
1421
- dtype_orig = cls._set_default_dtype(dtype)
1422
-
1423
1470
  # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
1424
1471
  if "attn_implementation" in kwargs:
1425
1472
  config._attn_implementation = kwargs.pop("attn_implementation")
1426
1473
 
1474
+ # If passing `experts_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
1475
+ if "experts_implementation" in kwargs:
1476
+ config._experts_implementation = kwargs.pop("experts_implementation")
1477
+
1478
+ init_contexts = []
1479
+ if dtype is not None:
1480
+ init_contexts.append(local_torch_dtype(dtype, cls.__name__))
1481
+
1427
1482
  if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
1428
1483
  logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
1429
1484
  # this immediately partitions the model across all gpus, to avoid the overhead in time
1430
1485
  # and memory copying it on CPU or each GPU first
1431
1486
  import deepspeed
1432
1487
 
1433
- init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
1434
- with ContextManagers(init_contexts):
1435
- model = cls(config, **kwargs)
1488
+ init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
1436
1489
 
1437
- else:
1490
+ # Instantiate the model
1491
+ with ContextManagers(init_contexts):
1438
1492
  model = cls(config, **kwargs)
1439
1493
 
1440
- # restore default dtype if it was modified
1441
- if dtype_orig is not None:
1442
- torch.set_default_dtype(dtype_orig)
1443
-
1444
1494
  return model
1445
1495
 
1446
- @classmethod
1447
- def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
1448
- """
1449
- Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
1450
- under specific dtype.
1451
-
1452
- Args:
1453
- dtype (`torch.dtype`):
1454
- a floating dtype to set to.
1455
-
1456
- Returns:
1457
- `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
1458
- modified. If it wasn't, returns `None`.
1459
-
1460
- Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
1461
- `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
1462
- """
1463
- if not dtype.is_floating_point:
1464
- raise ValueError(
1465
- f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
1466
- )
1467
-
1468
- logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
1469
- dtype_orig = torch.get_default_dtype()
1470
- torch.set_default_dtype(dtype)
1471
- return dtype_orig
1472
-
1473
1496
  @property
1474
1497
  def base_model(self) -> nn.Module:
1475
1498
  """
@@ -1546,7 +1569,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1546
1569
  return True
1547
1570
 
1548
1571
  if is_torch_xpu_available():
1549
- logger.info("Detect using FlashAttention2 (via kernel `kernels-community/flash-attn2`) on XPU.")
1572
+ logger.info(
1573
+ f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU."
1574
+ )
1550
1575
  return True
1551
1576
 
1552
1577
  if importlib.util.find_spec("flash_attn") is None:
@@ -1715,6 +1740,22 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1715
1740
 
1716
1741
  return True
1717
1742
 
1743
+ def _grouped_mm_can_dispatch(self) -> bool:
1744
+ """
1745
+ Check the availability of Grouped MM for a given model.
1746
+ """
1747
+
1748
+ if not self._can_set_experts_implementation():
1749
+ raise ValueError(f"{self.__class__.__name__} does not support setting experts implementation.")
1750
+
1751
+ if not is_grouped_mm_available():
1752
+ raise ImportError(
1753
+ "PyTorch Grouped MM requirements in Transformers are not met. Please install torch>=2.9.0."
1754
+ )
1755
+
1756
+ # If no error raised by this point, we can return `True`
1757
+ return True
1758
+
1718
1759
  def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
1719
1760
  """
1720
1761
  Check the availability of Flex Attention for a given model.
@@ -1764,9 +1805,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1764
1805
  """
1765
1806
  applicable_attn_implementation = attn_implementation
1766
1807
 
1808
+ is_paged = attn_implementation is not None and attn_implementation.startswith("paged|")
1809
+
1767
1810
  # If FA not installed, do not fail but use kernels instead
1768
1811
  requested_original_flash_attn = attn_implementation is not None and (
1769
- attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3"
1812
+ attn_implementation.removeprefix("paged|") == "flash_attention_2"
1813
+ or attn_implementation.removeprefix("paged|") == "flash_attention_3"
1770
1814
  )
1771
1815
  if (
1772
1816
  requested_original_flash_attn
@@ -1775,19 +1819,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1775
1819
  and is_kernels_available()
1776
1820
  and not is_torch_npu_available()
1777
1821
  ):
1778
- if attn_implementation.endswith("2"):
1779
- applicable_attn_implementation = "kernels-community/flash-attn2"
1780
- if is_torch_xpu_available():
1781
- # On XPU, kernels library is the native implementation
1782
- # Disabling this flag to avoid giving wrong fallbacks on errors and warnings
1783
- requested_original_flash_attn = False
1784
- else:
1785
- applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
1822
+ applicable_attn_implementation = FLASH_ATTN_KERNEL_FALLBACK[attn_implementation.removeprefix("paged|")]
1823
+
1824
+ if is_torch_xpu_available() and attn_implementation.removeprefix("paged|") == "flash_attention_2":
1825
+ # On XPU, kernels library is the native implementation
1826
+ # Disabling this flag to avoid giving wrong fallbacks on errors and warnings
1827
+ requested_original_flash_attn = False
1828
+
1829
+ if is_paged:
1830
+ applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
1786
1831
 
1787
1832
  if is_kernel(applicable_attn_implementation):
1788
1833
  try:
1789
1834
  # preload flash attention here to allow compile with fullgraph
1790
- lazy_import_flash_attention(applicable_attn_implementation)
1835
+ if is_paged:
1836
+ lazy_import_paged_flash_attention(applicable_attn_implementation)
1837
+ else:
1838
+ lazy_import_flash_attention(applicable_attn_implementation)
1791
1839
 
1792
1840
  # log that we used kernel fallback if successful
1793
1841
  if requested_original_flash_attn:
@@ -1816,6 +1864,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1816
1864
 
1817
1865
  return applicable_attn_implementation
1818
1866
 
1867
+ def _check_and_adjust_experts_implementation(self, experts_implementation: Optional[str]) -> str:
1868
+ """
1869
+ Check that the `experts_implementation` exists and is supported by the models.
1870
+
1871
+ Args:
1872
+ experts_implementation (`str` or `None`):
1873
+ The experts implementation to check for existence/validity.
1874
+ Returns:
1875
+ `str`: The final experts implementation to use.
1876
+ """
1877
+ applicable_experts_implementation = self.get_correct_experts_implementation(experts_implementation)
1878
+ return applicable_experts_implementation
1879
+
1819
1880
  def get_correct_attn_implementation(self, requested_attention: Optional[str], is_init_check: bool = False) -> str:
1820
1881
  applicable_attention = "sdpa" if requested_attention is None else requested_attention
1821
1882
  if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
@@ -1850,6 +1911,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1850
1911
 
1851
1912
  return applicable_attention
1852
1913
 
1914
+ def get_correct_experts_implementation(self, requested_experts: Optional[str]) -> str:
1915
+ applicable_experts = "grouped_mm" if requested_experts is None else requested_experts
1916
+ if applicable_experts not in ["eager", "grouped_mm", "batched_mm"]:
1917
+ message = (
1918
+ f'Specified `experts_implementation="{applicable_experts}"` is not supported. The only possible arguments are '
1919
+ '`experts_implementation="eager"`, `"experts_implementation=grouped_mm"` and `"experts_implementation=batched_mm"`.'
1920
+ )
1921
+ raise ValueError(message)
1922
+
1923
+ # Perform relevant checks
1924
+ if applicable_experts == "grouped_mm":
1925
+ try:
1926
+ self._grouped_mm_can_dispatch()
1927
+ except (ValueError, ImportError) as e:
1928
+ if requested_experts == "grouped_mm":
1929
+ raise e
1930
+ applicable_experts = "eager"
1931
+
1932
+ return applicable_experts
1933
+
1853
1934
  @classmethod
1854
1935
  def _can_set_attn_implementation(cls) -> bool:
1855
1936
  """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
@@ -1868,6 +1949,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1868
1949
  # If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
1869
1950
  return True
1870
1951
 
1952
+ @classmethod
1953
+ def _can_set_experts_implementation(cls) -> bool:
1954
+ """Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
1955
+ opening the file, but avoids maintaining yet another property flag.
1956
+ """
1957
+ class_file = sys.modules[cls.__module__].__file__
1958
+ with open(class_file, "r") as f:
1959
+ code = f.read()
1960
+ # heuristic -> if we the use_experts_implementation decorator is used, then we can set it
1961
+ return "@use_experts_implementation" in code
1962
+
1871
1963
  def set_attn_implementation(self, attn_implementation: Union[str, dict]):
1872
1964
  """
1873
1965
  Set the requested `attn_implementation` for this model.
@@ -1967,6 +2059,50 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1967
2059
  if hasattr(subconfig, "_attn_was_changed"):
1968
2060
  del subconfig._attn_was_changed
1969
2061
 
2062
+ def set_experts_implementation(self, experts_implementation: Union[str, dict]):
2063
+ """
2064
+ Set the requested `experts_implementation` for this model.
2065
+
2066
+ Args:
2067
+ experts_implementation (`str` or `dict`):
2068
+ The experts implementation to set for this model. It can be either a `str`, in which case it will be
2069
+ dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
2070
+ submodel will dispatch the corresponding value.
2071
+ """
2072
+ requested_implementation = (
2073
+ experts_implementation
2074
+ if not isinstance(experts_implementation, dict)
2075
+ else experts_implementation.get("", self.config._experts_implementation)
2076
+ )
2077
+
2078
+ if requested_implementation != self.config._experts_implementation:
2079
+ requested_implementation = self._check_and_adjust_experts_implementation(requested_implementation)
2080
+ # Apply the change (on the internal attr, to avoid setting it recursively)
2081
+ self.config._experts_implementation_internal = requested_implementation
2082
+
2083
+ # Apply it to all submodels as well
2084
+ for submodule in self.modules():
2085
+ # We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
2086
+ # e.g. ForCausalLM has a Model inside, but no need to check it again)
2087
+ if (
2088
+ submodule is not self
2089
+ and isinstance(submodule, PreTrainedModel)
2090
+ and submodule.config.__class__ != self.config.__class__
2091
+ ):
2092
+ # Set the experts on the submodule
2093
+ sub_implementation = requested_implementation
2094
+ if isinstance(experts_implementation, dict):
2095
+ for subconfig_key in self.config.sub_configs:
2096
+ # We need to check for exact object match here, with `is`
2097
+ if getattr(self.config, subconfig_key) is submodule.config:
2098
+ sub_implementation = experts_implementation.get(
2099
+ subconfig_key, submodule.config._experts_implementation
2100
+ )
2101
+ break
2102
+ # Check the module can use correctly, otherwise we raise an error if requested experts can't be set for submodule
2103
+ sub_implementation = submodule.get_correct_experts_implementation(sub_implementation)
2104
+ submodule.config._experts_implementation_internal = sub_implementation
2105
+
1970
2106
  def enable_input_require_grads(self):
1971
2107
  """
1972
2108
  Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
@@ -1978,14 +2114,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1978
2114
 
1979
2115
  hooks = []
1980
2116
  seen_modules = set()
2117
+ found_embeddings = False
1981
2118
 
1982
2119
  for module in self.modules():
1983
2120
  if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
1984
2121
  continue
1985
2122
 
1986
- input_embeddings = module.get_input_embeddings()
2123
+ try:
2124
+ input_embeddings = module.get_input_embeddings()
2125
+ except NotImplementedError:
2126
+ continue
1987
2127
 
1988
- if input_embeddings is None:
2128
+ if input_embeddings is None or not hasattr(input_embeddings, "register_forward_hook"):
1989
2129
  continue
1990
2130
 
1991
2131
  embedding_id = id(input_embeddings)
@@ -1994,11 +2134,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
1994
2134
 
1995
2135
  seen_modules.add(embedding_id)
1996
2136
  hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))
2137
+ found_embeddings = True
1997
2138
 
1998
2139
  self._require_grads_hooks = hooks
1999
2140
  if hooks:
2000
2141
  # for BC
2001
2142
  self._require_grads_hook = hooks[0]
2143
+ if not found_embeddings:
2144
+ logger.warning_once(
2145
+ f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
2146
+ "embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
2147
+ "support those features, or set `_input_embed_layer` to the attribute name that holds the embeddings."
2148
+ )
2002
2149
 
2003
2150
  def disable_input_require_grads(self):
2004
2151
  """
@@ -2104,7 +2251,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2104
2251
  possible_module_names = ["language_model", "text_model", "decoder"]
2105
2252
  for name in possible_module_names:
2106
2253
  if hasattr(self, name):
2107
- print(name)
2108
2254
  setattr(self, name, decoder)
2109
2255
  return
2110
2256
 
@@ -2134,14 +2280,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2134
2280
  if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
2135
2281
  if getattr(module, "weight", None) is not None:
2136
2282
  init.normal_(module.weight, mean=0.0, std=std)
2137
- if getattr(module, "bias", None) is not None:
2283
+ if module.bias is not None:
2138
2284
  init.zeros_(module.bias)
2139
2285
  elif isinstance(module, nn.Embedding):
2140
- if getattr(module, "weight", None) is not None:
2141
- init.normal_(module.weight, mean=0.0, std=std)
2142
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
2143
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
2144
- init.zeros_(module.weight[module.padding_idx])
2286
+ init.normal_(module.weight, mean=0.0, std=std)
2287
+ # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
2288
+ if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
2289
+ init.zeros_(module.weight[module.padding_idx])
2145
2290
  elif isinstance(module, nn.MultiheadAttention):
2146
2291
  # This uses torch's original init
2147
2292
  module._reset_parameters()
@@ -2153,10 +2298,25 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2153
2298
  or "RMSNorm" in module.__class__.__name__
2154
2299
  ):
2155
2300
  # Norms can exist without weights (in which case they are None from torch primitives)
2156
- if hasattr(module, "weight") and module.weight is not None:
2301
+ if getattr(module, "weight", None) is not None:
2157
2302
  init.ones_(module.weight)
2158
- if hasattr(module, "bias") and module.bias is not None:
2303
+ if getattr(module, "bias", None) is not None:
2159
2304
  init.zeros_(module.bias)
2305
+ # And the potential buffers for the BatchNorms
2306
+ if getattr(module, "running_mean", None) is not None:
2307
+ init.zeros_(module.running_mean)
2308
+ init.ones_(module.running_var)
2309
+ init.zeros_(module.num_batches_tracked)
2310
+ # This matches all the usual RotaryEmbeddings modules
2311
+ elif "RotaryEmbedding" in module.__class__.__name__ and hasattr(module, "original_inv_freq"):
2312
+ rope_fn = (
2313
+ ROPE_INIT_FUNCTIONS[module.rope_type]
2314
+ if module.rope_type != "default"
2315
+ else module.compute_default_rope_parameters
2316
+ )
2317
+ buffer_value, _ = rope_fn(module.config)
2318
+ init.copy_(module.inv_freq, buffer_value)
2319
+ init.copy_(module.original_inv_freq, buffer_value)
2160
2320
 
2161
2321
  def _initialize_weights(self, module):
2162
2322
  """
@@ -2261,7 +2421,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2261
2421
 
2262
2422
  tied_mapping = self._tied_weights_keys
2263
2423
  # If the config does not specify any tying, return empty dict
2264
- if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder:
2424
+ if not self.config.tie_word_embeddings:
2265
2425
  return {}
2266
2426
  # If None, return empty dict
2267
2427
  elif tied_mapping is None:
@@ -2327,30 +2487,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2327
2487
 
2328
2488
  tied_keys = list(tied_keys.items())
2329
2489
  for i, (target_param_name, source_param_name) in enumerate(tied_keys):
2330
- # Usually we tie a single target to a single source, but when both are missing we may later tie
2331
- # both the source and target to a third "backup" parameter that is present in the checkpoint, so we use
2332
- # a list here
2333
- target_param_names = [target_param_name]
2334
-
2335
2490
  # This is `from_pretrained` -> let's check symmetrically in case the source key is not present
2336
2491
  if missing_keys is not None:
2337
2492
  remove_from_missing = True
2338
2493
  source_is_there = source_param_name not in missing_keys
2339
2494
  target_is_there = target_param_name not in missing_keys
2340
2495
  # Both are already present -> it means the config is wrong and do not reflect the actual
2341
- # checkpoint -> let's raise a warning and do nothing
2496
+ # checkpoint -> let's raise a warning and NOT tie them
2342
2497
  if source_is_there and target_is_there:
2343
2498
  logger.warning(
2344
2499
  f"The tied weights mapping and config for this model specifies to tie {source_param_name} to "
2345
2500
  f"{target_param_name}, but both are present in the checkpoints, so we will NOT tie them. "
2346
2501
  "You should update the config with `tie_word_embeddings=False` to silence this warning"
2347
2502
  )
2503
+ # Remove from internal attribute to correctly reflect actual tied weights
2504
+ self.all_tied_weights_keys.pop(target_param_name)
2348
2505
  # Skip to next iteration
2349
2506
  continue
2350
2507
  # We're missing the source but we have the target -> we swap them, tying the parameter that exists
2351
2508
  elif not source_is_there and target_is_there:
2352
2509
  target_param_name, source_param_name = source_param_name, target_param_name
2353
- target_param_names = [target_param_name]
2354
2510
  # Both are missing -> check other keys in case more than 2 keys are tied to the same weight
2355
2511
  elif not source_is_there and not target_is_there:
2356
2512
  for target_backup, source_backup in tied_keys[i + 1 :]:
@@ -2359,10 +2515,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2359
2515
  if source_backup == source_param_name:
2360
2516
  target_backup_is_there = target_backup not in missing_keys
2361
2517
  # If the target is present, we found the correct weight to tie into (we know the source is missing)
2518
+ # Note here that we do not tie the missing source right now as well, as it will be done anyway when
2519
+ # the pair (target_backup, source_backup) becomes the main pair (target_param_name, source_param_name)
2362
2520
  if target_backup_is_there:
2363
2521
  source_param_name = target_backup
2364
- # Append the source as well, since both are missing we'll tie both
2365
- target_param_names.append(source_param_name)
2366
2522
  break
2367
2523
  # If we did not break from the loop, it was impossible to find a source key -> let's raise
2368
2524
  else:
@@ -2378,19 +2534,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2378
2534
 
2379
2535
  # Perform the actual tying
2380
2536
  source_param = self.get_parameter_or_buffer(source_param_name)
2381
- for target_param_name in target_param_names:
2382
- if "." in target_param_name:
2383
- parent_name, name = target_param_name.rsplit(".", 1)
2384
- parent = self.get_submodule(parent_name)
2385
- else:
2386
- name = target_param_name
2387
- parent = self
2388
- # Tie the weights
2389
- setattr(parent, name, source_param)
2390
- self._adjust_bias(parent, source_param)
2391
- # Remove from missing if necesary
2392
- if missing_keys is not None and remove_from_missing:
2393
- missing_keys.discard(target_param_name)
2537
+ if "." in target_param_name:
2538
+ parent_name, name = target_param_name.rsplit(".", 1)
2539
+ parent = self.get_submodule(parent_name)
2540
+ else:
2541
+ name = target_param_name
2542
+ parent = self
2543
+ # Tie the weights
2544
+ setattr(parent, name, source_param)
2545
+ self._adjust_bias(parent, source_param)
2546
+ # Remove from missing if necesary
2547
+ if missing_keys is not None and remove_from_missing:
2548
+ missing_keys.discard(target_param_name)
2394
2549
 
2395
2550
  def _adjust_bias(self, output_embeddings, input_embeddings):
2396
2551
  if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"):
@@ -2903,7 +3058,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2903
3058
  Maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
2904
3059
  initialization logic in `_init_weights`.
2905
3060
  """
2906
- if _init_weights:
3061
+ # If we are initializing on meta device, there is no point in trying to run inits
3062
+ if get_torch_context_manager_or_global_device() != torch.device("meta"):
2907
3063
  # Initialize weights
2908
3064
  self.initialize_weights()
2909
3065
  # Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
@@ -2941,7 +3097,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
2941
3097
  "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
2942
3098
  )
2943
3099
 
2944
- if getattr(self, "_hf_peft_config_loaded", False):
3100
+ needs_embedding_grads = self.main_input_name == "input_ids"
3101
+ # we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all)
3102
+ enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False)
3103
+ if enable_input_grads:
2945
3104
  # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
2946
3105
  # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
2947
3106
  # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
@@ -3002,10 +3161,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3002
3161
  save_directory: Union[str, os.PathLike],
3003
3162
  is_main_process: bool = True,
3004
3163
  state_dict: Optional[dict] = None,
3005
- save_function: Callable = torch.save,
3006
3164
  push_to_hub: bool = False,
3007
- max_shard_size: Union[int, str] = "5GB",
3008
- safe_serialization: bool = True,
3165
+ max_shard_size: Union[int, str] = "50GB",
3009
3166
  variant: Optional[str] = None,
3010
3167
  token: Optional[Union[str, bool]] = None,
3011
3168
  save_peft_format: bool = True,
@@ -3027,18 +3184,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3027
3184
  The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
3028
3185
  save parts of the model or if special precautions need to be taken when recovering the state dictionary
3029
3186
  of a model (like when using model parallelism).
3030
- save_function (`Callable`):
3031
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
3032
- need to replace `torch.save` by another method.
3033
3187
  push_to_hub (`bool`, *optional*, defaults to `False`):
3034
3188
  Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
3035
3189
  repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
3036
3190
  namespace).
3037
- max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
3191
+ max_shard_size (`int` or `str`, *optional*, defaults to `"50GB"`):
3038
3192
  The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
3039
3193
  lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
3040
- We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
3041
- without CPU OOM issues.
3042
3194
 
3043
3195
  <Tip warning={true}>
3044
3196
 
@@ -3047,10 +3199,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3047
3199
 
3048
3200
  </Tip>
3049
3201
 
3050
- safe_serialization (`bool`, *optional*, defaults to `True`):
3051
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
3052
3202
  variant (`str`, *optional*):
3053
- If specified, weights are saved in the format pytorch_model.<variant>.bin.
3203
+ If specified, weights are saved in the format model.<variant>.safetensors.
3054
3204
  token (`str` or `bool`, *optional*):
3055
3205
  The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
3056
3206
  the token generated when running `hf auth login` (stored in `~/.huggingface`).
@@ -3072,9 +3222,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3072
3222
 
3073
3223
  hf_quantizer = getattr(self, "hf_quantizer", None)
3074
3224
  quantization_serializable = (
3075
- hf_quantizer is not None
3076
- and isinstance(hf_quantizer, HfQuantizer)
3077
- and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
3225
+ hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable()
3078
3226
  )
3079
3227
 
3080
3228
  if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
@@ -3110,7 +3258,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3110
3258
 
3111
3259
  metadata = {}
3112
3260
  if hf_quantizer is not None:
3113
- state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization)
3261
+ state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self)
3114
3262
  metadata["format"] = "pt"
3115
3263
 
3116
3264
  # Only save the model itself if we are using distributed training
@@ -3163,29 +3311,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3163
3311
  current_peft_config = self.peft_config[active_adapter]
3164
3312
  current_peft_config.save_pretrained(save_directory)
3165
3313
 
3166
- # for offloaded modules
3167
- module_map = {}
3168
-
3169
- # Save the model
3314
+ # Get the model state_dict
3170
3315
  if state_dict is None:
3171
- # if any model parameters are offloaded, make module map
3172
- if (
3173
- hasattr(self, "hf_device_map")
3174
- and len(set(self.hf_device_map.values())) > 1
3175
- and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
3176
- ):
3177
- warnings.warn(
3178
- "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
3179
- )
3180
- for name, module in model_to_save.named_modules():
3181
- if name == "":
3182
- continue
3183
- module_state_dict = module.state_dict()
3184
-
3185
- for key in module_state_dict:
3186
- module_map[name + f".{key}"] = module
3187
3316
  state_dict = model_to_save.state_dict()
3188
3317
 
3318
+ # if any model parameters are offloaded, we need to know it for later
3319
+ is_offloaded = False
3320
+ if (
3321
+ hasattr(self, "hf_device_map")
3322
+ and len(set(self.hf_device_map.values())) > 1
3323
+ and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
3324
+ ):
3325
+ is_offloaded = True
3326
+ warnings.warn(
3327
+ "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory "
3328
+ "exceeds the `shard_size` (50GB default)"
3329
+ )
3330
+
3189
3331
  # Translate state_dict from smp to hf if saving with smp >= 1.10
3190
3332
  if IS_SAGEMAKER_MP_POST_1_10:
3191
3333
  for smp_to_hf, _ in smp.state.module_manager.translate_functions:
@@ -3202,86 +3344,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3202
3344
  if self._tp_size is not None:
3203
3345
  state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
3204
3346
 
3205
- if safe_serialization:
3206
- # TODO: fix safe_serialization for tied weights
3207
- # Safetensors does not allow tensor aliasing.
3208
- # We're going to remove aliases before saving
3209
- ptrs = collections.defaultdict(list)
3210
- for name, tensor in state_dict.items():
3211
- if not isinstance(tensor, torch.Tensor):
3212
- # Sometimes in the state_dict we have non-tensor objects.
3213
- # e.g. in bitsandbytes we have some `str` objects in the state_dict
3214
- # In the non-tensor case, fall back to the pointer of the object itself
3215
- ptrs[id(tensor)].append(name)
3216
-
3217
- elif tensor.device.type == "meta":
3218
- # In offloaded cases, there may be meta tensors in the state_dict.
3219
- # For these cases, key by the pointer of the original tensor object
3220
- # (state_dict tensors are detached and therefore no longer shared)
3221
- tensor = self.get_parameter(name)
3222
- ptrs[id(tensor)].append(name)
3223
-
3224
- else:
3225
- ptrs[id_tensor_storage(tensor)].append(name)
3226
-
3227
- shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
3228
-
3229
- # Recursively descend to find tied weight keys
3230
- _tied_weights_keys = set(_get_tied_weight_keys(self))
3231
- error_names = []
3232
- to_delete_names = set()
3233
- for names in shared_ptrs.values():
3234
- # Removing the keys which are declared as known duplicates on
3235
- # load. This allows to make sure the name which is kept is consistent.
3236
- if _tied_weights_keys is not None:
3237
- found = 0
3238
- for name in sorted(names):
3239
- matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
3240
- if matches_pattern and name in state_dict:
3241
- found += 1
3242
- if found < len(names):
3243
- to_delete_names.add(name)
3244
- # We are entering a place where the weights and the transformers configuration do NOT match.
3245
- shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
3246
- # Those are actually tensor sharing but disjoint from each other, we can safely clone them
3247
- # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
3248
- for name in disjoint_names:
3249
- state_dict[name] = state_dict[name].clone()
3250
-
3251
- # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
3252
- # If the link between tensors was done at runtime then `from_pretrained` will not get
3253
- # the key back leading to random tensor. A proper warning will be shown
3254
- # during reload (if applicable), but since the file is not necessarily compatible with
3255
- # the config, better show a proper warning.
3256
- shared_names, identical_names = _find_identical(shared_names, state_dict)
3257
- # delete tensors that have identical storage
3258
- for inames in identical_names:
3259
- known = inames.intersection(to_delete_names)
3260
- for name in known:
3261
- del state_dict[name]
3262
- unknown = inames.difference(to_delete_names)
3263
- if len(unknown) > 1:
3264
- error_names.append(unknown)
3265
-
3266
- if shared_names:
3267
- error_names.extend(shared_names)
3268
-
3269
- if len(error_names) > 0:
3270
- raise RuntimeError(
3271
- f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
3272
- "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
3273
- )
3347
+ # Remove tied weights as safetensors do not handle them
3348
+ state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)
3274
3349
 
3275
3350
  # Revert all renaming and/or weight operations
3276
3351
  if save_original_format:
3277
- state_dict = revert_weight_conversion(self, state_dict)
3352
+ state_dict = revert_weight_conversion(model_to_save, state_dict)
3278
3353
 
3279
3354
  # Shard the model if it is too big.
3280
3355
  if not _hf_peft_config_loaded:
3281
- weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
3356
+ weights_name = SAFE_WEIGHTS_NAME
3282
3357
  weights_name = _add_variant(weights_name, variant)
3283
3358
  else:
3284
- weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
3359
+ weights_name = ADAPTER_SAFE_WEIGHTS_NAME
3285
3360
 
3286
3361
  filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
3287
3362
  state_dict_split = split_torch_state_dict_into_shards(
@@ -3314,57 +3389,45 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3314
3389
  and reg.fullmatch(filename_no_suffix) is not None
3315
3390
  ):
3316
3391
  os.remove(full_filename)
3392
+
3317
3393
  # Save the model
3318
- filename_to_tensors = state_dict_split.filename_to_tensors.items()
3319
- if module_map:
3320
- filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
3321
- for shard_file, tensors in filename_to_tensors:
3322
- shard = {}
3323
- for tensor in tensors:
3324
- if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
3325
- full_tensor = state_dict[tensor].full_tensor()
3394
+ for shard_file, tensor_names in logging.tqdm(
3395
+ state_dict_split.filename_to_tensors.items(), desc="Writing model shards"
3396
+ ):
3397
+ filename = os.path.join(save_directory, shard_file)
3398
+ shard_state_dict = {}
3399
+ for tensor_name in tensor_names:
3400
+ # Get the tensor, and remove it from state_dict to avoid keeping the ref
3401
+ tensor = state_dict.pop(tensor_name)
3402
+
3403
+ # In case of TP, get the full parameter back
3404
+ if _is_dtensor_available and isinstance(tensor, DTensor):
3405
+ tensor = tensor.full_tensor()
3326
3406
  # to get the correctly ordered tensor we need to repack if packed
3327
- if _get_parameter_tp_plan(tensor, self._tp_plan) == "local_packed_rowwise":
3328
- full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
3329
- shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
3330
- else:
3331
- shard[tensor] = state_dict[tensor].contiguous()
3332
- # delete reference, see https://github.com/huggingface/transformers/pull/34890
3333
- del state_dict[tensor]
3334
-
3335
- # remake shard with onloaded parameters if necessary
3336
- if module_map:
3337
- # init state_dict for this shard
3338
- shard_state_dict = dict.fromkeys(shard, "")
3339
- for module_name in shard:
3340
- # note that get_state_dict_from_offload can update with meta tensors
3341
- # if both a parent module and its descendant are offloaded
3342
- tensor = shard_state_dict[module_name]
3343
- if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
3344
- # update state dict with onloaded parameters
3345
- module = module_map[module_name]
3346
- shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
3347
-
3348
- # assign shard to be the completed state dict
3349
- shard = shard_state_dict
3350
- del shard_state_dict
3351
- gc.collect()
3352
-
3353
- if safe_serialization:
3354
- # At some point we will need to deal better with save_function (used for TPU and other distributed
3355
- # joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting
3356
- # too much before scheduling the next write when its in a different file
3357
- safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
3358
- else:
3359
- save_function(shard, os.path.join(save_directory, shard_file))
3407
+ if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
3408
+ tensor = repack_weights(tensor, -1, self._tp_size, 2)
3409
+
3410
+ # If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
3411
+ # but it would otherwise not be contained in the saved shard if we were to simply move the file
3412
+ # or something
3413
+ if is_offloaded and tensor.device.type == "meta":
3414
+ tensor = load_offloaded_parameter(model_to_save, tensor_name)
3415
+
3416
+ # only do contiguous after it's permuted correctly in case of TP
3417
+ shard_state_dict[tensor_name] = tensor.contiguous()
3360
3418
 
3361
- del state_dict
3419
+ # TODO: it would be very nice to do the writing concurrently, but safetensors never releases the GIL,
3420
+ # so it's not possible for now....
3421
+ # Write the shard to disk
3422
+ safe_save_file(shard_state_dict, filename, metadata=metadata)
3423
+ # Cleanup the data before next loop (important with offloading, so we don't blowup cpu RAM)
3424
+ del shard_state_dict
3362
3425
 
3363
3426
  if index is None:
3364
3427
  path_to_weights = os.path.join(save_directory, weights_name)
3365
3428
  logger.info(f"Model weights saved in {path_to_weights}")
3366
3429
  else:
3367
- save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
3430
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME
3368
3431
  save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
3369
3432
  # Save the index as well
3370
3433
  with open(save_index_file, "w", encoding="utf-8") as f:
@@ -3535,19 +3598,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3535
3598
  return super().float(*args)
3536
3599
 
3537
3600
  @classmethod
3538
- def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
3601
+ def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
3602
+ # Need to instantiate with correct dtype
3603
+ init_contexts = [local_torch_dtype(dtype, cls.__name__)]
3539
3604
  if is_deepspeed_zero3_enabled():
3540
3605
  import deepspeed
3541
3606
 
3542
- init_contexts = [no_init_weights()]
3543
3607
  # We cannot initialize the model on meta device with deepspeed when not quantized
3544
3608
  if not is_quantized and not _is_ds_init_called:
3545
3609
  logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
3546
- init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
3610
+ init_contexts.extend(
3611
+ [
3612
+ init.no_init_weights(),
3613
+ deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
3614
+ set_zero3_state(),
3615
+ ]
3616
+ )
3547
3617
  elif is_quantized:
3548
- init_contexts.extend([init_empty_weights(), set_quantized_state()])
3618
+ init_contexts.extend([torch.device("meta"), set_quantized_state()])
3549
3619
  else:
3550
- init_contexts = [no_init_weights(), init_empty_weights()]
3620
+ init_contexts.append(torch.device("meta"))
3551
3621
 
3552
3622
  return init_contexts
3553
3623
 
@@ -3572,7 +3642,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3572
3642
 
3573
3643
  # This is a context manager to override the default kernel mapping
3574
3644
  # We are calling kernelize inside this context manager using the use_kernels setter
3575
- with use_kernel_mapping(kernel_config.kernel_mapping):
3645
+ # Param inherit_mapping should be False to avoid still loading kernel from remote
3646
+ inherit_mapping = not kernel_config.use_local_kernel
3647
+ with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
3576
3648
  self.use_kernels = True
3577
3649
  # We use the default kernel mapping in .integrations.hub_kernels
3578
3650
  else:
@@ -3581,7 +3653,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3581
3653
  self.use_kernels = False
3582
3654
 
3583
3655
  @classmethod
3584
- @restore_default_dtype
3585
3656
  def from_pretrained(
3586
3657
  cls: type[SpecificPreTrainedModelType],
3587
3658
  pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
@@ -3690,10 +3761,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3690
3761
  "org/model@main"
3691
3762
  "org/model:custom_kernel"
3692
3763
  "org/model@v1.2.3:custom_kernel"
3764
+ experts_implementation (`str`, *optional*):
3765
+ The experts implementation to use in the model (if relevant). Can be any of:
3766
+
3767
+ - `"eager"` (sequential implementation of the experts matrix multiplications).
3768
+ - `"batched_mm"` (using [`torch.bmm`](https://pytorch.org/docs/stable/generated/torch.bmm.html)).
3769
+ - `"grouped_mm"` (using [`torch._grouped_mm`](https://docs.pytorch.org/docs/main/generated/torch.nn.functional.grouped_mm.html)).
3770
+
3771
+ By default, if available, `grouped_mm` will be used for torch>=2.9.0. The default is otherwise the sequential `"eager"` implementation.
3693
3772
 
3694
3773
  > Parameters for big model inference
3695
3774
 
3696
- dtype (`str` or `torch.dtype`, *optional*):
3775
+ dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"`):
3697
3776
  Override the default `torch_dtype` and load the model under a specific `dtype`. The different options
3698
3777
  are:
3699
3778
 
@@ -3835,6 +3914,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3835
3914
  # For BC on torch_dtype argument
3836
3915
  if torch_dtype is not None:
3837
3916
  dtype = dtype if dtype is not None else torch_dtype
3917
+ if dtype is None:
3918
+ dtype = "auto"
3838
3919
 
3839
3920
  if is_offline_mode() and not local_files_only:
3840
3921
  local_files_only = True
@@ -3911,8 +3992,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3911
3992
  if "attn_implementation" in kwargs:
3912
3993
  config._attn_implementation = kwargs.pop("attn_implementation")
3913
3994
 
3914
- hf_quantizer, config, dtype, device_map = get_hf_quantizer(
3915
- config, quantization_config, dtype, device_map, weights_only, user_agent
3995
+ if "experts_implementation" in kwargs:
3996
+ config._experts_implementation = kwargs.pop("experts_implementation")
3997
+
3998
+ hf_quantizer, config, device_map = get_hf_quantizer(
3999
+ config, quantization_config, device_map, weights_only, user_agent
3916
4000
  )
3917
4001
 
3918
4002
  if gguf_file:
@@ -3959,33 +4043,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3959
4043
  ]
3960
4044
 
3961
4045
  # Find the correct dtype based on current state
3962
- config, dtype, dtype_orig = _get_dtype(
3963
- cls, dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
4046
+ config, dtype = _get_dtype(
4047
+ dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only, hf_quantizer
3964
4048
  )
3965
4049
 
3966
4050
  config.name_or_path = pretrained_model_name_or_path
3967
- model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
4051
+ model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called)
3968
4052
  config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
3969
4053
  with ContextManagers(model_init_context):
3970
4054
  # Let's make sure we don't run the init function of buffer modules
3971
4055
  model = cls(config, *model_args, **model_kwargs)
3972
4056
 
4057
+ if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
4058
+ hf_quantizer.preprocess_model(
4059
+ model=model,
4060
+ dtype=dtype,
4061
+ device_map=device_map,
4062
+ checkpoint_files=checkpoint_files,
4063
+ use_kernels=use_kernels,
4064
+ )
4065
+
3973
4066
  # Obtain the weight conversion mapping for this model if any are registered
3974
4067
  weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
3975
4068
 
3976
- # make sure we use the model's config since the __init__ call might have copied it
3977
- config = model.config
3978
-
3979
- if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
3980
- hf_quantizer.preprocess_model(
3981
- model=model,
3982
- device_map=device_map,
3983
- keep_in_fp32_modules=model._keep_in_fp32_modules, # TODO prob no longer needed?
3984
- config=config,
3985
- checkpoint_files=checkpoint_files,
3986
- use_kernels=use_kernels,
3987
- )
3988
-
3989
4069
  if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
3990
4070
  model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
3991
4071
 
@@ -3993,10 +4073,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
3993
4073
  if device_map is not None:
3994
4074
  device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
3995
4075
 
3996
- # restore default dtype
3997
- if dtype_orig is not None:
3998
- torch.set_default_dtype(dtype_orig)
3999
-
4000
4076
  # Finalize model weight initialization
4001
4077
  model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
4002
4078
  model,
@@ -4007,6 +4083,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4007
4083
  sharded_metadata=sharded_metadata,
4008
4084
  device_map=device_map,
4009
4085
  disk_offload_folder=offload_folder,
4086
+ offload_buffers=offload_buffers,
4010
4087
  dtype=dtype,
4011
4088
  hf_quantizer=hf_quantizer,
4012
4089
  device_mesh=device_mesh,
@@ -4014,7 +4091,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4014
4091
  weight_mapping=weight_conversions,
4015
4092
  )
4016
4093
 
4017
- model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
4094
+ model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
4018
4095
  model.set_use_kernels(use_kernels, kernel_config)
4019
4096
 
4020
4097
  # If it is a model with generation capabilities, attempt to load generation files (generation config,
@@ -4030,16 +4107,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4030
4107
  **kwargs,
4031
4108
  )
4032
4109
 
4033
- # for device_map="auto" : dispatch model with hooks on all devices if necessary
4034
- if device_map is not None and device_mesh is None:
4110
+ # If the device_map has more than 1 device: dispatch model with hooks on all devices
4111
+ if device_map is not None and len(set(device_map.values())) > 1:
4035
4112
  accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
4036
4113
 
4037
4114
  if hf_quantizer is not None:
4038
4115
  model.hf_quantizer = hf_quantizer
4039
- hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
4116
+ hf_quantizer.postprocess_model(
4117
+ model
4118
+ ) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
4040
4119
 
4041
4120
  if _adapter_model_path is not None:
4042
- adapter_kwargs["key_mapping"] = weight_conversions # TODO: Dynamic weight loader for adapters
4121
+ adapter_kwargs["key_mapping"] = key_mapping
4043
4122
  model.load_adapter(
4044
4123
  _adapter_model_path,
4045
4124
  adapter_name=adapter_name,
@@ -4068,6 +4147,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4068
4147
  sharded_metadata: Optional[dict] = None,
4069
4148
  device_map: Optional[dict] = None,
4070
4149
  disk_offload_folder: Optional[str] = None,
4150
+ offload_buffers: bool = False,
4071
4151
  dtype: Optional[torch.dtype] = None,
4072
4152
  hf_quantizer: Optional[HfQuantizer] = None,
4073
4153
  device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
@@ -4082,6 +4162,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4082
4162
 
4083
4163
  # Model's definition arriving here is final (TP hooks added, quantized layers replaces)
4084
4164
  expected_keys = list(model.state_dict().keys())
4165
+
4085
4166
  if logger.level >= logging.WARNING:
4086
4167
  verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
4087
4168
 
@@ -4090,10 +4171,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4090
4171
  # Prepare parameters offloading if needed
4091
4172
  if device_map is not None and "disk" in device_map.values():
4092
4173
  disk_offload_index = accelerate_disk_offload(
4174
+ model,
4093
4175
  disk_offload_folder,
4094
4176
  checkpoint_files,
4095
4177
  device_map,
4096
- expected_keys,
4097
4178
  sharded_metadata,
4098
4179
  dtype,
4099
4180
  weight_mapping,
@@ -4104,7 +4185,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4104
4185
  expanded_device_map = expand_device_map(device_map, expected_keys)
4105
4186
  caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
4106
4187
 
4107
- tp_plan = getattr(model, "_tp_plan", None)
4108
4188
  error_msgs = []
4109
4189
 
4110
4190
  if is_deepspeed_zero3_enabled() and not is_quantized:
@@ -4113,9 +4193,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4113
4193
  for ckpt_file in checkpoint_files:
4114
4194
  merged_state_dict.update(load_state_dict(ckpt_file, map_location="cpu", weights_only=weights_only))
4115
4195
  state_dict = merged_state_dict
4116
- error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
4196
+ error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict)
4117
4197
  # This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
4118
- missing_keys, unexpected_keys, mismatched_keys, misc = set(), set(), set(), set()
4198
+ unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set()
4119
4199
  else:
4120
4200
  all_pointer = set()
4121
4201
  # Checkpoints are safetensors
@@ -4137,19 +4217,20 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4137
4217
  else:
4138
4218
  raise ValueError("Neither a state dict nor checkpoint files were found.")
4139
4219
 
4140
- missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, misc = (
4220
+ missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
4141
4221
  convert_and_load_state_dict_in_model(
4142
- model,
4143
- merged_state_dict,
4144
- weight_mapping,
4145
- tp_plan,
4146
- hf_quantizer,
4147
- dtype,
4148
- device_map,
4149
- model.dtype_plan,
4150
- device_mesh,
4151
- disk_offload_index,
4152
- disk_offload_folder,
4222
+ model=model,
4223
+ state_dict=merged_state_dict,
4224
+ weight_mapping=weight_mapping,
4225
+ tp_plan=model._tp_plan,
4226
+ hf_quantizer=hf_quantizer,
4227
+ dtype=dtype,
4228
+ device_map=device_map,
4229
+ dtype_plan=model.dtype_plan,
4230
+ device_mesh=device_mesh,
4231
+ disk_offload_index=disk_offload_index,
4232
+ disk_offload_folder=disk_offload_folder,
4233
+ offload_buffers=offload_buffers,
4153
4234
  )
4154
4235
  )
4155
4236
 
@@ -4160,12 +4241,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4160
4241
  # Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
4161
4242
  model.mark_tied_weights_as_initialized()
4162
4243
 
4163
- # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
4164
- # loading the weights as they are not in the loaded state dict)
4165
- miss_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
4166
- model._move_missing_keys_from_meta_to_cpu(miss_and_mismatched, dtype, hf_quantizer)
4244
+ # Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
4245
+ # meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
4246
+ missing_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
4247
+ model._move_missing_keys_from_meta_to_device(missing_and_mismatched, device_map, device_mesh, hf_quantizer)
4167
4248
 
4168
- # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialzed` flag)
4249
+ # Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
4169
4250
  model._initialize_missing_keys(is_quantized)
4170
4251
 
4171
4252
  # Tie the weights
@@ -4174,34 +4255,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4174
4255
  # Adjust missing and unexpected keys
4175
4256
  missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys)
4176
4257
 
4177
- # Post-processing for tensor parallelism
4178
- if device_mesh is not None:
4179
- # When using TP, the device map is a single device for all parameters
4180
- tp_device = list(device_map.values())[0]
4181
- # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
4182
- # not part of the state_dict (persistent=False)
4183
- for buffer in model.buffers(): # TODO to avaoid this buffer could be added to the ckpt
4184
- if buffer.device != tp_device:
4185
- buffer.data = buffer.to(tp_device)
4186
-
4187
- # In this case, the top-most task module weights were not moved to device and parallelized as they
4188
- # were not part of the loaded weights: do it now
4189
- if missing_keys:
4190
- state_dict = model.state_dict()
4191
- for name in missing_keys:
4192
- param = state_dict[name]
4193
- # Shard the param
4194
- shard_and_distribute_module(
4195
- model,
4196
- param.to(tp_device),
4197
- param,
4198
- name,
4199
- None,
4200
- False,
4201
- device_mesh.get_local_rank(),
4202
- device_mesh,
4203
- )
4204
-
4205
4258
  log_state_dict_report(
4206
4259
  model=model,
4207
4260
  pretrained_model_name_or_path=pretrained_model_name_or_path,
@@ -4211,7 +4264,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4211
4264
  missing_keys=missing_keys,
4212
4265
  mismatched_keys=mismatched_keys,
4213
4266
  mismatched_shapes=mismatched_keys,
4214
- misc=misc,
4267
+ conversion_errors=conversion_errors,
4215
4268
  ignore_mismatched_sizes=ignore_mismatched_sizes,
4216
4269
  )
4217
4270
 
@@ -4399,33 +4452,54 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4399
4452
  def is_backend_compatible(cls):
4400
4453
  return cls._supports_attention_backend
4401
4454
 
4402
- def _move_missing_keys_from_meta_to_cpu(
4403
- self, missing_keys: list[str], dtype: torch.dtype, hf_quantizer: Optional[HfQuantizer]
4455
+ def _move_missing_keys_from_meta_to_device(
4456
+ self,
4457
+ missing_keys: list[str],
4458
+ device_map: dict | None,
4459
+ device_mesh: "torch.distributed.device_mesh.DeviceMesh | None",
4460
+ hf_quantizer: HfQuantizer | None,
4404
4461
  ) -> None:
4405
- """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back
4406
- from meta device to cpu.
4462
+ """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts)
4463
+ back from meta device to their device according to the `device_map` if any, else cpu. Takes care of sharding those
4464
+ missing parameters if `device_mesh` is provided, i.e. we are using TP.
4465
+ All non-persistent buffers are also moved back to the correct device (they are not part of the state_dict, but are
4466
+ not missing either).
4407
4467
  """
4408
4468
  is_quantized = hf_quantizer is not None
4469
+ # This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
4470
+ if is_deepspeed_zero3_enabled() and not is_quantized:
4471
+ return
4409
4472
 
4410
4473
  # In this case we need to move everything back
4411
4474
  if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
4412
- # We only do it for the parameters, as the buffers are not initialized on the meta device by default
4413
4475
  for key, param in self.named_parameters():
4414
- value = torch.empty_like(param, dtype=dtype, device="cpu")
4476
+ value = torch.empty_like(param, device="cpu")
4477
+ _load_parameter_into_model(self, key, value)
4478
+ for key, buffer in self.named_buffers():
4479
+ value = torch.empty_like(buffer, device="cpu")
4415
4480
  _load_parameter_into_model(self, key, value)
4416
4481
  return
4417
4482
 
4418
- model_state_dict = self.state_dict()
4419
4483
  # The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway)
4420
4484
  # This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they
4421
4485
  # will be re-initialized for nothing (which can be quite long)
4422
4486
  for key in missing_keys - self.all_tied_weights_keys.keys():
4423
- param = model_state_dict[key]
4424
- # Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
4425
- if param.device == torch.device("meta"):
4426
- value = torch.empty_like(param, dtype=dtype, device="cpu")
4427
- if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
4428
- _load_parameter_into_model(self, key, value)
4487
+ param = self.get_parameter_or_buffer(key)
4488
+ param_device = get_device(device_map, key, valid_torch_device=True)
4489
+ value = torch.empty_like(param, device=param_device)
4490
+ # For TP, we may need to shard the param
4491
+ if device_mesh is not None:
4492
+ shard_and_distribute_module(
4493
+ self, value, param, key, None, False, device_mesh.get_local_rank(), device_mesh
4494
+ )
4495
+ # Otherwise, just move it to device
4496
+ else:
4497
+ _load_parameter_into_model(self, key, value)
4498
+ # We need to move back non-persistent buffers as well, as they are not part of loaded weights anyway
4499
+ for key, buffer in self.named_non_persistent_buffers():
4500
+ buffer_device = get_device(device_map, key, valid_torch_device=True)
4501
+ value = torch.empty_like(buffer, device=buffer_device)
4502
+ _load_parameter_into_model(self, key, value)
4429
4503
 
4430
4504
  def _initialize_missing_keys(self, is_quantized: bool) -> None:
4431
4505
  """
@@ -4453,8 +4527,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4453
4527
  ) -> tuple[set[str], set[str]]:
4454
4528
  """Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
4455
4529
  raising unneeded warnings/errors.
4456
- Also, set the `_is_hf_initialized` on tied weight keys, to avoid initializing them as they are going to
4457
- be tied anyway.
4458
4530
  """
4459
4531
  # Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
4460
4532
  # (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
@@ -4513,6 +4585,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
4513
4585
 
4514
4586
  raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
4515
4587
 
4588
+ def named_non_persistent_buffers(
4589
+ self, recurse: bool = True, remove_duplicate: bool = True
4590
+ ) -> Iterator[tuple[str, torch.Tensor]]:
4591
+ """Similar to `named_buffers`, but only yield non-persistent ones. It is handy as it's not perfectly straightforward
4592
+ to know if they are persistent or not"""
4593
+ for name, tensor in self.named_buffers(recurse=recurse, remove_duplicate=remove_duplicate):
4594
+ # We have to grab the parent here, as the attribute `_non_persistent_buffers_set` is on the immediate
4595
+ # parent only
4596
+ parent, buf_name = name.rsplit(".", 1) if "." in name else ("", name)
4597
+ parent = self.get_submodule(parent)
4598
+ if buf_name in parent._non_persistent_buffers_set:
4599
+ yield name, tensor
4600
+
4516
4601
  def train(self, mode: bool = True):
4517
4602
  out = super().train(mode)
4518
4603
  if self.use_kernels:
@@ -4565,6 +4650,40 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
4565
4650
  return torch.device(device).type not in ["meta", "cpu"]
4566
4651
 
4567
4652
 
4653
+ def get_total_byte_count(
4654
+ model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: Optional[HfQuantizer] = None
4655
+ ):
4656
+ """
4657
+ This utility function calculates the total bytes count needed to load the model on each device.
4658
+ This is useful for caching_allocator_warmup as we want to know how much cache we need to pre-allocate.
4659
+ """
4660
+
4661
+ total_byte_count = defaultdict(lambda: 0)
4662
+ tied_param_names = model.all_tied_weights_keys.keys()
4663
+ tp_plan = model._tp_plan if torch.distributed.is_available() and torch.distributed.is_initialized() else []
4664
+
4665
+ for param_name, device in accelerator_device_map.items():
4666
+ # Skip if the parameter has already been accounted for (tied weights)
4667
+ if param_name in tied_param_names:
4668
+ continue
4669
+
4670
+ param = model.get_parameter_or_buffer(param_name)
4671
+
4672
+ if hf_quantizer is not None:
4673
+ dtype_size = hf_quantizer.param_element_size(model, param_name, param)
4674
+ else:
4675
+ dtype_size = param.element_size()
4676
+
4677
+ param_byte_count = param.numel() * dtype_size
4678
+
4679
+ if len(tp_plan) > 0:
4680
+ is_part_of_plan = _get_parameter_tp_plan(param_name, tp_plan, is_weight=True) is not None
4681
+ param_byte_count //= torch.distributed.get_world_size() if is_part_of_plan else 1
4682
+
4683
+ total_byte_count[device] += param_byte_count
4684
+ return total_byte_count
4685
+
4686
+
4568
4687
  def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: Optional[HfQuantizer]):
4569
4688
  """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
4570
4689
  device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
@@ -4584,8 +4703,6 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4584
4703
  - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
4585
4704
  However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
4586
4705
  """
4587
- factor = 2 if hf_quantizer is None else hf_quantizer.get_accelerator_warm_up_factor()
4588
-
4589
4706
  # Remove disk, cpu and meta devices, and cast to proper torch.device
4590
4707
  accelerator_device_map = {
4591
4708
  param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
@@ -4593,40 +4710,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4593
4710
  if not accelerator_device_map:
4594
4711
  return
4595
4712
 
4596
- tp_plan = getattr(model, "_tp_plan", []) or []
4597
- tp_plan_regex = (
4598
- re.compile("|".join([re.escape(plan) for plan in tp_plan]))
4599
- if _torch_distributed_available and torch.distributed.is_initialized()
4600
- else None
4601
- )
4602
- total_byte_count = defaultdict(lambda: 0)
4603
- tied_param_names = model.all_tied_weights_keys.keys()
4604
- for param_name, device in accelerator_device_map.items():
4605
- # Skip if the parameter has already been accounted for (tied weights)
4606
- if param_name in tied_param_names:
4607
- continue
4608
-
4609
- # For example in the case of MXFP4 quantization, we need to update the param name to the original param name
4610
- # because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
4611
- if hf_quantizer is not None:
4612
- param_name = hf_quantizer.get_param_name(param_name)
4613
-
4614
- try:
4615
- param = model.get_parameter_or_buffer(param_name)
4616
- except AttributeError:
4617
- # TODO: for now let's skip if we can't find the parameters
4618
- if hf_quantizer is not None:
4619
- continue
4620
- raise AttributeError(f"Parameter {param_name} not found in model")
4621
-
4622
- # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
4623
- param_byte_count = param.numel() * param.element_size()
4624
-
4625
- if tp_plan_regex is not None:
4626
- generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
4627
- param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
4628
-
4629
- total_byte_count[device] += param_byte_count
4713
+ total_byte_count = get_total_byte_count(model, accelerator_device_map, hf_quantizer)
4630
4714
 
4631
4715
  # This will kick off the caching allocator to avoid having to Malloc afterwards
4632
4716
  for device, byte_count in total_byte_count.items():
@@ -4646,9 +4730,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
4646
4730
  unused_memory = torch_accelerator_module.memory_reserved(
4647
4731
  index
4648
4732
  ) - torch_accelerator_module.memory_allocated(index)
4649
- byte_count = max(0, byte_count - unused_memory)
4650
- # Allocate memory
4651
- _ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
4733
+ byte_count = int(max(0, byte_count - unused_memory))
4734
+ # We divide by 2 here as we allocate in fp16
4735
+ _ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
4652
4736
 
4653
4737
 
4654
4738
  class AttentionInterface(GeneralInterface):