transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
transformers/modeling_utils.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
import collections
|
|
17
17
|
import copy
|
|
18
18
|
import functools
|
|
19
|
-
import gc
|
|
20
19
|
import importlib.metadata
|
|
21
20
|
import inspect
|
|
22
21
|
import json
|
|
@@ -26,7 +25,7 @@ import sys
|
|
|
26
25
|
import warnings
|
|
27
26
|
from abc import abstractmethod
|
|
28
27
|
from collections import defaultdict
|
|
29
|
-
from collections.abc import Callable, Sequence
|
|
28
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
30
29
|
from contextlib import contextmanager
|
|
31
30
|
from enum import Enum
|
|
32
31
|
from functools import partial, wraps
|
|
@@ -36,7 +35,7 @@ from typing import Optional, TypeVar, Union, get_type_hints
|
|
|
36
35
|
from zipfile import is_zipfile
|
|
37
36
|
|
|
38
37
|
import torch
|
|
39
|
-
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
|
|
38
|
+
from huggingface_hub import create_repo, is_offline_mode, split_torch_state_dict_into_shards
|
|
40
39
|
from packaging import version
|
|
41
40
|
from safetensors import safe_open
|
|
42
41
|
from safetensors.torch import save_file as safe_save_file
|
|
@@ -63,7 +62,8 @@ from .integrations.accelerate import (
|
|
|
63
62
|
accelerate_dispatch,
|
|
64
63
|
check_and_set_device_map,
|
|
65
64
|
expand_device_map,
|
|
66
|
-
|
|
65
|
+
get_device,
|
|
66
|
+
load_offloaded_parameter,
|
|
67
67
|
)
|
|
68
68
|
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
|
69
69
|
from .integrations.eager_paged import eager_paged_attention_forward
|
|
@@ -85,7 +85,8 @@ from .integrations.tensor_parallel import (
|
|
|
85
85
|
verify_tp_plan,
|
|
86
86
|
)
|
|
87
87
|
from .loss.loss_utils import LOSS_MAPPING
|
|
88
|
-
from .modeling_flash_attention_utils import lazy_import_flash_attention
|
|
88
|
+
from .modeling_flash_attention_utils import lazy_import_flash_attention, lazy_import_paged_flash_attention
|
|
89
|
+
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
89
90
|
from .pytorch_utils import id_tensor_storage
|
|
90
91
|
from .quantizers import HfQuantizer
|
|
91
92
|
from .quantizers.auto import get_hf_quantizer
|
|
@@ -93,7 +94,6 @@ from .quantizers.quantizers_utils import get_module_from_name
|
|
|
93
94
|
from .safetensors_conversion import auto_conversion
|
|
94
95
|
from .utils import (
|
|
95
96
|
ADAPTER_SAFE_WEIGHTS_NAME,
|
|
96
|
-
ADAPTER_WEIGHTS_NAME,
|
|
97
97
|
DUMMY_INPUTS,
|
|
98
98
|
SAFE_WEIGHTS_INDEX_NAME,
|
|
99
99
|
SAFE_WEIGHTS_NAME,
|
|
@@ -109,8 +109,8 @@ from .utils import (
|
|
|
109
109
|
is_accelerate_available,
|
|
110
110
|
is_flash_attn_2_available,
|
|
111
111
|
is_flash_attn_3_available,
|
|
112
|
+
is_grouped_mm_available,
|
|
112
113
|
is_kernels_available,
|
|
113
|
-
is_offline_mode,
|
|
114
114
|
is_torch_flex_attn_available,
|
|
115
115
|
is_torch_greater_or_equal,
|
|
116
116
|
is_torch_mlu_available,
|
|
@@ -132,7 +132,6 @@ from .utils.quantization_config import QuantizationMethod
|
|
|
132
132
|
if is_accelerate_available():
|
|
133
133
|
from accelerate.hooks import add_hook_to_module
|
|
134
134
|
from accelerate.utils import extract_model_from_parallel
|
|
135
|
-
from accelerate.utils.modeling import get_state_dict_from_offload
|
|
136
135
|
|
|
137
136
|
|
|
138
137
|
_torch_distributed_available = torch.distributed.is_available()
|
|
@@ -154,10 +153,15 @@ logger = logging.get_logger(__name__)
|
|
|
154
153
|
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
|
|
155
154
|
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
|
156
155
|
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
|
|
157
|
-
_init_weights = True
|
|
158
156
|
_is_quantized = False
|
|
159
157
|
_is_ds_init_called = False
|
|
160
158
|
|
|
159
|
+
# Mapping from flash attention implementations to their kernel fallback repositories
|
|
160
|
+
FLASH_ATTN_KERNEL_FALLBACK = {
|
|
161
|
+
"flash_attention_2": "kernels-community/flash-attn2",
|
|
162
|
+
"flash_attention_3": "kernels-community/vllm-flash-attn3",
|
|
163
|
+
}
|
|
164
|
+
|
|
161
165
|
|
|
162
166
|
def is_local_dist_rank_0():
|
|
163
167
|
return (
|
|
@@ -167,51 +171,6 @@ def is_local_dist_rank_0():
|
|
|
167
171
|
)
|
|
168
172
|
|
|
169
173
|
|
|
170
|
-
TORCH_INIT_FUNCTIONS = {
|
|
171
|
-
"uniform_": nn.init.uniform_,
|
|
172
|
-
"normal_": nn.init.normal_,
|
|
173
|
-
"trunc_normal_": nn.init.trunc_normal_,
|
|
174
|
-
"constant_": nn.init.constant_,
|
|
175
|
-
"xavier_uniform_": nn.init.xavier_uniform_,
|
|
176
|
-
"xavier_normal_": nn.init.xavier_normal_,
|
|
177
|
-
"kaiming_uniform_": nn.init.kaiming_uniform_,
|
|
178
|
-
"kaiming_normal_": nn.init.kaiming_normal_,
|
|
179
|
-
"uniform": nn.init.uniform,
|
|
180
|
-
"normal": nn.init.normal,
|
|
181
|
-
"xavier_uniform": nn.init.xavier_uniform,
|
|
182
|
-
"xavier_normal": nn.init.xavier_normal,
|
|
183
|
-
"kaiming_uniform": nn.init.kaiming_uniform,
|
|
184
|
-
"kaiming_normal": nn.init.kaiming_normal,
|
|
185
|
-
"orthogonal_": nn.init.orthogonal_,
|
|
186
|
-
}
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
@contextmanager
|
|
190
|
-
def no_init_weights():
|
|
191
|
-
"""
|
|
192
|
-
Context manager to globally disable weight initialization to speed up loading large models.
|
|
193
|
-
"""
|
|
194
|
-
global _init_weights
|
|
195
|
-
old_init_weights = _init_weights
|
|
196
|
-
|
|
197
|
-
_init_weights = False
|
|
198
|
-
|
|
199
|
-
def _skip_init(*args, **kwargs):
|
|
200
|
-
pass
|
|
201
|
-
|
|
202
|
-
# Save the original initialization functions
|
|
203
|
-
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
|
204
|
-
setattr(torch.nn.init, name, _skip_init)
|
|
205
|
-
|
|
206
|
-
try:
|
|
207
|
-
yield
|
|
208
|
-
finally:
|
|
209
|
-
_init_weights = old_init_weights
|
|
210
|
-
# Restore the original initialization functions
|
|
211
|
-
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
|
212
|
-
setattr(torch.nn.init, name, init_func)
|
|
213
|
-
|
|
214
|
-
|
|
215
174
|
@contextmanager
|
|
216
175
|
def set_quantized_state():
|
|
217
176
|
global _is_quantized
|
|
@@ -235,23 +194,28 @@ def set_zero3_state():
|
|
|
235
194
|
_is_ds_init_called = False
|
|
236
195
|
|
|
237
196
|
|
|
238
|
-
|
|
197
|
+
@contextmanager
|
|
198
|
+
def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
|
|
239
199
|
"""
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
as a backup in case calling the function raises
|
|
243
|
-
an error after the function has changed the default dtype but before it could restore it.
|
|
200
|
+
Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context.
|
|
201
|
+
If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid.
|
|
244
202
|
"""
|
|
203
|
+
# Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
|
|
204
|
+
if not dtype.is_floating_point:
|
|
205
|
+
if model_class_name is not None:
|
|
206
|
+
error_message = (
|
|
207
|
+
f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype"
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
|
|
211
|
+
raise ValueError(error_message)
|
|
245
212
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
torch.set_default_dtype(old_dtype)
|
|
253
|
-
|
|
254
|
-
return _wrapper
|
|
213
|
+
original_dtype = torch.get_default_dtype()
|
|
214
|
+
try:
|
|
215
|
+
torch.set_default_dtype(dtype)
|
|
216
|
+
yield
|
|
217
|
+
finally:
|
|
218
|
+
torch.set_default_dtype(original_dtype)
|
|
255
219
|
|
|
256
220
|
|
|
257
221
|
def get_torch_context_manager_or_global_device():
|
|
@@ -279,7 +243,9 @@ def get_state_dict_dtype(state_dict):
|
|
|
279
243
|
return t.dtype
|
|
280
244
|
|
|
281
245
|
# if no floating dtype was found return whatever the first dtype is
|
|
282
|
-
|
|
246
|
+
if len(state_dict) == 0:
|
|
247
|
+
return torch.float32
|
|
248
|
+
return next(iter(state_dict.values())).dtype
|
|
283
249
|
|
|
284
250
|
|
|
285
251
|
str_to_torch_dtype = {
|
|
@@ -405,11 +371,94 @@ def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]
|
|
|
405
371
|
return shared_tensors, identical
|
|
406
372
|
|
|
407
373
|
|
|
374
|
+
def remove_tied_weights_from_state_dict(
|
|
375
|
+
state_dict: dict[str, torch.Tensor], model: "PreTrainedModel"
|
|
376
|
+
) -> dict[str, torch.Tensor]:
|
|
377
|
+
"""
|
|
378
|
+
Remove all tied weights from the given `state_dict`, making sure to keep only the main weight that `model`
|
|
379
|
+
will expect when reloading (even if we know tie weights symmetrically, it's better to keep the intended one).
|
|
380
|
+
This is because `safetensors` does not allow tensor aliasing - so we're going to remove aliases before saving.
|
|
381
|
+
"""
|
|
382
|
+
# To avoid any potential mistakes and mismatches between config and actual tied weights, here we check the pointers
|
|
383
|
+
# of the Tensors themselves -> we are guaranteed to find all the actual tied weights
|
|
384
|
+
ptrs = collections.defaultdict(list)
|
|
385
|
+
for name, tensor in state_dict.items():
|
|
386
|
+
if not isinstance(tensor, torch.Tensor):
|
|
387
|
+
# Sometimes in the state_dict we have non-tensor objects.
|
|
388
|
+
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
|
389
|
+
# In the non-tensor case, fall back to the pointer of the object itself
|
|
390
|
+
ptrs[id(tensor)].append(name)
|
|
391
|
+
|
|
392
|
+
elif tensor.device.type == "meta":
|
|
393
|
+
# In offloaded cases, there may be meta tensors in the state_dict.
|
|
394
|
+
# For these cases, key by the pointer of the original tensor object
|
|
395
|
+
# (state_dict tensors are detached and therefore no longer shared)
|
|
396
|
+
tensor = model.get_parameter(name)
|
|
397
|
+
ptrs[id(tensor)].append(name)
|
|
398
|
+
|
|
399
|
+
else:
|
|
400
|
+
ptrs[id_tensor_storage(tensor)].append(name)
|
|
401
|
+
|
|
402
|
+
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
|
403
|
+
|
|
404
|
+
# Recursively descend to find tied weight keys
|
|
405
|
+
all_potential_tied_weights_keys = set(_get_tied_weight_keys(model))
|
|
406
|
+
error_names = []
|
|
407
|
+
to_delete_names = set()
|
|
408
|
+
# Removing the keys which are declared as known duplicates on load. This allows to make sure the name which is
|
|
409
|
+
# kept is consistent
|
|
410
|
+
if all_potential_tied_weights_keys is not None:
|
|
411
|
+
for names in shared_ptrs.values():
|
|
412
|
+
found = 0
|
|
413
|
+
for name in sorted(names):
|
|
414
|
+
matches_pattern = any(re.search(pat, name) for pat in all_potential_tied_weights_keys)
|
|
415
|
+
if matches_pattern and name in state_dict:
|
|
416
|
+
found += 1
|
|
417
|
+
if found < len(names):
|
|
418
|
+
to_delete_names.add(name)
|
|
419
|
+
# We are entering a place where the weights and the transformers configuration do NOT match.
|
|
420
|
+
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
|
|
421
|
+
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
|
|
422
|
+
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
|
|
423
|
+
for name in disjoint_names:
|
|
424
|
+
state_dict[name] = state_dict[name].clone()
|
|
425
|
+
|
|
426
|
+
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
|
427
|
+
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
|
428
|
+
# the key back leading to random tensor. A proper warning will be shown
|
|
429
|
+
# during reload (if applicable), but since the file is not necessarily compatible with
|
|
430
|
+
# the config, better show a proper warning.
|
|
431
|
+
shared_names, identical_names = _find_identical(shared_names, state_dict)
|
|
432
|
+
# delete tensors that have identical storage
|
|
433
|
+
for inames in identical_names:
|
|
434
|
+
known = inames.intersection(to_delete_names)
|
|
435
|
+
for name in known:
|
|
436
|
+
del state_dict[name]
|
|
437
|
+
unknown = inames.difference(to_delete_names)
|
|
438
|
+
if len(unknown) > 1:
|
|
439
|
+
error_names.append(unknown)
|
|
440
|
+
|
|
441
|
+
if shared_names:
|
|
442
|
+
error_names.extend(shared_names)
|
|
443
|
+
|
|
444
|
+
if len(error_names) > 0:
|
|
445
|
+
raise RuntimeError(
|
|
446
|
+
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. "
|
|
447
|
+
f"We found all the potential target tied weights keys to be: {all_potential_tied_weights_keys}.\n"
|
|
448
|
+
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
return state_dict
|
|
452
|
+
|
|
453
|
+
|
|
408
454
|
def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
|
|
409
|
-
"""Cast a single parameter `param_name` into the `model`, with value `tensor`."""
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
455
|
+
"""Cast a single parameter or buffer `param_name` into the `model`, with value `tensor`."""
|
|
456
|
+
parent, param_type = get_module_from_name(model, param_name)
|
|
457
|
+
if param_type in parent._parameters and not isinstance(tensor, nn.Parameter):
|
|
458
|
+
tensor = nn.Parameter(tensor, requires_grad=tensor.is_floating_point())
|
|
459
|
+
# We need to use setattr here, as we set non-persistent buffers as well with this function (`load_state_dict`
|
|
460
|
+
# does not allow to do it)
|
|
461
|
+
setattr(parent, param_type, tensor)
|
|
413
462
|
|
|
414
463
|
|
|
415
464
|
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|
@@ -552,8 +601,7 @@ def _get_resolved_checkpoint_files(
|
|
|
552
601
|
raise OSError(
|
|
553
602
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
554
603
|
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
|
|
555
|
-
"and thus cannot be loaded with `safetensors`. Please
|
|
556
|
-
"been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
|
|
604
|
+
"and thus cannot be loaded with `safetensors`. Please do not set `use_safetensors=True`."
|
|
557
605
|
)
|
|
558
606
|
else:
|
|
559
607
|
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
|
@@ -697,23 +745,22 @@ def _get_resolved_checkpoint_files(
|
|
|
697
745
|
|
|
698
746
|
|
|
699
747
|
def _get_dtype(
|
|
700
|
-
cls,
|
|
701
748
|
dtype: Optional[Union[str, torch.dtype, dict]],
|
|
702
749
|
checkpoint_files: Optional[list[str]],
|
|
703
750
|
config: PreTrainedConfig,
|
|
704
751
|
sharded_metadata: Optional[dict],
|
|
705
752
|
state_dict: Optional[dict],
|
|
706
753
|
weights_only: bool,
|
|
707
|
-
|
|
754
|
+
hf_quantizer: Optional[HfQuantizer] = None,
|
|
755
|
+
) -> tuple[PreTrainedConfig, torch.dtype]:
|
|
708
756
|
"""Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
|
|
709
757
|
inferred dtype. We do the following:
|
|
710
|
-
1. If dtype is
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
we also may have config.dtype available, but we won't rely on it till v5
|
|
758
|
+
1. If dtype is "auto", we try to read the config, else auto-detect dtype from the loaded state_dict, by checking
|
|
759
|
+
its first weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
|
760
|
+
2. Else, use the dtype provided as a dict or str
|
|
714
761
|
"""
|
|
715
|
-
dtype_orig = None
|
|
716
762
|
is_sharded = sharded_metadata is not None
|
|
763
|
+
asked_dtype = dtype
|
|
717
764
|
|
|
718
765
|
if dtype is not None:
|
|
719
766
|
if isinstance(dtype, str):
|
|
@@ -737,43 +784,49 @@ def _get_dtype(
|
|
|
737
784
|
)
|
|
738
785
|
elif hasattr(torch, dtype):
|
|
739
786
|
dtype = getattr(torch, dtype)
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
sub_config.dtype = dtype
|
|
749
|
-
elif isinstance(dtype, dict):
|
|
750
|
-
for key, curr_dtype in dtype.items():
|
|
751
|
-
if hasattr(config, key):
|
|
752
|
-
value = getattr(config, key)
|
|
753
|
-
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
|
|
754
|
-
value.dtype = curr_dtype
|
|
755
|
-
# main torch dtype for modules that aren't part of any sub-config
|
|
756
|
-
dtype = dtype.get("")
|
|
757
|
-
dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
|
|
758
|
-
config.dtype = dtype
|
|
759
|
-
if dtype is None:
|
|
760
|
-
dtype = torch.float32
|
|
761
|
-
else:
|
|
787
|
+
else:
|
|
788
|
+
raise ValueError(
|
|
789
|
+
"`dtype` provided as a `str` can only be `'auto'`, or a string representation of a valid `torch.dtype`"
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
# cast it to a proper `torch.dtype` object
|
|
793
|
+
dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
|
794
|
+
elif not isinstance(dtype, (dict, torch.dtype)):
|
|
762
795
|
raise ValueError(
|
|
763
796
|
f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
|
|
764
797
|
f"for each sub-config in composite configs, but received {dtype}"
|
|
765
798
|
)
|
|
799
|
+
else:
|
|
800
|
+
# set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
|
|
801
|
+
dtype = torch.get_default_dtype()
|
|
766
802
|
|
|
767
|
-
|
|
803
|
+
if hf_quantizer is not None:
|
|
804
|
+
hf_quantizer.update_dtype(dtype)
|
|
805
|
+
|
|
806
|
+
# Get the main dtype
|
|
807
|
+
if isinstance(dtype, dict):
|
|
808
|
+
main_dtype = dtype.get("", torch.get_default_dtype())
|
|
809
|
+
main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
|
|
768
810
|
else:
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
811
|
+
main_dtype = dtype
|
|
812
|
+
|
|
813
|
+
# Set it on the config and subconfigs
|
|
814
|
+
config.dtype = main_dtype
|
|
815
|
+
for sub_config_key in config.sub_configs:
|
|
816
|
+
if (sub_config := getattr(config, sub_config_key)) is not None:
|
|
817
|
+
# The dtype was "auto" -> try to read the subconfig dtype value if any
|
|
818
|
+
if asked_dtype == "auto":
|
|
819
|
+
sub_dtype = getattr(sub_config, "dtype", main_dtype)
|
|
820
|
+
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
|
|
821
|
+
# The dtype was provided as a dict, try to see if we match the subconfig name
|
|
822
|
+
elif isinstance(dtype, dict):
|
|
823
|
+
sub_dtype = dtype.get(sub_config_key, main_dtype)
|
|
824
|
+
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
|
|
825
|
+
else:
|
|
826
|
+
sub_dtype = main_dtype
|
|
827
|
+
sub_config.dtype = sub_dtype
|
|
775
828
|
|
|
776
|
-
return config,
|
|
829
|
+
return config, main_dtype
|
|
777
830
|
|
|
778
831
|
|
|
779
832
|
class PipelineParallel(Enum):
|
|
@@ -969,54 +1022,52 @@ class EmbeddingAccessMixin:
|
|
|
969
1022
|
`nn.Module`: A torch module mapping vocabulary to hidden states.
|
|
970
1023
|
"""
|
|
971
1024
|
|
|
972
|
-
# 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
|
|
973
|
-
# for most NLP models), and if so, return it.
|
|
974
|
-
|
|
975
1025
|
name = getattr(self, "_input_embed_layer", "embed_tokens")
|
|
976
1026
|
|
|
1027
|
+
# 1) Direct attribute (most NLP models).
|
|
977
1028
|
if (default_embedding := getattr(self, name, None)) is not None:
|
|
978
1029
|
return default_embedding
|
|
979
|
-
# 2)
|
|
1030
|
+
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision/audio models).
|
|
1031
|
+
if hasattr(self, "embeddings") and hasattr(self.embeddings, name):
|
|
1032
|
+
return getattr(self.embeddings, name)
|
|
1033
|
+
# 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides).
|
|
1034
|
+
if hasattr(self, "model") and hasattr(self.model, name):
|
|
1035
|
+
return getattr(self.model, name)
|
|
980
1036
|
|
|
981
|
-
if hasattr(self, "
|
|
982
|
-
|
|
1037
|
+
if hasattr(self, "base_model"):
|
|
1038
|
+
base_model = self.base_model
|
|
1039
|
+
if base_model is not None and base_model is not self:
|
|
1040
|
+
return base_model.get_input_embeddings()
|
|
983
1041
|
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
else:
|
|
988
|
-
base_model = getattr(self, "base_model_prefix", None)
|
|
989
|
-
if base_model is not None:
|
|
990
|
-
base_model = getattr(self, base_model, None)
|
|
991
|
-
if base_model is not None and base_model is not self:
|
|
992
|
-
return base_model.get_input_embeddings()
|
|
993
|
-
raise NotImplementedError(
|
|
994
|
-
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
|
|
995
|
-
"please override in the subclass."
|
|
996
|
-
)
|
|
1042
|
+
raise NotImplementedError(
|
|
1043
|
+
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
|
|
1044
|
+
)
|
|
997
1045
|
|
|
998
1046
|
def set_input_embeddings(self, value: nn.Module):
|
|
999
1047
|
"""Fallback setter that handles **~70%** of models in the code-base.
|
|
1000
1048
|
|
|
1001
1049
|
Order of attempts:
|
|
1002
|
-
1. `self
|
|
1003
|
-
2. `self.
|
|
1004
|
-
3.
|
|
1005
|
-
4.
|
|
1050
|
+
1. `self.<_input_embed_layer>` (direct attribute)
|
|
1051
|
+
2. `self.embeddings.<_input_embed_layer>` (nested embeddings for vision/audio models)
|
|
1052
|
+
3. `self.model.<_input_embed_layer>` (encoder/decoder models)
|
|
1053
|
+
4. delegate to the *base model* if one exists
|
|
1054
|
+
5. otherwise raise `NotImplementedError` so subclasses still can (and
|
|
1006
1055
|
should) override for exotic layouts.
|
|
1007
1056
|
"""
|
|
1008
1057
|
|
|
1009
|
-
# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
|
|
1010
1058
|
name = getattr(self, "_input_embed_layer", "embed_tokens")
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
# 2) as well as vanilla decoder‑only architectures
|
|
1014
|
-
elif hasattr(self, name):
|
|
1059
|
+
# 1) Direct attribute (most NLP models)
|
|
1060
|
+
if hasattr(self, name):
|
|
1015
1061
|
setattr(self, name, value)
|
|
1016
|
-
#
|
|
1017
|
-
elif
|
|
1018
|
-
|
|
1019
|
-
|
|
1062
|
+
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models)
|
|
1063
|
+
elif hasattr(self, "embeddings") and hasattr(self.embeddings, name):
|
|
1064
|
+
setattr(self.embeddings, name, value)
|
|
1065
|
+
# 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
|
|
1066
|
+
elif hasattr(self, "model") and hasattr(self.model, name):
|
|
1067
|
+
setattr(self.model, name, value)
|
|
1068
|
+
# 4) recurse once into the registered *base* model (e.g. for encoder/decoder)
|
|
1069
|
+
elif hasattr(self, "base_model") and self.base_model is not self:
|
|
1070
|
+
self.base_model.set_input_embeddings(value)
|
|
1020
1071
|
else:
|
|
1021
1072
|
raise NotImplementedError(
|
|
1022
1073
|
f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
|
|
@@ -1228,6 +1279,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1228
1279
|
self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
|
1229
1280
|
self.config._attn_implementation, is_init_check=True
|
|
1230
1281
|
)
|
|
1282
|
+
# Check the experts implementation is supported, or set it if not yet set (on the internal attr, to avoid
|
|
1283
|
+
# setting it recursively)
|
|
1284
|
+
self.config._experts_implementation_internal = self._check_and_adjust_experts_implementation(
|
|
1285
|
+
self.config._experts_implementation
|
|
1286
|
+
)
|
|
1231
1287
|
if self.can_generate():
|
|
1232
1288
|
self.generation_config = GenerationConfig.from_model_config(config)
|
|
1233
1289
|
|
|
@@ -1343,7 +1399,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1343
1399
|
def pp_plan(self, plan: dict[str, tuple[str, str]]):
|
|
1344
1400
|
self._pp_plan = plan
|
|
1345
1401
|
|
|
1346
|
-
def dequantize(self):
|
|
1402
|
+
def dequantize(self, dtype=None):
|
|
1347
1403
|
"""
|
|
1348
1404
|
Potentially dequantize the model in case it has been quantized by a quantization method that support
|
|
1349
1405
|
dequantization.
|
|
@@ -1353,7 +1409,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1353
1409
|
if hf_quantizer is None:
|
|
1354
1410
|
raise ValueError("You need to first quantize your model in order to dequantize it")
|
|
1355
1411
|
|
|
1356
|
-
return hf_quantizer.dequantize(self)
|
|
1412
|
+
return hf_quantizer.dequantize(self, dtype=dtype)
|
|
1357
1413
|
|
|
1358
1414
|
def _backward_compatibility_gradient_checkpointing(self):
|
|
1359
1415
|
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
|
|
@@ -1394,7 +1450,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1394
1450
|
self.model_tags.append(tag)
|
|
1395
1451
|
|
|
1396
1452
|
@classmethod
|
|
1397
|
-
@restore_default_dtype
|
|
1398
1453
|
def _from_config(cls, config, **kwargs):
|
|
1399
1454
|
"""
|
|
1400
1455
|
All context managers that the model should be initialized under go here.
|
|
@@ -1403,9 +1458,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1403
1458
|
dtype (`torch.dtype`, *optional*):
|
|
1404
1459
|
Override the default `dtype` and load the model under this dtype.
|
|
1405
1460
|
"""
|
|
1406
|
-
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
|
|
1407
|
-
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
|
|
1408
|
-
# modeling code, we can try to infer it here same way as done in `from_pretrained`
|
|
1409
1461
|
# For BC on the old `torch_dtype`
|
|
1410
1462
|
dtype = kwargs.pop("dtype", config.dtype)
|
|
1411
1463
|
if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
|
|
@@ -1415,61 +1467,32 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1415
1467
|
if isinstance(dtype, str):
|
|
1416
1468
|
dtype = getattr(torch, dtype)
|
|
1417
1469
|
|
|
1418
|
-
# override default dtype if needed
|
|
1419
|
-
dtype_orig = None
|
|
1420
|
-
if dtype is not None:
|
|
1421
|
-
dtype_orig = cls._set_default_dtype(dtype)
|
|
1422
|
-
|
|
1423
1470
|
# If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
|
|
1424
1471
|
if "attn_implementation" in kwargs:
|
|
1425
1472
|
config._attn_implementation = kwargs.pop("attn_implementation")
|
|
1426
1473
|
|
|
1474
|
+
# If passing `experts_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
|
|
1475
|
+
if "experts_implementation" in kwargs:
|
|
1476
|
+
config._experts_implementation = kwargs.pop("experts_implementation")
|
|
1477
|
+
|
|
1478
|
+
init_contexts = []
|
|
1479
|
+
if dtype is not None:
|
|
1480
|
+
init_contexts.append(local_torch_dtype(dtype, cls.__name__))
|
|
1481
|
+
|
|
1427
1482
|
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
|
|
1428
1483
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
|
1429
1484
|
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
|
1430
1485
|
# and memory copying it on CPU or each GPU first
|
|
1431
1486
|
import deepspeed
|
|
1432
1487
|
|
|
1433
|
-
init_contexts
|
|
1434
|
-
with ContextManagers(init_contexts):
|
|
1435
|
-
model = cls(config, **kwargs)
|
|
1488
|
+
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
|
|
1436
1489
|
|
|
1437
|
-
|
|
1490
|
+
# Instantiate the model
|
|
1491
|
+
with ContextManagers(init_contexts):
|
|
1438
1492
|
model = cls(config, **kwargs)
|
|
1439
1493
|
|
|
1440
|
-
# restore default dtype if it was modified
|
|
1441
|
-
if dtype_orig is not None:
|
|
1442
|
-
torch.set_default_dtype(dtype_orig)
|
|
1443
|
-
|
|
1444
1494
|
return model
|
|
1445
1495
|
|
|
1446
|
-
@classmethod
|
|
1447
|
-
def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
|
1448
|
-
"""
|
|
1449
|
-
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
|
|
1450
|
-
under specific dtype.
|
|
1451
|
-
|
|
1452
|
-
Args:
|
|
1453
|
-
dtype (`torch.dtype`):
|
|
1454
|
-
a floating dtype to set to.
|
|
1455
|
-
|
|
1456
|
-
Returns:
|
|
1457
|
-
`torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
|
|
1458
|
-
modified. If it wasn't, returns `None`.
|
|
1459
|
-
|
|
1460
|
-
Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
|
|
1461
|
-
`torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
|
|
1462
|
-
"""
|
|
1463
|
-
if not dtype.is_floating_point:
|
|
1464
|
-
raise ValueError(
|
|
1465
|
-
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
|
|
1466
|
-
)
|
|
1467
|
-
|
|
1468
|
-
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
|
|
1469
|
-
dtype_orig = torch.get_default_dtype()
|
|
1470
|
-
torch.set_default_dtype(dtype)
|
|
1471
|
-
return dtype_orig
|
|
1472
|
-
|
|
1473
1496
|
@property
|
|
1474
1497
|
def base_model(self) -> nn.Module:
|
|
1475
1498
|
"""
|
|
@@ -1546,7 +1569,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1546
1569
|
return True
|
|
1547
1570
|
|
|
1548
1571
|
if is_torch_xpu_available():
|
|
1549
|
-
logger.info(
|
|
1572
|
+
logger.info(
|
|
1573
|
+
f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU."
|
|
1574
|
+
)
|
|
1550
1575
|
return True
|
|
1551
1576
|
|
|
1552
1577
|
if importlib.util.find_spec("flash_attn") is None:
|
|
@@ -1715,6 +1740,22 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1715
1740
|
|
|
1716
1741
|
return True
|
|
1717
1742
|
|
|
1743
|
+
def _grouped_mm_can_dispatch(self) -> bool:
|
|
1744
|
+
"""
|
|
1745
|
+
Check the availability of Grouped MM for a given model.
|
|
1746
|
+
"""
|
|
1747
|
+
|
|
1748
|
+
if not self._can_set_experts_implementation():
|
|
1749
|
+
raise ValueError(f"{self.__class__.__name__} does not support setting experts implementation.")
|
|
1750
|
+
|
|
1751
|
+
if not is_grouped_mm_available():
|
|
1752
|
+
raise ImportError(
|
|
1753
|
+
"PyTorch Grouped MM requirements in Transformers are not met. Please install torch>=2.9.0."
|
|
1754
|
+
)
|
|
1755
|
+
|
|
1756
|
+
# If no error raised by this point, we can return `True`
|
|
1757
|
+
return True
|
|
1758
|
+
|
|
1718
1759
|
def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
|
|
1719
1760
|
"""
|
|
1720
1761
|
Check the availability of Flex Attention for a given model.
|
|
@@ -1764,9 +1805,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1764
1805
|
"""
|
|
1765
1806
|
applicable_attn_implementation = attn_implementation
|
|
1766
1807
|
|
|
1808
|
+
is_paged = attn_implementation is not None and attn_implementation.startswith("paged|")
|
|
1809
|
+
|
|
1767
1810
|
# If FA not installed, do not fail but use kernels instead
|
|
1768
1811
|
requested_original_flash_attn = attn_implementation is not None and (
|
|
1769
|
-
attn_implementation
|
|
1812
|
+
attn_implementation.removeprefix("paged|") == "flash_attention_2"
|
|
1813
|
+
or attn_implementation.removeprefix("paged|") == "flash_attention_3"
|
|
1770
1814
|
)
|
|
1771
1815
|
if (
|
|
1772
1816
|
requested_original_flash_attn
|
|
@@ -1775,19 +1819,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1775
1819
|
and is_kernels_available()
|
|
1776
1820
|
and not is_torch_npu_available()
|
|
1777
1821
|
):
|
|
1778
|
-
|
|
1779
|
-
|
|
1780
|
-
|
|
1781
|
-
|
|
1782
|
-
|
|
1783
|
-
|
|
1784
|
-
|
|
1785
|
-
|
|
1822
|
+
applicable_attn_implementation = FLASH_ATTN_KERNEL_FALLBACK[attn_implementation.removeprefix("paged|")]
|
|
1823
|
+
|
|
1824
|
+
if is_torch_xpu_available() and attn_implementation.removeprefix("paged|") == "flash_attention_2":
|
|
1825
|
+
# On XPU, kernels library is the native implementation
|
|
1826
|
+
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
|
|
1827
|
+
requested_original_flash_attn = False
|
|
1828
|
+
|
|
1829
|
+
if is_paged:
|
|
1830
|
+
applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
|
|
1786
1831
|
|
|
1787
1832
|
if is_kernel(applicable_attn_implementation):
|
|
1788
1833
|
try:
|
|
1789
1834
|
# preload flash attention here to allow compile with fullgraph
|
|
1790
|
-
|
|
1835
|
+
if is_paged:
|
|
1836
|
+
lazy_import_paged_flash_attention(applicable_attn_implementation)
|
|
1837
|
+
else:
|
|
1838
|
+
lazy_import_flash_attention(applicable_attn_implementation)
|
|
1791
1839
|
|
|
1792
1840
|
# log that we used kernel fallback if successful
|
|
1793
1841
|
if requested_original_flash_attn:
|
|
@@ -1816,6 +1864,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1816
1864
|
|
|
1817
1865
|
return applicable_attn_implementation
|
|
1818
1866
|
|
|
1867
|
+
def _check_and_adjust_experts_implementation(self, experts_implementation: Optional[str]) -> str:
|
|
1868
|
+
"""
|
|
1869
|
+
Check that the `experts_implementation` exists and is supported by the models.
|
|
1870
|
+
|
|
1871
|
+
Args:
|
|
1872
|
+
experts_implementation (`str` or `None`):
|
|
1873
|
+
The experts implementation to check for existence/validity.
|
|
1874
|
+
Returns:
|
|
1875
|
+
`str`: The final experts implementation to use.
|
|
1876
|
+
"""
|
|
1877
|
+
applicable_experts_implementation = self.get_correct_experts_implementation(experts_implementation)
|
|
1878
|
+
return applicable_experts_implementation
|
|
1879
|
+
|
|
1819
1880
|
def get_correct_attn_implementation(self, requested_attention: Optional[str], is_init_check: bool = False) -> str:
|
|
1820
1881
|
applicable_attention = "sdpa" if requested_attention is None else requested_attention
|
|
1821
1882
|
if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
|
@@ -1850,6 +1911,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1850
1911
|
|
|
1851
1912
|
return applicable_attention
|
|
1852
1913
|
|
|
1914
|
+
def get_correct_experts_implementation(self, requested_experts: Optional[str]) -> str:
|
|
1915
|
+
applicable_experts = "grouped_mm" if requested_experts is None else requested_experts
|
|
1916
|
+
if applicable_experts not in ["eager", "grouped_mm", "batched_mm"]:
|
|
1917
|
+
message = (
|
|
1918
|
+
f'Specified `experts_implementation="{applicable_experts}"` is not supported. The only possible arguments are '
|
|
1919
|
+
'`experts_implementation="eager"`, `"experts_implementation=grouped_mm"` and `"experts_implementation=batched_mm"`.'
|
|
1920
|
+
)
|
|
1921
|
+
raise ValueError(message)
|
|
1922
|
+
|
|
1923
|
+
# Perform relevant checks
|
|
1924
|
+
if applicable_experts == "grouped_mm":
|
|
1925
|
+
try:
|
|
1926
|
+
self._grouped_mm_can_dispatch()
|
|
1927
|
+
except (ValueError, ImportError) as e:
|
|
1928
|
+
if requested_experts == "grouped_mm":
|
|
1929
|
+
raise e
|
|
1930
|
+
applicable_experts = "eager"
|
|
1931
|
+
|
|
1932
|
+
return applicable_experts
|
|
1933
|
+
|
|
1853
1934
|
@classmethod
|
|
1854
1935
|
def _can_set_attn_implementation(cls) -> bool:
|
|
1855
1936
|
"""Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
|
|
@@ -1868,6 +1949,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1868
1949
|
# If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
|
|
1869
1950
|
return True
|
|
1870
1951
|
|
|
1952
|
+
@classmethod
|
|
1953
|
+
def _can_set_experts_implementation(cls) -> bool:
|
|
1954
|
+
"""Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
|
|
1955
|
+
opening the file, but avoids maintaining yet another property flag.
|
|
1956
|
+
"""
|
|
1957
|
+
class_file = sys.modules[cls.__module__].__file__
|
|
1958
|
+
with open(class_file, "r") as f:
|
|
1959
|
+
code = f.read()
|
|
1960
|
+
# heuristic -> if we the use_experts_implementation decorator is used, then we can set it
|
|
1961
|
+
return "@use_experts_implementation" in code
|
|
1962
|
+
|
|
1871
1963
|
def set_attn_implementation(self, attn_implementation: Union[str, dict]):
|
|
1872
1964
|
"""
|
|
1873
1965
|
Set the requested `attn_implementation` for this model.
|
|
@@ -1967,6 +2059,50 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1967
2059
|
if hasattr(subconfig, "_attn_was_changed"):
|
|
1968
2060
|
del subconfig._attn_was_changed
|
|
1969
2061
|
|
|
2062
|
+
def set_experts_implementation(self, experts_implementation: Union[str, dict]):
|
|
2063
|
+
"""
|
|
2064
|
+
Set the requested `experts_implementation` for this model.
|
|
2065
|
+
|
|
2066
|
+
Args:
|
|
2067
|
+
experts_implementation (`str` or `dict`):
|
|
2068
|
+
The experts implementation to set for this model. It can be either a `str`, in which case it will be
|
|
2069
|
+
dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
|
|
2070
|
+
submodel will dispatch the corresponding value.
|
|
2071
|
+
"""
|
|
2072
|
+
requested_implementation = (
|
|
2073
|
+
experts_implementation
|
|
2074
|
+
if not isinstance(experts_implementation, dict)
|
|
2075
|
+
else experts_implementation.get("", self.config._experts_implementation)
|
|
2076
|
+
)
|
|
2077
|
+
|
|
2078
|
+
if requested_implementation != self.config._experts_implementation:
|
|
2079
|
+
requested_implementation = self._check_and_adjust_experts_implementation(requested_implementation)
|
|
2080
|
+
# Apply the change (on the internal attr, to avoid setting it recursively)
|
|
2081
|
+
self.config._experts_implementation_internal = requested_implementation
|
|
2082
|
+
|
|
2083
|
+
# Apply it to all submodels as well
|
|
2084
|
+
for submodule in self.modules():
|
|
2085
|
+
# We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
|
|
2086
|
+
# e.g. ForCausalLM has a Model inside, but no need to check it again)
|
|
2087
|
+
if (
|
|
2088
|
+
submodule is not self
|
|
2089
|
+
and isinstance(submodule, PreTrainedModel)
|
|
2090
|
+
and submodule.config.__class__ != self.config.__class__
|
|
2091
|
+
):
|
|
2092
|
+
# Set the experts on the submodule
|
|
2093
|
+
sub_implementation = requested_implementation
|
|
2094
|
+
if isinstance(experts_implementation, dict):
|
|
2095
|
+
for subconfig_key in self.config.sub_configs:
|
|
2096
|
+
# We need to check for exact object match here, with `is`
|
|
2097
|
+
if getattr(self.config, subconfig_key) is submodule.config:
|
|
2098
|
+
sub_implementation = experts_implementation.get(
|
|
2099
|
+
subconfig_key, submodule.config._experts_implementation
|
|
2100
|
+
)
|
|
2101
|
+
break
|
|
2102
|
+
# Check the module can use correctly, otherwise we raise an error if requested experts can't be set for submodule
|
|
2103
|
+
sub_implementation = submodule.get_correct_experts_implementation(sub_implementation)
|
|
2104
|
+
submodule.config._experts_implementation_internal = sub_implementation
|
|
2105
|
+
|
|
1970
2106
|
def enable_input_require_grads(self):
|
|
1971
2107
|
"""
|
|
1972
2108
|
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
|
|
@@ -1978,14 +2114,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1978
2114
|
|
|
1979
2115
|
hooks = []
|
|
1980
2116
|
seen_modules = set()
|
|
2117
|
+
found_embeddings = False
|
|
1981
2118
|
|
|
1982
2119
|
for module in self.modules():
|
|
1983
2120
|
if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
|
|
1984
2121
|
continue
|
|
1985
2122
|
|
|
1986
|
-
|
|
2123
|
+
try:
|
|
2124
|
+
input_embeddings = module.get_input_embeddings()
|
|
2125
|
+
except NotImplementedError:
|
|
2126
|
+
continue
|
|
1987
2127
|
|
|
1988
|
-
if input_embeddings is None:
|
|
2128
|
+
if input_embeddings is None or not hasattr(input_embeddings, "register_forward_hook"):
|
|
1989
2129
|
continue
|
|
1990
2130
|
|
|
1991
2131
|
embedding_id = id(input_embeddings)
|
|
@@ -1994,11 +2134,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1994
2134
|
|
|
1995
2135
|
seen_modules.add(embedding_id)
|
|
1996
2136
|
hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))
|
|
2137
|
+
found_embeddings = True
|
|
1997
2138
|
|
|
1998
2139
|
self._require_grads_hooks = hooks
|
|
1999
2140
|
if hooks:
|
|
2000
2141
|
# for BC
|
|
2001
2142
|
self._require_grads_hook = hooks[0]
|
|
2143
|
+
if not found_embeddings:
|
|
2144
|
+
logger.warning_once(
|
|
2145
|
+
f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
|
|
2146
|
+
"embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
|
|
2147
|
+
"support those features, or set `_input_embed_layer` to the attribute name that holds the embeddings."
|
|
2148
|
+
)
|
|
2002
2149
|
|
|
2003
2150
|
def disable_input_require_grads(self):
|
|
2004
2151
|
"""
|
|
@@ -2104,7 +2251,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2104
2251
|
possible_module_names = ["language_model", "text_model", "decoder"]
|
|
2105
2252
|
for name in possible_module_names:
|
|
2106
2253
|
if hasattr(self, name):
|
|
2107
|
-
print(name)
|
|
2108
2254
|
setattr(self, name, decoder)
|
|
2109
2255
|
return
|
|
2110
2256
|
|
|
@@ -2134,14 +2280,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2134
2280
|
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
|
|
2135
2281
|
if getattr(module, "weight", None) is not None:
|
|
2136
2282
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
2137
|
-
if
|
|
2283
|
+
if module.bias is not None:
|
|
2138
2284
|
init.zeros_(module.bias)
|
|
2139
2285
|
elif isinstance(module, nn.Embedding):
|
|
2140
|
-
|
|
2141
|
-
|
|
2142
|
-
|
|
2143
|
-
|
|
2144
|
-
init.zeros_(module.weight[module.padding_idx])
|
|
2286
|
+
init.normal_(module.weight, mean=0.0, std=std)
|
|
2287
|
+
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
2288
|
+
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
|
|
2289
|
+
init.zeros_(module.weight[module.padding_idx])
|
|
2145
2290
|
elif isinstance(module, nn.MultiheadAttention):
|
|
2146
2291
|
# This uses torch's original init
|
|
2147
2292
|
module._reset_parameters()
|
|
@@ -2153,10 +2298,25 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2153
2298
|
or "RMSNorm" in module.__class__.__name__
|
|
2154
2299
|
):
|
|
2155
2300
|
# Norms can exist without weights (in which case they are None from torch primitives)
|
|
2156
|
-
if
|
|
2301
|
+
if getattr(module, "weight", None) is not None:
|
|
2157
2302
|
init.ones_(module.weight)
|
|
2158
|
-
if
|
|
2303
|
+
if getattr(module, "bias", None) is not None:
|
|
2159
2304
|
init.zeros_(module.bias)
|
|
2305
|
+
# And the potential buffers for the BatchNorms
|
|
2306
|
+
if getattr(module, "running_mean", None) is not None:
|
|
2307
|
+
init.zeros_(module.running_mean)
|
|
2308
|
+
init.ones_(module.running_var)
|
|
2309
|
+
init.zeros_(module.num_batches_tracked)
|
|
2310
|
+
# This matches all the usual RotaryEmbeddings modules
|
|
2311
|
+
elif "RotaryEmbedding" in module.__class__.__name__ and hasattr(module, "original_inv_freq"):
|
|
2312
|
+
rope_fn = (
|
|
2313
|
+
ROPE_INIT_FUNCTIONS[module.rope_type]
|
|
2314
|
+
if module.rope_type != "default"
|
|
2315
|
+
else module.compute_default_rope_parameters
|
|
2316
|
+
)
|
|
2317
|
+
buffer_value, _ = rope_fn(module.config)
|
|
2318
|
+
init.copy_(module.inv_freq, buffer_value)
|
|
2319
|
+
init.copy_(module.original_inv_freq, buffer_value)
|
|
2160
2320
|
|
|
2161
2321
|
def _initialize_weights(self, module):
|
|
2162
2322
|
"""
|
|
@@ -2261,7 +2421,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2261
2421
|
|
|
2262
2422
|
tied_mapping = self._tied_weights_keys
|
|
2263
2423
|
# If the config does not specify any tying, return empty dict
|
|
2264
|
-
if not self.config.tie_word_embeddings
|
|
2424
|
+
if not self.config.tie_word_embeddings:
|
|
2265
2425
|
return {}
|
|
2266
2426
|
# If None, return empty dict
|
|
2267
2427
|
elif tied_mapping is None:
|
|
@@ -2327,30 +2487,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2327
2487
|
|
|
2328
2488
|
tied_keys = list(tied_keys.items())
|
|
2329
2489
|
for i, (target_param_name, source_param_name) in enumerate(tied_keys):
|
|
2330
|
-
# Usually we tie a single target to a single source, but when both are missing we may later tie
|
|
2331
|
-
# both the source and target to a third "backup" parameter that is present in the checkpoint, so we use
|
|
2332
|
-
# a list here
|
|
2333
|
-
target_param_names = [target_param_name]
|
|
2334
|
-
|
|
2335
2490
|
# This is `from_pretrained` -> let's check symmetrically in case the source key is not present
|
|
2336
2491
|
if missing_keys is not None:
|
|
2337
2492
|
remove_from_missing = True
|
|
2338
2493
|
source_is_there = source_param_name not in missing_keys
|
|
2339
2494
|
target_is_there = target_param_name not in missing_keys
|
|
2340
2495
|
# Both are already present -> it means the config is wrong and do not reflect the actual
|
|
2341
|
-
# checkpoint -> let's raise a warning and
|
|
2496
|
+
# checkpoint -> let's raise a warning and NOT tie them
|
|
2342
2497
|
if source_is_there and target_is_there:
|
|
2343
2498
|
logger.warning(
|
|
2344
2499
|
f"The tied weights mapping and config for this model specifies to tie {source_param_name} to "
|
|
2345
2500
|
f"{target_param_name}, but both are present in the checkpoints, so we will NOT tie them. "
|
|
2346
2501
|
"You should update the config with `tie_word_embeddings=False` to silence this warning"
|
|
2347
2502
|
)
|
|
2503
|
+
# Remove from internal attribute to correctly reflect actual tied weights
|
|
2504
|
+
self.all_tied_weights_keys.pop(target_param_name)
|
|
2348
2505
|
# Skip to next iteration
|
|
2349
2506
|
continue
|
|
2350
2507
|
# We're missing the source but we have the target -> we swap them, tying the parameter that exists
|
|
2351
2508
|
elif not source_is_there and target_is_there:
|
|
2352
2509
|
target_param_name, source_param_name = source_param_name, target_param_name
|
|
2353
|
-
target_param_names = [target_param_name]
|
|
2354
2510
|
# Both are missing -> check other keys in case more than 2 keys are tied to the same weight
|
|
2355
2511
|
elif not source_is_there and not target_is_there:
|
|
2356
2512
|
for target_backup, source_backup in tied_keys[i + 1 :]:
|
|
@@ -2359,10 +2515,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2359
2515
|
if source_backup == source_param_name:
|
|
2360
2516
|
target_backup_is_there = target_backup not in missing_keys
|
|
2361
2517
|
# If the target is present, we found the correct weight to tie into (we know the source is missing)
|
|
2518
|
+
# Note here that we do not tie the missing source right now as well, as it will be done anyway when
|
|
2519
|
+
# the pair (target_backup, source_backup) becomes the main pair (target_param_name, source_param_name)
|
|
2362
2520
|
if target_backup_is_there:
|
|
2363
2521
|
source_param_name = target_backup
|
|
2364
|
-
# Append the source as well, since both are missing we'll tie both
|
|
2365
|
-
target_param_names.append(source_param_name)
|
|
2366
2522
|
break
|
|
2367
2523
|
# If we did not break from the loop, it was impossible to find a source key -> let's raise
|
|
2368
2524
|
else:
|
|
@@ -2378,19 +2534,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2378
2534
|
|
|
2379
2535
|
# Perform the actual tying
|
|
2380
2536
|
source_param = self.get_parameter_or_buffer(source_param_name)
|
|
2381
|
-
|
|
2382
|
-
|
|
2383
|
-
|
|
2384
|
-
|
|
2385
|
-
|
|
2386
|
-
|
|
2387
|
-
|
|
2388
|
-
|
|
2389
|
-
|
|
2390
|
-
|
|
2391
|
-
|
|
2392
|
-
|
|
2393
|
-
missing_keys.discard(target_param_name)
|
|
2537
|
+
if "." in target_param_name:
|
|
2538
|
+
parent_name, name = target_param_name.rsplit(".", 1)
|
|
2539
|
+
parent = self.get_submodule(parent_name)
|
|
2540
|
+
else:
|
|
2541
|
+
name = target_param_name
|
|
2542
|
+
parent = self
|
|
2543
|
+
# Tie the weights
|
|
2544
|
+
setattr(parent, name, source_param)
|
|
2545
|
+
self._adjust_bias(parent, source_param)
|
|
2546
|
+
# Remove from missing if necesary
|
|
2547
|
+
if missing_keys is not None and remove_from_missing:
|
|
2548
|
+
missing_keys.discard(target_param_name)
|
|
2394
2549
|
|
|
2395
2550
|
def _adjust_bias(self, output_embeddings, input_embeddings):
|
|
2396
2551
|
if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"):
|
|
@@ -2903,7 +3058,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2903
3058
|
Maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
|
|
2904
3059
|
initialization logic in `_init_weights`.
|
|
2905
3060
|
"""
|
|
2906
|
-
|
|
3061
|
+
# If we are initializing on meta device, there is no point in trying to run inits
|
|
3062
|
+
if get_torch_context_manager_or_global_device() != torch.device("meta"):
|
|
2907
3063
|
# Initialize weights
|
|
2908
3064
|
self.initialize_weights()
|
|
2909
3065
|
# Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
|
|
@@ -2941,7 +3097,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2941
3097
|
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
|
2942
3098
|
)
|
|
2943
3099
|
|
|
2944
|
-
|
|
3100
|
+
needs_embedding_grads = self.main_input_name == "input_ids"
|
|
3101
|
+
# we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all)
|
|
3102
|
+
enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False)
|
|
3103
|
+
if enable_input_grads:
|
|
2945
3104
|
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
|
2946
3105
|
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
|
2947
3106
|
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
|
@@ -3002,10 +3161,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3002
3161
|
save_directory: Union[str, os.PathLike],
|
|
3003
3162
|
is_main_process: bool = True,
|
|
3004
3163
|
state_dict: Optional[dict] = None,
|
|
3005
|
-
save_function: Callable = torch.save,
|
|
3006
3164
|
push_to_hub: bool = False,
|
|
3007
|
-
max_shard_size: Union[int, str] = "
|
|
3008
|
-
safe_serialization: bool = True,
|
|
3165
|
+
max_shard_size: Union[int, str] = "50GB",
|
|
3009
3166
|
variant: Optional[str] = None,
|
|
3010
3167
|
token: Optional[Union[str, bool]] = None,
|
|
3011
3168
|
save_peft_format: bool = True,
|
|
@@ -3027,18 +3184,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3027
3184
|
The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
|
|
3028
3185
|
save parts of the model or if special precautions need to be taken when recovering the state dictionary
|
|
3029
3186
|
of a model (like when using model parallelism).
|
|
3030
|
-
save_function (`Callable`):
|
|
3031
|
-
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
|
3032
|
-
need to replace `torch.save` by another method.
|
|
3033
3187
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
3034
3188
|
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
|
3035
3189
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
|
3036
3190
|
namespace).
|
|
3037
|
-
max_shard_size (`int` or `str`, *optional*, defaults to `"
|
|
3191
|
+
max_shard_size (`int` or `str`, *optional*, defaults to `"50GB"`):
|
|
3038
3192
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
|
3039
3193
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
|
3040
|
-
We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
|
|
3041
|
-
without CPU OOM issues.
|
|
3042
3194
|
|
|
3043
3195
|
<Tip warning={true}>
|
|
3044
3196
|
|
|
@@ -3047,10 +3199,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3047
3199
|
|
|
3048
3200
|
</Tip>
|
|
3049
3201
|
|
|
3050
|
-
safe_serialization (`bool`, *optional*, defaults to `True`):
|
|
3051
|
-
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
|
3052
3202
|
variant (`str`, *optional*):
|
|
3053
|
-
If specified, weights are saved in the format
|
|
3203
|
+
If specified, weights are saved in the format model.<variant>.safetensors.
|
|
3054
3204
|
token (`str` or `bool`, *optional*):
|
|
3055
3205
|
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
|
3056
3206
|
the token generated when running `hf auth login` (stored in `~/.huggingface`).
|
|
@@ -3072,9 +3222,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3072
3222
|
|
|
3073
3223
|
hf_quantizer = getattr(self, "hf_quantizer", None)
|
|
3074
3224
|
quantization_serializable = (
|
|
3075
|
-
hf_quantizer is not None
|
|
3076
|
-
and isinstance(hf_quantizer, HfQuantizer)
|
|
3077
|
-
and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
|
|
3225
|
+
hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable()
|
|
3078
3226
|
)
|
|
3079
3227
|
|
|
3080
3228
|
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
|
|
@@ -3110,7 +3258,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3110
3258
|
|
|
3111
3259
|
metadata = {}
|
|
3112
3260
|
if hf_quantizer is not None:
|
|
3113
|
-
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self
|
|
3261
|
+
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self)
|
|
3114
3262
|
metadata["format"] = "pt"
|
|
3115
3263
|
|
|
3116
3264
|
# Only save the model itself if we are using distributed training
|
|
@@ -3163,29 +3311,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3163
3311
|
current_peft_config = self.peft_config[active_adapter]
|
|
3164
3312
|
current_peft_config.save_pretrained(save_directory)
|
|
3165
3313
|
|
|
3166
|
-
#
|
|
3167
|
-
module_map = {}
|
|
3168
|
-
|
|
3169
|
-
# Save the model
|
|
3314
|
+
# Get the model state_dict
|
|
3170
3315
|
if state_dict is None:
|
|
3171
|
-
# if any model parameters are offloaded, make module map
|
|
3172
|
-
if (
|
|
3173
|
-
hasattr(self, "hf_device_map")
|
|
3174
|
-
and len(set(self.hf_device_map.values())) > 1
|
|
3175
|
-
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
|
|
3176
|
-
):
|
|
3177
|
-
warnings.warn(
|
|
3178
|
-
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
|
|
3179
|
-
)
|
|
3180
|
-
for name, module in model_to_save.named_modules():
|
|
3181
|
-
if name == "":
|
|
3182
|
-
continue
|
|
3183
|
-
module_state_dict = module.state_dict()
|
|
3184
|
-
|
|
3185
|
-
for key in module_state_dict:
|
|
3186
|
-
module_map[name + f".{key}"] = module
|
|
3187
3316
|
state_dict = model_to_save.state_dict()
|
|
3188
3317
|
|
|
3318
|
+
# if any model parameters are offloaded, we need to know it for later
|
|
3319
|
+
is_offloaded = False
|
|
3320
|
+
if (
|
|
3321
|
+
hasattr(self, "hf_device_map")
|
|
3322
|
+
and len(set(self.hf_device_map.values())) > 1
|
|
3323
|
+
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
|
|
3324
|
+
):
|
|
3325
|
+
is_offloaded = True
|
|
3326
|
+
warnings.warn(
|
|
3327
|
+
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory "
|
|
3328
|
+
"exceeds the `shard_size` (50GB default)"
|
|
3329
|
+
)
|
|
3330
|
+
|
|
3189
3331
|
# Translate state_dict from smp to hf if saving with smp >= 1.10
|
|
3190
3332
|
if IS_SAGEMAKER_MP_POST_1_10:
|
|
3191
3333
|
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
|
|
@@ -3202,86 +3344,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3202
3344
|
if self._tp_size is not None:
|
|
3203
3345
|
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
|
3204
3346
|
|
|
3205
|
-
|
|
3206
|
-
|
|
3207
|
-
# Safetensors does not allow tensor aliasing.
|
|
3208
|
-
# We're going to remove aliases before saving
|
|
3209
|
-
ptrs = collections.defaultdict(list)
|
|
3210
|
-
for name, tensor in state_dict.items():
|
|
3211
|
-
if not isinstance(tensor, torch.Tensor):
|
|
3212
|
-
# Sometimes in the state_dict we have non-tensor objects.
|
|
3213
|
-
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
|
3214
|
-
# In the non-tensor case, fall back to the pointer of the object itself
|
|
3215
|
-
ptrs[id(tensor)].append(name)
|
|
3216
|
-
|
|
3217
|
-
elif tensor.device.type == "meta":
|
|
3218
|
-
# In offloaded cases, there may be meta tensors in the state_dict.
|
|
3219
|
-
# For these cases, key by the pointer of the original tensor object
|
|
3220
|
-
# (state_dict tensors are detached and therefore no longer shared)
|
|
3221
|
-
tensor = self.get_parameter(name)
|
|
3222
|
-
ptrs[id(tensor)].append(name)
|
|
3223
|
-
|
|
3224
|
-
else:
|
|
3225
|
-
ptrs[id_tensor_storage(tensor)].append(name)
|
|
3226
|
-
|
|
3227
|
-
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
|
3228
|
-
|
|
3229
|
-
# Recursively descend to find tied weight keys
|
|
3230
|
-
_tied_weights_keys = set(_get_tied_weight_keys(self))
|
|
3231
|
-
error_names = []
|
|
3232
|
-
to_delete_names = set()
|
|
3233
|
-
for names in shared_ptrs.values():
|
|
3234
|
-
# Removing the keys which are declared as known duplicates on
|
|
3235
|
-
# load. This allows to make sure the name which is kept is consistent.
|
|
3236
|
-
if _tied_weights_keys is not None:
|
|
3237
|
-
found = 0
|
|
3238
|
-
for name in sorted(names):
|
|
3239
|
-
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
|
|
3240
|
-
if matches_pattern and name in state_dict:
|
|
3241
|
-
found += 1
|
|
3242
|
-
if found < len(names):
|
|
3243
|
-
to_delete_names.add(name)
|
|
3244
|
-
# We are entering a place where the weights and the transformers configuration do NOT match.
|
|
3245
|
-
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
|
|
3246
|
-
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
|
|
3247
|
-
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
|
|
3248
|
-
for name in disjoint_names:
|
|
3249
|
-
state_dict[name] = state_dict[name].clone()
|
|
3250
|
-
|
|
3251
|
-
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
|
3252
|
-
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
|
3253
|
-
# the key back leading to random tensor. A proper warning will be shown
|
|
3254
|
-
# during reload (if applicable), but since the file is not necessarily compatible with
|
|
3255
|
-
# the config, better show a proper warning.
|
|
3256
|
-
shared_names, identical_names = _find_identical(shared_names, state_dict)
|
|
3257
|
-
# delete tensors that have identical storage
|
|
3258
|
-
for inames in identical_names:
|
|
3259
|
-
known = inames.intersection(to_delete_names)
|
|
3260
|
-
for name in known:
|
|
3261
|
-
del state_dict[name]
|
|
3262
|
-
unknown = inames.difference(to_delete_names)
|
|
3263
|
-
if len(unknown) > 1:
|
|
3264
|
-
error_names.append(unknown)
|
|
3265
|
-
|
|
3266
|
-
if shared_names:
|
|
3267
|
-
error_names.extend(shared_names)
|
|
3268
|
-
|
|
3269
|
-
if len(error_names) > 0:
|
|
3270
|
-
raise RuntimeError(
|
|
3271
|
-
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
|
|
3272
|
-
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
|
|
3273
|
-
)
|
|
3347
|
+
# Remove tied weights as safetensors do not handle them
|
|
3348
|
+
state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)
|
|
3274
3349
|
|
|
3275
3350
|
# Revert all renaming and/or weight operations
|
|
3276
3351
|
if save_original_format:
|
|
3277
|
-
state_dict = revert_weight_conversion(
|
|
3352
|
+
state_dict = revert_weight_conversion(model_to_save, state_dict)
|
|
3278
3353
|
|
|
3279
3354
|
# Shard the model if it is too big.
|
|
3280
3355
|
if not _hf_peft_config_loaded:
|
|
3281
|
-
weights_name = SAFE_WEIGHTS_NAME
|
|
3356
|
+
weights_name = SAFE_WEIGHTS_NAME
|
|
3282
3357
|
weights_name = _add_variant(weights_name, variant)
|
|
3283
3358
|
else:
|
|
3284
|
-
weights_name = ADAPTER_SAFE_WEIGHTS_NAME
|
|
3359
|
+
weights_name = ADAPTER_SAFE_WEIGHTS_NAME
|
|
3285
3360
|
|
|
3286
3361
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
|
3287
3362
|
state_dict_split = split_torch_state_dict_into_shards(
|
|
@@ -3314,57 +3389,45 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3314
3389
|
and reg.fullmatch(filename_no_suffix) is not None
|
|
3315
3390
|
):
|
|
3316
3391
|
os.remove(full_filename)
|
|
3392
|
+
|
|
3317
3393
|
# Save the model
|
|
3318
|
-
|
|
3319
|
-
|
|
3320
|
-
|
|
3321
|
-
|
|
3322
|
-
|
|
3323
|
-
for
|
|
3324
|
-
|
|
3325
|
-
|
|
3394
|
+
for shard_file, tensor_names in logging.tqdm(
|
|
3395
|
+
state_dict_split.filename_to_tensors.items(), desc="Writing model shards"
|
|
3396
|
+
):
|
|
3397
|
+
filename = os.path.join(save_directory, shard_file)
|
|
3398
|
+
shard_state_dict = {}
|
|
3399
|
+
for tensor_name in tensor_names:
|
|
3400
|
+
# Get the tensor, and remove it from state_dict to avoid keeping the ref
|
|
3401
|
+
tensor = state_dict.pop(tensor_name)
|
|
3402
|
+
|
|
3403
|
+
# In case of TP, get the full parameter back
|
|
3404
|
+
if _is_dtensor_available and isinstance(tensor, DTensor):
|
|
3405
|
+
tensor = tensor.full_tensor()
|
|
3326
3406
|
# to get the correctly ordered tensor we need to repack if packed
|
|
3327
|
-
if _get_parameter_tp_plan(
|
|
3328
|
-
|
|
3329
|
-
|
|
3330
|
-
|
|
3331
|
-
|
|
3332
|
-
#
|
|
3333
|
-
|
|
3334
|
-
|
|
3335
|
-
|
|
3336
|
-
|
|
3337
|
-
|
|
3338
|
-
shard_state_dict = dict.fromkeys(shard, "")
|
|
3339
|
-
for module_name in shard:
|
|
3340
|
-
# note that get_state_dict_from_offload can update with meta tensors
|
|
3341
|
-
# if both a parent module and its descendant are offloaded
|
|
3342
|
-
tensor = shard_state_dict[module_name]
|
|
3343
|
-
if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
|
|
3344
|
-
# update state dict with onloaded parameters
|
|
3345
|
-
module = module_map[module_name]
|
|
3346
|
-
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
|
|
3347
|
-
|
|
3348
|
-
# assign shard to be the completed state dict
|
|
3349
|
-
shard = shard_state_dict
|
|
3350
|
-
del shard_state_dict
|
|
3351
|
-
gc.collect()
|
|
3352
|
-
|
|
3353
|
-
if safe_serialization:
|
|
3354
|
-
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
|
3355
|
-
# joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting
|
|
3356
|
-
# too much before scheduling the next write when its in a different file
|
|
3357
|
-
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
|
3358
|
-
else:
|
|
3359
|
-
save_function(shard, os.path.join(save_directory, shard_file))
|
|
3407
|
+
if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
|
|
3408
|
+
tensor = repack_weights(tensor, -1, self._tp_size, 2)
|
|
3409
|
+
|
|
3410
|
+
# If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
|
|
3411
|
+
# but it would otherwise not be contained in the saved shard if we were to simply move the file
|
|
3412
|
+
# or something
|
|
3413
|
+
if is_offloaded and tensor.device.type == "meta":
|
|
3414
|
+
tensor = load_offloaded_parameter(model_to_save, tensor_name)
|
|
3415
|
+
|
|
3416
|
+
# only do contiguous after it's permuted correctly in case of TP
|
|
3417
|
+
shard_state_dict[tensor_name] = tensor.contiguous()
|
|
3360
3418
|
|
|
3361
|
-
|
|
3419
|
+
# TODO: it would be very nice to do the writing concurrently, but safetensors never releases the GIL,
|
|
3420
|
+
# so it's not possible for now....
|
|
3421
|
+
# Write the shard to disk
|
|
3422
|
+
safe_save_file(shard_state_dict, filename, metadata=metadata)
|
|
3423
|
+
# Cleanup the data before next loop (important with offloading, so we don't blowup cpu RAM)
|
|
3424
|
+
del shard_state_dict
|
|
3362
3425
|
|
|
3363
3426
|
if index is None:
|
|
3364
3427
|
path_to_weights = os.path.join(save_directory, weights_name)
|
|
3365
3428
|
logger.info(f"Model weights saved in {path_to_weights}")
|
|
3366
3429
|
else:
|
|
3367
|
-
save_index_file = SAFE_WEIGHTS_INDEX_NAME
|
|
3430
|
+
save_index_file = SAFE_WEIGHTS_INDEX_NAME
|
|
3368
3431
|
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
|
3369
3432
|
# Save the index as well
|
|
3370
3433
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
|
@@ -3535,19 +3598,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3535
3598
|
return super().float(*args)
|
|
3536
3599
|
|
|
3537
3600
|
@classmethod
|
|
3538
|
-
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
|
3601
|
+
def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
|
|
3602
|
+
# Need to instantiate with correct dtype
|
|
3603
|
+
init_contexts = [local_torch_dtype(dtype, cls.__name__)]
|
|
3539
3604
|
if is_deepspeed_zero3_enabled():
|
|
3540
3605
|
import deepspeed
|
|
3541
3606
|
|
|
3542
|
-
init_contexts = [no_init_weights()]
|
|
3543
3607
|
# We cannot initialize the model on meta device with deepspeed when not quantized
|
|
3544
3608
|
if not is_quantized and not _is_ds_init_called:
|
|
3545
3609
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
|
3546
|
-
init_contexts.extend(
|
|
3610
|
+
init_contexts.extend(
|
|
3611
|
+
[
|
|
3612
|
+
init.no_init_weights(),
|
|
3613
|
+
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
|
3614
|
+
set_zero3_state(),
|
|
3615
|
+
]
|
|
3616
|
+
)
|
|
3547
3617
|
elif is_quantized:
|
|
3548
|
-
init_contexts.extend([
|
|
3618
|
+
init_contexts.extend([torch.device("meta"), set_quantized_state()])
|
|
3549
3619
|
else:
|
|
3550
|
-
init_contexts
|
|
3620
|
+
init_contexts.append(torch.device("meta"))
|
|
3551
3621
|
|
|
3552
3622
|
return init_contexts
|
|
3553
3623
|
|
|
@@ -3572,7 +3642,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3572
3642
|
|
|
3573
3643
|
# This is a context manager to override the default kernel mapping
|
|
3574
3644
|
# We are calling kernelize inside this context manager using the use_kernels setter
|
|
3575
|
-
|
|
3645
|
+
# Param inherit_mapping should be False to avoid still loading kernel from remote
|
|
3646
|
+
inherit_mapping = not kernel_config.use_local_kernel
|
|
3647
|
+
with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
|
|
3576
3648
|
self.use_kernels = True
|
|
3577
3649
|
# We use the default kernel mapping in .integrations.hub_kernels
|
|
3578
3650
|
else:
|
|
@@ -3581,7 +3653,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3581
3653
|
self.use_kernels = False
|
|
3582
3654
|
|
|
3583
3655
|
@classmethod
|
|
3584
|
-
@restore_default_dtype
|
|
3585
3656
|
def from_pretrained(
|
|
3586
3657
|
cls: type[SpecificPreTrainedModelType],
|
|
3587
3658
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
|
@@ -3690,10 +3761,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3690
3761
|
"org/model@main"
|
|
3691
3762
|
"org/model:custom_kernel"
|
|
3692
3763
|
"org/model@v1.2.3:custom_kernel"
|
|
3764
|
+
experts_implementation (`str`, *optional*):
|
|
3765
|
+
The experts implementation to use in the model (if relevant). Can be any of:
|
|
3766
|
+
|
|
3767
|
+
- `"eager"` (sequential implementation of the experts matrix multiplications).
|
|
3768
|
+
- `"batched_mm"` (using [`torch.bmm`](https://pytorch.org/docs/stable/generated/torch.bmm.html)).
|
|
3769
|
+
- `"grouped_mm"` (using [`torch._grouped_mm`](https://docs.pytorch.org/docs/main/generated/torch.nn.functional.grouped_mm.html)).
|
|
3770
|
+
|
|
3771
|
+
By default, if available, `grouped_mm` will be used for torch>=2.9.0. The default is otherwise the sequential `"eager"` implementation.
|
|
3693
3772
|
|
|
3694
3773
|
> Parameters for big model inference
|
|
3695
3774
|
|
|
3696
|
-
dtype (`str` or `torch.dtype`, *optional
|
|
3775
|
+
dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"`):
|
|
3697
3776
|
Override the default `torch_dtype` and load the model under a specific `dtype`. The different options
|
|
3698
3777
|
are:
|
|
3699
3778
|
|
|
@@ -3835,6 +3914,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3835
3914
|
# For BC on torch_dtype argument
|
|
3836
3915
|
if torch_dtype is not None:
|
|
3837
3916
|
dtype = dtype if dtype is not None else torch_dtype
|
|
3917
|
+
if dtype is None:
|
|
3918
|
+
dtype = "auto"
|
|
3838
3919
|
|
|
3839
3920
|
if is_offline_mode() and not local_files_only:
|
|
3840
3921
|
local_files_only = True
|
|
@@ -3911,8 +3992,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3911
3992
|
if "attn_implementation" in kwargs:
|
|
3912
3993
|
config._attn_implementation = kwargs.pop("attn_implementation")
|
|
3913
3994
|
|
|
3914
|
-
|
|
3915
|
-
config
|
|
3995
|
+
if "experts_implementation" in kwargs:
|
|
3996
|
+
config._experts_implementation = kwargs.pop("experts_implementation")
|
|
3997
|
+
|
|
3998
|
+
hf_quantizer, config, device_map = get_hf_quantizer(
|
|
3999
|
+
config, quantization_config, device_map, weights_only, user_agent
|
|
3916
4000
|
)
|
|
3917
4001
|
|
|
3918
4002
|
if gguf_file:
|
|
@@ -3959,33 +4043,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3959
4043
|
]
|
|
3960
4044
|
|
|
3961
4045
|
# Find the correct dtype based on current state
|
|
3962
|
-
config, dtype
|
|
3963
|
-
|
|
4046
|
+
config, dtype = _get_dtype(
|
|
4047
|
+
dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only, hf_quantizer
|
|
3964
4048
|
)
|
|
3965
4049
|
|
|
3966
4050
|
config.name_or_path = pretrained_model_name_or_path
|
|
3967
|
-
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
|
|
4051
|
+
model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called)
|
|
3968
4052
|
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
|
3969
4053
|
with ContextManagers(model_init_context):
|
|
3970
4054
|
# Let's make sure we don't run the init function of buffer modules
|
|
3971
4055
|
model = cls(config, *model_args, **model_kwargs)
|
|
3972
4056
|
|
|
4057
|
+
if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
|
|
4058
|
+
hf_quantizer.preprocess_model(
|
|
4059
|
+
model=model,
|
|
4060
|
+
dtype=dtype,
|
|
4061
|
+
device_map=device_map,
|
|
4062
|
+
checkpoint_files=checkpoint_files,
|
|
4063
|
+
use_kernels=use_kernels,
|
|
4064
|
+
)
|
|
4065
|
+
|
|
3973
4066
|
# Obtain the weight conversion mapping for this model if any are registered
|
|
3974
4067
|
weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
|
|
3975
4068
|
|
|
3976
|
-
# make sure we use the model's config since the __init__ call might have copied it
|
|
3977
|
-
config = model.config
|
|
3978
|
-
|
|
3979
|
-
if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
|
|
3980
|
-
hf_quantizer.preprocess_model(
|
|
3981
|
-
model=model,
|
|
3982
|
-
device_map=device_map,
|
|
3983
|
-
keep_in_fp32_modules=model._keep_in_fp32_modules, # TODO prob no longer needed?
|
|
3984
|
-
config=config,
|
|
3985
|
-
checkpoint_files=checkpoint_files,
|
|
3986
|
-
use_kernels=use_kernels,
|
|
3987
|
-
)
|
|
3988
|
-
|
|
3989
4069
|
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
|
|
3990
4070
|
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
|
|
3991
4071
|
|
|
@@ -3993,10 +4073,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3993
4073
|
if device_map is not None:
|
|
3994
4074
|
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
|
|
3995
4075
|
|
|
3996
|
-
# restore default dtype
|
|
3997
|
-
if dtype_orig is not None:
|
|
3998
|
-
torch.set_default_dtype(dtype_orig)
|
|
3999
|
-
|
|
4000
4076
|
# Finalize model weight initialization
|
|
4001
4077
|
model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
|
|
4002
4078
|
model,
|
|
@@ -4007,6 +4083,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4007
4083
|
sharded_metadata=sharded_metadata,
|
|
4008
4084
|
device_map=device_map,
|
|
4009
4085
|
disk_offload_folder=offload_folder,
|
|
4086
|
+
offload_buffers=offload_buffers,
|
|
4010
4087
|
dtype=dtype,
|
|
4011
4088
|
hf_quantizer=hf_quantizer,
|
|
4012
4089
|
device_mesh=device_mesh,
|
|
@@ -4014,7 +4091,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4014
4091
|
weight_mapping=weight_conversions,
|
|
4015
4092
|
)
|
|
4016
4093
|
|
|
4017
|
-
model.eval() # Set model in evaluation mode to deactivate
|
|
4094
|
+
model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
|
|
4018
4095
|
model.set_use_kernels(use_kernels, kernel_config)
|
|
4019
4096
|
|
|
4020
4097
|
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
|
@@ -4030,16 +4107,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4030
4107
|
**kwargs,
|
|
4031
4108
|
)
|
|
4032
4109
|
|
|
4033
|
-
#
|
|
4034
|
-
if device_map is not None and
|
|
4110
|
+
# If the device_map has more than 1 device: dispatch model with hooks on all devices
|
|
4111
|
+
if device_map is not None and len(set(device_map.values())) > 1:
|
|
4035
4112
|
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
|
|
4036
4113
|
|
|
4037
4114
|
if hf_quantizer is not None:
|
|
4038
4115
|
model.hf_quantizer = hf_quantizer
|
|
4039
|
-
hf_quantizer.postprocess_model(
|
|
4116
|
+
hf_quantizer.postprocess_model(
|
|
4117
|
+
model
|
|
4118
|
+
) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
|
|
4040
4119
|
|
|
4041
4120
|
if _adapter_model_path is not None:
|
|
4042
|
-
adapter_kwargs["key_mapping"] =
|
|
4121
|
+
adapter_kwargs["key_mapping"] = key_mapping
|
|
4043
4122
|
model.load_adapter(
|
|
4044
4123
|
_adapter_model_path,
|
|
4045
4124
|
adapter_name=adapter_name,
|
|
@@ -4068,6 +4147,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4068
4147
|
sharded_metadata: Optional[dict] = None,
|
|
4069
4148
|
device_map: Optional[dict] = None,
|
|
4070
4149
|
disk_offload_folder: Optional[str] = None,
|
|
4150
|
+
offload_buffers: bool = False,
|
|
4071
4151
|
dtype: Optional[torch.dtype] = None,
|
|
4072
4152
|
hf_quantizer: Optional[HfQuantizer] = None,
|
|
4073
4153
|
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
|
@@ -4082,6 +4162,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4082
4162
|
|
|
4083
4163
|
# Model's definition arriving here is final (TP hooks added, quantized layers replaces)
|
|
4084
4164
|
expected_keys = list(model.state_dict().keys())
|
|
4165
|
+
|
|
4085
4166
|
if logger.level >= logging.WARNING:
|
|
4086
4167
|
verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
|
|
4087
4168
|
|
|
@@ -4090,10 +4171,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4090
4171
|
# Prepare parameters offloading if needed
|
|
4091
4172
|
if device_map is not None and "disk" in device_map.values():
|
|
4092
4173
|
disk_offload_index = accelerate_disk_offload(
|
|
4174
|
+
model,
|
|
4093
4175
|
disk_offload_folder,
|
|
4094
4176
|
checkpoint_files,
|
|
4095
4177
|
device_map,
|
|
4096
|
-
expected_keys,
|
|
4097
4178
|
sharded_metadata,
|
|
4098
4179
|
dtype,
|
|
4099
4180
|
weight_mapping,
|
|
@@ -4104,7 +4185,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4104
4185
|
expanded_device_map = expand_device_map(device_map, expected_keys)
|
|
4105
4186
|
caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
|
|
4106
4187
|
|
|
4107
|
-
tp_plan = getattr(model, "_tp_plan", None)
|
|
4108
4188
|
error_msgs = []
|
|
4109
4189
|
|
|
4110
4190
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
|
@@ -4113,9 +4193,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4113
4193
|
for ckpt_file in checkpoint_files:
|
|
4114
4194
|
merged_state_dict.update(load_state_dict(ckpt_file, map_location="cpu", weights_only=weights_only))
|
|
4115
4195
|
state_dict = merged_state_dict
|
|
4116
|
-
error_msgs
|
|
4196
|
+
error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict)
|
|
4117
4197
|
# This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
|
|
4118
|
-
|
|
4198
|
+
unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set()
|
|
4119
4199
|
else:
|
|
4120
4200
|
all_pointer = set()
|
|
4121
4201
|
# Checkpoints are safetensors
|
|
@@ -4137,19 +4217,20 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4137
4217
|
else:
|
|
4138
4218
|
raise ValueError("Neither a state dict nor checkpoint files were found.")
|
|
4139
4219
|
|
|
4140
|
-
missing_keys, unexpected_keys, mismatched_keys, disk_offload_index,
|
|
4220
|
+
missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
|
|
4141
4221
|
convert_and_load_state_dict_in_model(
|
|
4142
|
-
model,
|
|
4143
|
-
merged_state_dict,
|
|
4144
|
-
weight_mapping,
|
|
4145
|
-
tp_plan,
|
|
4146
|
-
hf_quantizer,
|
|
4147
|
-
dtype,
|
|
4148
|
-
device_map,
|
|
4149
|
-
model.dtype_plan,
|
|
4150
|
-
device_mesh,
|
|
4151
|
-
disk_offload_index,
|
|
4152
|
-
disk_offload_folder,
|
|
4222
|
+
model=model,
|
|
4223
|
+
state_dict=merged_state_dict,
|
|
4224
|
+
weight_mapping=weight_mapping,
|
|
4225
|
+
tp_plan=model._tp_plan,
|
|
4226
|
+
hf_quantizer=hf_quantizer,
|
|
4227
|
+
dtype=dtype,
|
|
4228
|
+
device_map=device_map,
|
|
4229
|
+
dtype_plan=model.dtype_plan,
|
|
4230
|
+
device_mesh=device_mesh,
|
|
4231
|
+
disk_offload_index=disk_offload_index,
|
|
4232
|
+
disk_offload_folder=disk_offload_folder,
|
|
4233
|
+
offload_buffers=offload_buffers,
|
|
4153
4234
|
)
|
|
4154
4235
|
)
|
|
4155
4236
|
|
|
@@ -4160,12 +4241,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4160
4241
|
# Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
|
|
4161
4242
|
model.mark_tied_weights_as_initialized()
|
|
4162
4243
|
|
|
4163
|
-
# Move missing (and potentially mismatched) keys back to
|
|
4164
|
-
# loading the weights as they
|
|
4165
|
-
|
|
4166
|
-
model.
|
|
4244
|
+
# Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
|
|
4245
|
+
# meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
|
|
4246
|
+
missing_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
|
|
4247
|
+
model._move_missing_keys_from_meta_to_device(missing_and_mismatched, device_map, device_mesh, hf_quantizer)
|
|
4167
4248
|
|
|
4168
|
-
# Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `
|
|
4249
|
+
# Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
|
|
4169
4250
|
model._initialize_missing_keys(is_quantized)
|
|
4170
4251
|
|
|
4171
4252
|
# Tie the weights
|
|
@@ -4174,34 +4255,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4174
4255
|
# Adjust missing and unexpected keys
|
|
4175
4256
|
missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys)
|
|
4176
4257
|
|
|
4177
|
-
# Post-processing for tensor parallelism
|
|
4178
|
-
if device_mesh is not None:
|
|
4179
|
-
# When using TP, the device map is a single device for all parameters
|
|
4180
|
-
tp_device = list(device_map.values())[0]
|
|
4181
|
-
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
|
|
4182
|
-
# not part of the state_dict (persistent=False)
|
|
4183
|
-
for buffer in model.buffers(): # TODO to avaoid this buffer could be added to the ckpt
|
|
4184
|
-
if buffer.device != tp_device:
|
|
4185
|
-
buffer.data = buffer.to(tp_device)
|
|
4186
|
-
|
|
4187
|
-
# In this case, the top-most task module weights were not moved to device and parallelized as they
|
|
4188
|
-
# were not part of the loaded weights: do it now
|
|
4189
|
-
if missing_keys:
|
|
4190
|
-
state_dict = model.state_dict()
|
|
4191
|
-
for name in missing_keys:
|
|
4192
|
-
param = state_dict[name]
|
|
4193
|
-
# Shard the param
|
|
4194
|
-
shard_and_distribute_module(
|
|
4195
|
-
model,
|
|
4196
|
-
param.to(tp_device),
|
|
4197
|
-
param,
|
|
4198
|
-
name,
|
|
4199
|
-
None,
|
|
4200
|
-
False,
|
|
4201
|
-
device_mesh.get_local_rank(),
|
|
4202
|
-
device_mesh,
|
|
4203
|
-
)
|
|
4204
|
-
|
|
4205
4258
|
log_state_dict_report(
|
|
4206
4259
|
model=model,
|
|
4207
4260
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
@@ -4211,7 +4264,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4211
4264
|
missing_keys=missing_keys,
|
|
4212
4265
|
mismatched_keys=mismatched_keys,
|
|
4213
4266
|
mismatched_shapes=mismatched_keys,
|
|
4214
|
-
|
|
4267
|
+
conversion_errors=conversion_errors,
|
|
4215
4268
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
4216
4269
|
)
|
|
4217
4270
|
|
|
@@ -4399,33 +4452,54 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4399
4452
|
def is_backend_compatible(cls):
|
|
4400
4453
|
return cls._supports_attention_backend
|
|
4401
4454
|
|
|
4402
|
-
def
|
|
4403
|
-
self,
|
|
4455
|
+
def _move_missing_keys_from_meta_to_device(
|
|
4456
|
+
self,
|
|
4457
|
+
missing_keys: list[str],
|
|
4458
|
+
device_map: dict | None,
|
|
4459
|
+
device_mesh: "torch.distributed.device_mesh.DeviceMesh | None",
|
|
4460
|
+
hf_quantizer: HfQuantizer | None,
|
|
4404
4461
|
) -> None:
|
|
4405
|
-
"""Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts)
|
|
4406
|
-
from meta device to cpu.
|
|
4462
|
+
"""Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts)
|
|
4463
|
+
back from meta device to their device according to the `device_map` if any, else cpu. Takes care of sharding those
|
|
4464
|
+
missing parameters if `device_mesh` is provided, i.e. we are using TP.
|
|
4465
|
+
All non-persistent buffers are also moved back to the correct device (they are not part of the state_dict, but are
|
|
4466
|
+
not missing either).
|
|
4407
4467
|
"""
|
|
4408
4468
|
is_quantized = hf_quantizer is not None
|
|
4469
|
+
# This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
|
|
4470
|
+
if is_deepspeed_zero3_enabled() and not is_quantized:
|
|
4471
|
+
return
|
|
4409
4472
|
|
|
4410
4473
|
# In this case we need to move everything back
|
|
4411
4474
|
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
|
4412
|
-
# We only do it for the parameters, as the buffers are not initialized on the meta device by default
|
|
4413
4475
|
for key, param in self.named_parameters():
|
|
4414
|
-
value = torch.empty_like(param,
|
|
4476
|
+
value = torch.empty_like(param, device="cpu")
|
|
4477
|
+
_load_parameter_into_model(self, key, value)
|
|
4478
|
+
for key, buffer in self.named_buffers():
|
|
4479
|
+
value = torch.empty_like(buffer, device="cpu")
|
|
4415
4480
|
_load_parameter_into_model(self, key, value)
|
|
4416
4481
|
return
|
|
4417
4482
|
|
|
4418
|
-
model_state_dict = self.state_dict()
|
|
4419
4483
|
# The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway)
|
|
4420
4484
|
# This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they
|
|
4421
4485
|
# will be re-initialized for nothing (which can be quite long)
|
|
4422
4486
|
for key in missing_keys - self.all_tied_weights_keys.keys():
|
|
4423
|
-
param =
|
|
4424
|
-
|
|
4425
|
-
|
|
4426
|
-
|
|
4427
|
-
|
|
4428
|
-
|
|
4487
|
+
param = self.get_parameter_or_buffer(key)
|
|
4488
|
+
param_device = get_device(device_map, key, valid_torch_device=True)
|
|
4489
|
+
value = torch.empty_like(param, device=param_device)
|
|
4490
|
+
# For TP, we may need to shard the param
|
|
4491
|
+
if device_mesh is not None:
|
|
4492
|
+
shard_and_distribute_module(
|
|
4493
|
+
self, value, param, key, None, False, device_mesh.get_local_rank(), device_mesh
|
|
4494
|
+
)
|
|
4495
|
+
# Otherwise, just move it to device
|
|
4496
|
+
else:
|
|
4497
|
+
_load_parameter_into_model(self, key, value)
|
|
4498
|
+
# We need to move back non-persistent buffers as well, as they are not part of loaded weights anyway
|
|
4499
|
+
for key, buffer in self.named_non_persistent_buffers():
|
|
4500
|
+
buffer_device = get_device(device_map, key, valid_torch_device=True)
|
|
4501
|
+
value = torch.empty_like(buffer, device=buffer_device)
|
|
4502
|
+
_load_parameter_into_model(self, key, value)
|
|
4429
4503
|
|
|
4430
4504
|
def _initialize_missing_keys(self, is_quantized: bool) -> None:
|
|
4431
4505
|
"""
|
|
@@ -4453,8 +4527,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4453
4527
|
) -> tuple[set[str], set[str]]:
|
|
4454
4528
|
"""Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
|
|
4455
4529
|
raising unneeded warnings/errors.
|
|
4456
|
-
Also, set the `_is_hf_initialized` on tied weight keys, to avoid initializing them as they are going to
|
|
4457
|
-
be tied anyway.
|
|
4458
4530
|
"""
|
|
4459
4531
|
# Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
|
|
4460
4532
|
# (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
|
|
@@ -4513,6 +4585,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4513
4585
|
|
|
4514
4586
|
raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
|
|
4515
4587
|
|
|
4588
|
+
def named_non_persistent_buffers(
|
|
4589
|
+
self, recurse: bool = True, remove_duplicate: bool = True
|
|
4590
|
+
) -> Iterator[tuple[str, torch.Tensor]]:
|
|
4591
|
+
"""Similar to `named_buffers`, but only yield non-persistent ones. It is handy as it's not perfectly straightforward
|
|
4592
|
+
to know if they are persistent or not"""
|
|
4593
|
+
for name, tensor in self.named_buffers(recurse=recurse, remove_duplicate=remove_duplicate):
|
|
4594
|
+
# We have to grab the parent here, as the attribute `_non_persistent_buffers_set` is on the immediate
|
|
4595
|
+
# parent only
|
|
4596
|
+
parent, buf_name = name.rsplit(".", 1) if "." in name else ("", name)
|
|
4597
|
+
parent = self.get_submodule(parent)
|
|
4598
|
+
if buf_name in parent._non_persistent_buffers_set:
|
|
4599
|
+
yield name, tensor
|
|
4600
|
+
|
|
4516
4601
|
def train(self, mode: bool = True):
|
|
4517
4602
|
out = super().train(mode)
|
|
4518
4603
|
if self.use_kernels:
|
|
@@ -4565,6 +4650,40 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
|
|
4565
4650
|
return torch.device(device).type not in ["meta", "cpu"]
|
|
4566
4651
|
|
|
4567
4652
|
|
|
4653
|
+
def get_total_byte_count(
|
|
4654
|
+
model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: Optional[HfQuantizer] = None
|
|
4655
|
+
):
|
|
4656
|
+
"""
|
|
4657
|
+
This utility function calculates the total bytes count needed to load the model on each device.
|
|
4658
|
+
This is useful for caching_allocator_warmup as we want to know how much cache we need to pre-allocate.
|
|
4659
|
+
"""
|
|
4660
|
+
|
|
4661
|
+
total_byte_count = defaultdict(lambda: 0)
|
|
4662
|
+
tied_param_names = model.all_tied_weights_keys.keys()
|
|
4663
|
+
tp_plan = model._tp_plan if torch.distributed.is_available() and torch.distributed.is_initialized() else []
|
|
4664
|
+
|
|
4665
|
+
for param_name, device in accelerator_device_map.items():
|
|
4666
|
+
# Skip if the parameter has already been accounted for (tied weights)
|
|
4667
|
+
if param_name in tied_param_names:
|
|
4668
|
+
continue
|
|
4669
|
+
|
|
4670
|
+
param = model.get_parameter_or_buffer(param_name)
|
|
4671
|
+
|
|
4672
|
+
if hf_quantizer is not None:
|
|
4673
|
+
dtype_size = hf_quantizer.param_element_size(model, param_name, param)
|
|
4674
|
+
else:
|
|
4675
|
+
dtype_size = param.element_size()
|
|
4676
|
+
|
|
4677
|
+
param_byte_count = param.numel() * dtype_size
|
|
4678
|
+
|
|
4679
|
+
if len(tp_plan) > 0:
|
|
4680
|
+
is_part_of_plan = _get_parameter_tp_plan(param_name, tp_plan, is_weight=True) is not None
|
|
4681
|
+
param_byte_count //= torch.distributed.get_world_size() if is_part_of_plan else 1
|
|
4682
|
+
|
|
4683
|
+
total_byte_count[device] += param_byte_count
|
|
4684
|
+
return total_byte_count
|
|
4685
|
+
|
|
4686
|
+
|
|
4568
4687
|
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: Optional[HfQuantizer]):
|
|
4569
4688
|
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
|
4570
4689
|
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
|
@@ -4584,8 +4703,6 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4584
4703
|
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
|
|
4585
4704
|
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
|
|
4586
4705
|
"""
|
|
4587
|
-
factor = 2 if hf_quantizer is None else hf_quantizer.get_accelerator_warm_up_factor()
|
|
4588
|
-
|
|
4589
4706
|
# Remove disk, cpu and meta devices, and cast to proper torch.device
|
|
4590
4707
|
accelerator_device_map = {
|
|
4591
4708
|
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
|
|
@@ -4593,40 +4710,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4593
4710
|
if not accelerator_device_map:
|
|
4594
4711
|
return
|
|
4595
4712
|
|
|
4596
|
-
|
|
4597
|
-
tp_plan_regex = (
|
|
4598
|
-
re.compile("|".join([re.escape(plan) for plan in tp_plan]))
|
|
4599
|
-
if _torch_distributed_available and torch.distributed.is_initialized()
|
|
4600
|
-
else None
|
|
4601
|
-
)
|
|
4602
|
-
total_byte_count = defaultdict(lambda: 0)
|
|
4603
|
-
tied_param_names = model.all_tied_weights_keys.keys()
|
|
4604
|
-
for param_name, device in accelerator_device_map.items():
|
|
4605
|
-
# Skip if the parameter has already been accounted for (tied weights)
|
|
4606
|
-
if param_name in tied_param_names:
|
|
4607
|
-
continue
|
|
4608
|
-
|
|
4609
|
-
# For example in the case of MXFP4 quantization, we need to update the param name to the original param name
|
|
4610
|
-
# because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
|
|
4611
|
-
if hf_quantizer is not None:
|
|
4612
|
-
param_name = hf_quantizer.get_param_name(param_name)
|
|
4613
|
-
|
|
4614
|
-
try:
|
|
4615
|
-
param = model.get_parameter_or_buffer(param_name)
|
|
4616
|
-
except AttributeError:
|
|
4617
|
-
# TODO: for now let's skip if we can't find the parameters
|
|
4618
|
-
if hf_quantizer is not None:
|
|
4619
|
-
continue
|
|
4620
|
-
raise AttributeError(f"Parameter {param_name} not found in model")
|
|
4621
|
-
|
|
4622
|
-
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
|
4623
|
-
param_byte_count = param.numel() * param.element_size()
|
|
4624
|
-
|
|
4625
|
-
if tp_plan_regex is not None:
|
|
4626
|
-
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
|
|
4627
|
-
param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
|
|
4628
|
-
|
|
4629
|
-
total_byte_count[device] += param_byte_count
|
|
4713
|
+
total_byte_count = get_total_byte_count(model, accelerator_device_map, hf_quantizer)
|
|
4630
4714
|
|
|
4631
4715
|
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
|
4632
4716
|
for device, byte_count in total_byte_count.items():
|
|
@@ -4646,9 +4730,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4646
4730
|
unused_memory = torch_accelerator_module.memory_reserved(
|
|
4647
4731
|
index
|
|
4648
4732
|
) - torch_accelerator_module.memory_allocated(index)
|
|
4649
|
-
byte_count = max(0, byte_count - unused_memory)
|
|
4650
|
-
#
|
|
4651
|
-
_ = torch.empty(byte_count //
|
|
4733
|
+
byte_count = int(max(0, byte_count - unused_memory))
|
|
4734
|
+
# We divide by 2 here as we allocate in fp16
|
|
4735
|
+
_ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
|
|
4652
4736
|
|
|
4653
4737
|
|
|
4654
4738
|
class AttentionInterface(GeneralInterface):
|