transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
transformers/modeling_utils.py
CHANGED
|
@@ -24,8 +24,9 @@ import sys
|
|
|
24
24
|
import warnings
|
|
25
25
|
from abc import abstractmethod
|
|
26
26
|
from collections import defaultdict
|
|
27
|
-
from collections.abc import Callable, Iterator
|
|
27
|
+
from collections.abc import Callable, Iterator
|
|
28
28
|
from contextlib import contextmanager
|
|
29
|
+
from dataclasses import dataclass, field
|
|
29
30
|
from enum import Enum
|
|
30
31
|
from functools import partial, wraps
|
|
31
32
|
from itertools import cycle
|
|
@@ -77,9 +78,8 @@ from .integrations.tensor_parallel import (
|
|
|
77
78
|
ALL_PARALLEL_STYLES,
|
|
78
79
|
_get_parameter_tp_plan,
|
|
79
80
|
distribute_model,
|
|
81
|
+
gather_state_dict_for_save,
|
|
80
82
|
initialize_tensor_parallelism,
|
|
81
|
-
repack_weights,
|
|
82
|
-
replace_state_dict_local_with_dtensor,
|
|
83
83
|
shard_and_distribute_module,
|
|
84
84
|
verify_tp_plan,
|
|
85
85
|
)
|
|
@@ -106,25 +106,26 @@ from .utils import (
|
|
|
106
106
|
copy_func,
|
|
107
107
|
has_file,
|
|
108
108
|
is_accelerate_available,
|
|
109
|
+
is_bitsandbytes_available,
|
|
110
|
+
is_env_variable_true,
|
|
109
111
|
is_flash_attn_2_available,
|
|
110
112
|
is_flash_attn_3_available,
|
|
111
113
|
is_grouped_mm_available,
|
|
112
114
|
is_kernels_available,
|
|
113
115
|
is_torch_flex_attn_available,
|
|
114
|
-
is_torch_greater_or_equal,
|
|
115
116
|
is_torch_mlu_available,
|
|
116
117
|
is_torch_npu_available,
|
|
117
118
|
is_torch_xpu_available,
|
|
118
119
|
logging,
|
|
119
120
|
)
|
|
120
|
-
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
|
|
121
|
+
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder, is_flash_attention_requested
|
|
121
122
|
from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
|
|
122
123
|
from .utils.import_utils import (
|
|
123
124
|
is_huggingface_hub_greater_or_equal,
|
|
124
125
|
is_sagemaker_mp_enabled,
|
|
125
126
|
is_tracing,
|
|
126
127
|
)
|
|
127
|
-
from .utils.loading_report import log_state_dict_report
|
|
128
|
+
from .utils.loading_report import LoadStateDictInfo, log_state_dict_report
|
|
128
129
|
from .utils.quantization_config import QuantizationMethod
|
|
129
130
|
|
|
130
131
|
|
|
@@ -134,9 +135,6 @@ if is_accelerate_available():
|
|
|
134
135
|
|
|
135
136
|
|
|
136
137
|
_torch_distributed_available = torch.distributed.is_available()
|
|
137
|
-
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
|
|
138
|
-
if _is_dtensor_available:
|
|
139
|
-
from torch.distributed.tensor import DTensor
|
|
140
138
|
|
|
141
139
|
if is_sagemaker_mp_enabled():
|
|
142
140
|
import smdistributed.modelparallel.torch as smp
|
|
@@ -162,6 +160,33 @@ FLASH_ATTN_KERNEL_FALLBACK = {
|
|
|
162
160
|
}
|
|
163
161
|
|
|
164
162
|
|
|
163
|
+
@dataclass(frozen=True)
|
|
164
|
+
class LoadStateDictConfig:
|
|
165
|
+
"""
|
|
166
|
+
Config for loading weights. This allows bundling arguments that are just
|
|
167
|
+
passed around.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
pretrained_model_name_or_path: str | None = None
|
|
171
|
+
download_kwargs: DownloadKwargs | None = field(default_factory=DownloadKwargs)
|
|
172
|
+
use_safetensors: bool | None = None
|
|
173
|
+
ignore_mismatched_sizes: bool = False
|
|
174
|
+
sharded_metadata: dict | None = None
|
|
175
|
+
device_map: dict | None = None
|
|
176
|
+
disk_offload_folder: str | None = None
|
|
177
|
+
offload_buffers: bool = False
|
|
178
|
+
dtype: torch.dtype | None = None
|
|
179
|
+
dtype_plan: dict = field(default_factory=dict)
|
|
180
|
+
hf_quantizer: HfQuantizer | None = None
|
|
181
|
+
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None
|
|
182
|
+
weights_only: bool = True
|
|
183
|
+
weight_mapping: list[WeightConverter | WeightRenaming] | None = None
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def is_quantized(self) -> bool:
|
|
187
|
+
return self.hf_quantizer is not None
|
|
188
|
+
|
|
189
|
+
|
|
165
190
|
def is_local_dist_rank_0():
|
|
166
191
|
return (
|
|
167
192
|
torch.distributed.is_available()
|
|
@@ -223,8 +248,7 @@ def get_torch_context_manager_or_global_device():
|
|
|
223
248
|
is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
|
|
224
249
|
"""
|
|
225
250
|
device_in_context = torch.tensor([]).device
|
|
226
|
-
|
|
227
|
-
default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu")
|
|
251
|
+
default_device = torch.get_default_device()
|
|
228
252
|
# This case means no context manager was used -> we still check if the default that was potentially set is not cpu
|
|
229
253
|
if device_in_context == default_device:
|
|
230
254
|
if default_device != torch.device("cpu"):
|
|
@@ -252,23 +276,20 @@ str_to_torch_dtype = {
|
|
|
252
276
|
"U8": torch.uint8,
|
|
253
277
|
"I8": torch.int8,
|
|
254
278
|
"I16": torch.int16,
|
|
279
|
+
"U16": torch.uint16,
|
|
255
280
|
"F16": torch.float16,
|
|
256
281
|
"BF16": torch.bfloat16,
|
|
257
282
|
"I32": torch.int32,
|
|
283
|
+
"U32": torch.uint32,
|
|
258
284
|
"F32": torch.float32,
|
|
259
285
|
"F64": torch.float64,
|
|
260
286
|
"I64": torch.int64,
|
|
287
|
+
"U64": torch.uint64,
|
|
261
288
|
"F8_E4M3": torch.float8_e4m3fn,
|
|
262
289
|
"F8_E5M2": torch.float8_e5m2,
|
|
263
290
|
}
|
|
264
291
|
|
|
265
292
|
|
|
266
|
-
if is_torch_greater_or_equal("2.3.0"):
|
|
267
|
-
str_to_torch_dtype["U16"] = torch.uint16
|
|
268
|
-
str_to_torch_dtype["U32"] = torch.uint32
|
|
269
|
-
str_to_torch_dtype["U64"] = torch.uint64
|
|
270
|
-
|
|
271
|
-
|
|
272
293
|
def load_state_dict(
|
|
273
294
|
checkpoint_file: str | os.PathLike, map_location: str | torch.device = "cpu", weights_only: bool = True
|
|
274
295
|
) -> dict[str, torch.Tensor]:
|
|
@@ -472,15 +493,16 @@ def _get_resolved_checkpoint_files(
|
|
|
472
493
|
variant: str | None,
|
|
473
494
|
gguf_file: str | None,
|
|
474
495
|
use_safetensors: bool | None,
|
|
475
|
-
|
|
476
|
-
user_agent: dict,
|
|
496
|
+
user_agent: dict | None,
|
|
477
497
|
is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
|
|
478
498
|
transformers_explicit_filename: str | None = None,
|
|
499
|
+
download_kwargs: DownloadKwargs | None = None,
|
|
479
500
|
) -> tuple[list[str] | None, dict | None]:
|
|
480
501
|
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
|
481
502
|
checkpoints are sharded.
|
|
482
503
|
This function will download the data if necessary.
|
|
483
504
|
"""
|
|
505
|
+
download_kwargs = download_kwargs or DownloadKwargs()
|
|
484
506
|
cache_dir = download_kwargs.get("cache_dir")
|
|
485
507
|
force_download = download_kwargs.get("force_download", False)
|
|
486
508
|
proxies = download_kwargs.get("proxies")
|
|
@@ -493,17 +515,19 @@ def _get_resolved_checkpoint_files(
|
|
|
493
515
|
if not transformers_explicit_filename.endswith(".safetensors") and not transformers_explicit_filename.endswith(
|
|
494
516
|
".safetensors.index.json"
|
|
495
517
|
):
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
518
|
+
if transformers_explicit_filename != "adapter_model.bin":
|
|
519
|
+
raise ValueError(
|
|
520
|
+
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
|
|
521
|
+
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
|
|
522
|
+
f"{transformers_explicit_filename}"
|
|
523
|
+
)
|
|
501
524
|
|
|
502
525
|
is_sharded = False
|
|
503
526
|
|
|
504
527
|
if pretrained_model_name_or_path is not None and gguf_file is None:
|
|
505
528
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
|
506
529
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
|
530
|
+
# If the file is a local folder (but not in the HF_HOME cache, even if it's technically local)
|
|
507
531
|
if is_local:
|
|
508
532
|
if transformers_explicit_filename is not None:
|
|
509
533
|
# If the filename is explicitly defined, load this by default.
|
|
@@ -562,25 +586,38 @@ def _get_resolved_checkpoint_files(
|
|
|
562
586
|
else:
|
|
563
587
|
filename = _add_variant(WEIGHTS_NAME, variant)
|
|
564
588
|
|
|
589
|
+
# Prepare set of kwargs for hub functions
|
|
590
|
+
has_file_kwargs = {
|
|
591
|
+
"revision": revision,
|
|
592
|
+
"proxies": proxies,
|
|
593
|
+
"token": token,
|
|
594
|
+
"cache_dir": cache_dir,
|
|
595
|
+
"local_files_only": local_files_only,
|
|
596
|
+
}
|
|
597
|
+
cached_file_kwargs = {
|
|
598
|
+
"force_download": force_download,
|
|
599
|
+
"user_agent": user_agent,
|
|
600
|
+
"subfolder": subfolder,
|
|
601
|
+
"_raise_exceptions_for_gated_repo": False,
|
|
602
|
+
"_raise_exceptions_for_missing_entries": False,
|
|
603
|
+
"_commit_hash": commit_hash,
|
|
604
|
+
**has_file_kwargs,
|
|
605
|
+
}
|
|
606
|
+
can_auto_convert = (
|
|
607
|
+
not is_offline_mode() # for obvious reasons
|
|
608
|
+
# If we are in a CI environment or in a pytest run, we prevent the conversion
|
|
609
|
+
and not is_env_variable_true("DISABLE_SAFETENSORS_CONVERSION")
|
|
610
|
+
and not is_remote_code # converter bot does not work on remote code
|
|
611
|
+
and subfolder == "" # converter bot does not work on subfolders
|
|
612
|
+
)
|
|
613
|
+
|
|
565
614
|
try:
|
|
566
615
|
# Load from URL or cache if already cached
|
|
567
|
-
cached_file_kwargs = {
|
|
568
|
-
"cache_dir": cache_dir,
|
|
569
|
-
"force_download": force_download,
|
|
570
|
-
"proxies": proxies,
|
|
571
|
-
"local_files_only": local_files_only,
|
|
572
|
-
"token": token,
|
|
573
|
-
"user_agent": user_agent,
|
|
574
|
-
"revision": revision,
|
|
575
|
-
"subfolder": subfolder,
|
|
576
|
-
"_raise_exceptions_for_gated_repo": False,
|
|
577
|
-
"_raise_exceptions_for_missing_entries": False,
|
|
578
|
-
"_commit_hash": commit_hash,
|
|
579
|
-
}
|
|
580
|
-
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
|
581
|
-
|
|
582
616
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
|
583
617
|
# result when internet is up, the repo and revision exist, but the file does not.
|
|
618
|
+
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
|
619
|
+
|
|
620
|
+
# Try safetensors files first if not already found
|
|
584
621
|
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
|
|
585
622
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
|
586
623
|
resolved_archive_file = cached_file(
|
|
@@ -591,7 +628,7 @@ def _get_resolved_checkpoint_files(
|
|
|
591
628
|
if resolved_archive_file is not None:
|
|
592
629
|
is_sharded = True
|
|
593
630
|
elif use_safetensors:
|
|
594
|
-
if revision == "main" and
|
|
631
|
+
if revision == "main" and can_auto_convert:
|
|
595
632
|
resolved_archive_file, revision, is_sharded = auto_conversion(
|
|
596
633
|
pretrained_model_name_or_path, **cached_file_kwargs
|
|
597
634
|
)
|
|
@@ -608,6 +645,8 @@ def _get_resolved_checkpoint_files(
|
|
|
608
645
|
resolved_archive_file = cached_file(
|
|
609
646
|
pretrained_model_name_or_path, filename, **cached_file_kwargs
|
|
610
647
|
)
|
|
648
|
+
|
|
649
|
+
# Then try `.bin` files
|
|
611
650
|
if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
|
|
612
651
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
|
613
652
|
resolved_archive_file = cached_file(
|
|
@@ -617,67 +656,38 @@ def _get_resolved_checkpoint_files(
|
|
|
617
656
|
)
|
|
618
657
|
if resolved_archive_file is not None:
|
|
619
658
|
is_sharded = True
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
**has_file_kwargs,
|
|
648
|
-
}
|
|
649
|
-
if (
|
|
650
|
-
not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
|
|
651
|
-
and not is_remote_code
|
|
652
|
-
):
|
|
653
|
-
Thread(
|
|
654
|
-
target=auto_conversion,
|
|
655
|
-
args=(pretrained_model_name_or_path,),
|
|
656
|
-
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
|
|
657
|
-
name="Thread-auto_conversion",
|
|
658
|
-
).start()
|
|
659
|
+
|
|
660
|
+
# If we have a match, but it's `.bin` format, try to launch safetensors conversion for next time
|
|
661
|
+
if resolved_archive_file is not None:
|
|
662
|
+
safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
|
|
663
|
+
if (
|
|
664
|
+
filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
|
|
665
|
+
and not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
|
|
666
|
+
and can_auto_convert
|
|
667
|
+
):
|
|
668
|
+
Thread(
|
|
669
|
+
target=auto_conversion,
|
|
670
|
+
args=(pretrained_model_name_or_path,),
|
|
671
|
+
kwargs={"ignore_errors_during_conversion": False, **cached_file_kwargs},
|
|
672
|
+
name="Thread-auto_conversion",
|
|
673
|
+
).start()
|
|
674
|
+
|
|
675
|
+
# If no match, raise appropriare errors
|
|
676
|
+
else:
|
|
677
|
+
# Otherwise, no PyTorch file was found
|
|
678
|
+
if variant is not None and has_file(
|
|
679
|
+
pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
|
|
680
|
+
):
|
|
681
|
+
raise OSError(
|
|
682
|
+
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
683
|
+
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
|
|
684
|
+
f" {variant}. Use `variant=None` to load this model from those weights."
|
|
685
|
+
)
|
|
659
686
|
else:
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
"
|
|
663
|
-
|
|
664
|
-
"token": token,
|
|
665
|
-
"cache_dir": cache_dir,
|
|
666
|
-
"local_files_only": local_files_only,
|
|
667
|
-
}
|
|
668
|
-
if variant is not None and has_file(
|
|
669
|
-
pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
|
|
670
|
-
):
|
|
671
|
-
raise OSError(
|
|
672
|
-
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
673
|
-
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
|
|
674
|
-
f" {variant}. Use `variant=None` to load this model from those weights."
|
|
675
|
-
)
|
|
676
|
-
else:
|
|
677
|
-
raise OSError(
|
|
678
|
-
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
679
|
-
f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
|
|
680
|
-
)
|
|
687
|
+
raise OSError(
|
|
688
|
+
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
689
|
+
f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
|
|
690
|
+
)
|
|
681
691
|
|
|
682
692
|
except OSError:
|
|
683
693
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
|
@@ -922,7 +932,7 @@ class ModuleUtilsMixin:
|
|
|
922
932
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
923
933
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
924
934
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
925
|
-
if self.config
|
|
935
|
+
if getattr(self.config, "is_decoder", None):
|
|
926
936
|
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
|
|
927
937
|
input_shape, attention_mask
|
|
928
938
|
)
|
|
@@ -1095,83 +1105,67 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1095
1105
|
- **can_record_outputs** (dict):
|
|
1096
1106
|
"""
|
|
1097
1107
|
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
main_input_name = "input_ids"
|
|
1101
|
-
model_tags = None
|
|
1102
|
-
|
|
1103
|
-
_checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
|
|
1104
|
-
|
|
1108
|
+
# General model properties
|
|
1109
|
+
config_class: type[PreTrainedConfig] | None = None
|
|
1105
1110
|
_auto_class = None
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
_keep_in_fp32_modules = None
|
|
1110
|
-
# the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
|
|
1111
|
-
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
|
|
1112
|
-
_keep_in_fp32_modules_strict = None
|
|
1113
|
-
|
|
1114
|
-
dtype_plan: dict[str, torch.dtype] | None = None
|
|
1115
|
-
|
|
1116
|
-
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
|
1117
|
-
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
|
1118
|
-
_keys_to_ignore_on_load_missing = None
|
|
1119
|
-
# a list of `re` patterns of `state_dict` keys that should be removed from the list of
|
|
1120
|
-
# unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
|
|
1121
|
-
# warnings.
|
|
1122
|
-
_keys_to_ignore_on_load_unexpected = None
|
|
1123
|
-
# a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
|
|
1124
|
-
# trained, but which are either deterministic or tied variables)
|
|
1125
|
-
_keys_to_ignore_on_save = None
|
|
1126
|
-
# a list of `state_dict` keys that are potentially tied to another key in the state_dict.
|
|
1127
|
-
_tied_weights_keys = None
|
|
1128
|
-
|
|
1129
|
-
supports_gradient_checkpointing = False
|
|
1130
|
-
_is_stateful = False
|
|
1131
|
-
|
|
1132
|
-
# Flash Attention support
|
|
1133
|
-
_supports_flash_attn = False
|
|
1134
|
-
|
|
1135
|
-
# SDPA support
|
|
1136
|
-
_supports_sdpa = False
|
|
1137
|
-
|
|
1138
|
-
# Flex Attention support
|
|
1139
|
-
_supports_flex_attn = False
|
|
1140
|
-
|
|
1141
|
-
_can_compile_fullgraph = False
|
|
1142
|
-
|
|
1143
|
-
# A tensor parallel plan to be applied to the model when TP is enabled. For
|
|
1144
|
-
# top-level models, this attribute is currently defined in respective model
|
|
1145
|
-
# code. For base models, this attribute comes from
|
|
1146
|
-
# `config.base_model_tp_plan` during `__init__`.
|
|
1147
|
-
# It should identify the layers exactly: if you want to TP model.language_model.layers.fc1
|
|
1148
|
-
# by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"}
|
|
1149
|
-
# for example.
|
|
1150
|
-
_tp_plan = None
|
|
1151
|
-
|
|
1152
|
-
# tensor parallel degree to which model is sharded to.
|
|
1153
|
-
_tp_size = None
|
|
1154
|
-
|
|
1155
|
-
# A pipeline parallel plan specifying the layers which may not be present
|
|
1156
|
-
# on all ranks when PP is enabled. For top-level models, this attribute is
|
|
1157
|
-
# currently defined in respective model code. For base models, this
|
|
1158
|
-
# attribute comes from `config.base_model_pp_plan` during `post_init`.
|
|
1159
|
-
#
|
|
1160
|
-
# The variable names for the inputs and outputs of the specified layers can
|
|
1161
|
-
# be indexed using the `PipelineParallel` enum as follows:
|
|
1162
|
-
# - `_pp_plan["layers"][PipelineParallel.inputs]`
|
|
1163
|
-
# - `_pp_plan["layers"][PipelineParallel.outputs]`
|
|
1164
|
-
_pp_plan = None
|
|
1111
|
+
base_model_prefix: str = ""
|
|
1112
|
+
_is_stateful: bool = False
|
|
1113
|
+
model_tags: list[str] | None = None
|
|
1165
1114
|
|
|
1115
|
+
# Input-related properties
|
|
1116
|
+
main_input_name: str = "input_ids"
|
|
1117
|
+
# Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
|
|
1118
|
+
# Possible values are: text, image, video, audio and time
|
|
1119
|
+
input_modalities: str | list[str] = "text"
|
|
1120
|
+
|
|
1121
|
+
# Device-map related properties
|
|
1122
|
+
_no_split_modules: set[str] | list[str] | None = None
|
|
1123
|
+
_skip_keys_device_placement: str | list[str] | None = None
|
|
1124
|
+
|
|
1125
|
+
# Specific dtype upcasting
|
|
1126
|
+
# `_keep_in_fp32_modules` will upcast to fp32 only if the requested dtype is fp16
|
|
1127
|
+
# `_keep_in_fp32_modules_strict` will upcast to fp32 independently if the requested dtype is fp16 or bf16
|
|
1128
|
+
_keep_in_fp32_modules: set[str] | list[str] | None = None
|
|
1129
|
+
_keep_in_fp32_modules_strict: set[str] | list[str] | None = None
|
|
1130
|
+
|
|
1131
|
+
# Loading-specific properties
|
|
1132
|
+
# A dictionary `{"target": "source"}` of checkpoint keys that are potentially tied to one another
|
|
1133
|
+
_tied_weights_keys: dict[str, str] = None
|
|
1134
|
+
# Used for BC support in VLMs, not meant to be used by new models
|
|
1135
|
+
_checkpoint_conversion_mapping: dict[str, str] = {}
|
|
1136
|
+
# A list of `re` patterns describing keys to ignore if they are missing from checkpoints to avoid warnings
|
|
1137
|
+
_keys_to_ignore_on_load_missing: list[str] | None = None
|
|
1138
|
+
# A list of `re` patterns describing keys to ignore if they are unexpected in the checkpoints to avoid warnings
|
|
1139
|
+
_keys_to_ignore_on_load_unexpected: list[str] | None = None
|
|
1140
|
+
# A list of keys to ignore when saving the model
|
|
1141
|
+
_keys_to_ignore_on_save: list[str] | None = None
|
|
1142
|
+
|
|
1143
|
+
# Attention interfaces support properties
|
|
1144
|
+
_supports_sdpa: bool = False
|
|
1145
|
+
_supports_flash_attn: bool = False
|
|
1146
|
+
_supports_flex_attn: bool = False
|
|
1147
|
+
|
|
1148
|
+
# Tensor-parallelism-related properties
|
|
1149
|
+
# A tensor parallel plan of the form `{"model.layer.mlp.param": "colwise"}` to be applied to the model when TP is enabled.
|
|
1150
|
+
# For top-level models, this attribute is currently defined in respective model code. For base models, this attribute comes
|
|
1151
|
+
# from `config.base_model_tp_plan` during `post_init`.
|
|
1152
|
+
_tp_plan: dict[str, str] = None
|
|
1153
|
+
# Tensor parallel degree to which model is sharded to
|
|
1154
|
+
_tp_size = None
|
|
1155
|
+
# A pipeline parallel plan specifying the layers which may not be present on all ranks when PP is enabled. For top-level
|
|
1156
|
+
# models, this attribute is currently defined in respective model code. For base models, it comes from
|
|
1157
|
+
# `config.base_model_pp_plan` during `post_init`.
|
|
1158
|
+
_pp_plan: dict[str, PipelineParallel] | None = None
|
|
1159
|
+
|
|
1160
|
+
# Advanced functionalities support
|
|
1161
|
+
supports_gradient_checkpointing: bool = False
|
|
1162
|
+
_can_compile_fullgraph: bool = False
|
|
1166
1163
|
# This flag signal that the model can be used as an efficient backend in TGI and vLLM
|
|
1167
1164
|
# In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
|
|
1168
1165
|
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
|
|
1169
|
-
_supports_attention_backend = False
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
# Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
|
|
1173
|
-
# Possible values are: text, image, video, audio and time
|
|
1174
|
-
input_modalities: str | list[str] = "text" # most models are text
|
|
1166
|
+
_supports_attention_backend: bool = False
|
|
1167
|
+
# A mapping describing what outputs can be captured by `check_model_inputs` decorator during the forward pass
|
|
1168
|
+
_can_record_outputs: dict | None = None
|
|
1175
1169
|
|
|
1176
1170
|
@property
|
|
1177
1171
|
@torch._dynamo.allow_in_graph
|
|
@@ -1256,6 +1250,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1256
1250
|
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
|
1257
1251
|
)
|
|
1258
1252
|
self.config = config
|
|
1253
|
+
self.name_or_path = config.name_or_path
|
|
1259
1254
|
|
|
1260
1255
|
# Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
|
|
1261
1256
|
# setting it recursively)
|
|
@@ -1281,38 +1276,33 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1281
1276
|
loss_type = None
|
|
1282
1277
|
self.loss_type = loss_type
|
|
1283
1278
|
|
|
1284
|
-
self.name_or_path = config.name_or_path
|
|
1285
|
-
self.warnings_issued = {}
|
|
1286
|
-
# Overwrite the class attribute to make it an instance attribute, so models like
|
|
1287
|
-
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
|
|
1288
|
-
# when a different component (e.g. language_model) is used.
|
|
1289
|
-
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
|
|
1290
|
-
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
|
1291
|
-
self.dtype_plan = {}
|
|
1292
|
-
|
|
1293
|
-
if isinstance(self._keep_in_fp32_modules, list):
|
|
1294
|
-
self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
|
|
1295
|
-
if isinstance(self._keep_in_fp32_modules_strict, list):
|
|
1296
|
-
self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
|
|
1297
|
-
|
|
1298
|
-
self._no_split_modules = self._no_split_modules or []
|
|
1299
1279
|
_CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
|
|
1300
1280
|
|
|
1301
1281
|
def post_init(self):
|
|
1302
1282
|
"""
|
|
1303
1283
|
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
|
|
1304
1284
|
modules properly initialized (such as weight initialization).
|
|
1285
|
+
It is also used to obtain all correct static properties (parallelism plans, tied_weights_keys, _keep_in_fp32_modules, etc)
|
|
1286
|
+
correctly in the case of composite models (that is, the top level model should know about those properties from its children).
|
|
1305
1287
|
"""
|
|
1306
1288
|
# Attach the different parallel plans and tied weight keys to the top-most model, so that everything is
|
|
1307
1289
|
# easily available
|
|
1308
1290
|
self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {}
|
|
1309
|
-
# Current submodel should register its tied weights
|
|
1310
|
-
self.all_tied_weights_keys = self.get_expanded_tied_weights_keys(all_submodels=False)
|
|
1311
1291
|
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
|
1312
1292
|
if self.base_model is self:
|
|
1313
1293
|
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
|
|
1314
1294
|
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
|
1315
1295
|
self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
|
|
1296
|
+
# Current submodel should register its tied weights
|
|
1297
|
+
self.all_tied_weights_keys = self.get_expanded_tied_weights_keys(all_submodels=False)
|
|
1298
|
+
# Current submodel should register its `_keep_in_fp32_modules`
|
|
1299
|
+
self._keep_in_fp32_modules = set(self._keep_in_fp32_modules or [])
|
|
1300
|
+
self._keep_in_fp32_modules_strict = set(self._keep_in_fp32_modules_strict or [])
|
|
1301
|
+
# Current submodel must register its `_no_split_modules` as well
|
|
1302
|
+
self._no_split_modules = set(self._no_split_modules or [])
|
|
1303
|
+
|
|
1304
|
+
# Iterate over children only: as the final model is created, this is enough to gather the properties from all submodels.
|
|
1305
|
+
# This works because the way the `__init__` and `post_init` are called on all submodules is depth-first in the graph
|
|
1316
1306
|
for name, module in self.named_children():
|
|
1317
1307
|
# Parallel plans
|
|
1318
1308
|
if plan := getattr(module, "_ep_plan", None):
|
|
@@ -1324,6 +1314,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1324
1314
|
# Always attach the keys of the children (if the children's config says to NOT tie, then it's empty)
|
|
1325
1315
|
if tied_keys := getattr(module, "all_tied_weights_keys", None):
|
|
1326
1316
|
self.all_tied_weights_keys.update({f"{name}.{k}": f"{name}.{v}" for k, v in tied_keys.copy().items()})
|
|
1317
|
+
# Record keep_in_fp_32 modules from the children as well
|
|
1318
|
+
if keep_fp32 := getattr(module, "_keep_in_fp32_modules", None):
|
|
1319
|
+
self._keep_in_fp32_modules.update(keep_fp32)
|
|
1320
|
+
if keep_fp32_strict := getattr(module, "_keep_in_fp32_modules_strict", None):
|
|
1321
|
+
self._keep_in_fp32_modules_strict.update(keep_fp32_strict)
|
|
1322
|
+
# Record `_no_split_modules` from the children
|
|
1323
|
+
if no_split := getattr(module, "_no_split_modules", None):
|
|
1324
|
+
self._no_split_modules.update(no_split)
|
|
1327
1325
|
|
|
1328
1326
|
# Maybe initialize the weights and tie the keys
|
|
1329
1327
|
self.init_weights()
|
|
@@ -1842,7 +1840,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1842
1840
|
)
|
|
1843
1841
|
|
|
1844
1842
|
# preload flash attention here to allow compile with fullgraph
|
|
1845
|
-
if
|
|
1843
|
+
if is_flash_attention_requested(requested_attention_implementation=applicable_attn_implementation):
|
|
1846
1844
|
lazy_import_flash_attention(applicable_attn_implementation)
|
|
1847
1845
|
|
|
1848
1846
|
return applicable_attn_implementation
|
|
@@ -1919,15 +1917,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1919
1917
|
"""Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
|
|
1920
1918
|
opening the file, but avoids maintaining yet another property flag.
|
|
1921
1919
|
"""
|
|
1922
|
-
|
|
1923
|
-
|
|
1920
|
+
class_module = sys.modules[cls.__module__]
|
|
1921
|
+
# This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
|
|
1922
|
+
if not hasattr(class_module, "__file__"):
|
|
1923
|
+
return False
|
|
1924
|
+
class_file = class_module.__file__
|
|
1925
|
+
with open(class_file, "r", encoding="utf-8") as f:
|
|
1924
1926
|
code = f.read()
|
|
1925
1927
|
# heuristic -> if we find those patterns, the model uses the correct interface
|
|
1926
1928
|
if re.search(r"class \w+Attention\(nn.Module\)", code):
|
|
1927
|
-
return (
|
|
1928
|
-
"eager_attention_forward" in code
|
|
1929
|
-
and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code
|
|
1930
|
-
)
|
|
1929
|
+
return "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS.get_interface(" in code
|
|
1931
1930
|
else:
|
|
1932
1931
|
# If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
|
|
1933
1932
|
return True
|
|
@@ -1937,8 +1936,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1937
1936
|
"""Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
|
|
1938
1937
|
opening the file, but avoids maintaining yet another property flag.
|
|
1939
1938
|
"""
|
|
1940
|
-
|
|
1941
|
-
|
|
1939
|
+
class_module = sys.modules[cls.__module__]
|
|
1940
|
+
# This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
|
|
1941
|
+
if not hasattr(class_module, "__file__"):
|
|
1942
|
+
return False
|
|
1943
|
+
class_file = class_module.__file__
|
|
1944
|
+
with open(class_file, "r", encoding="utf-8") as f:
|
|
1942
1945
|
code = f.read()
|
|
1943
1946
|
# heuristic -> if we the use_experts_implementation decorator is used, then we can set it
|
|
1944
1947
|
return "@use_experts_implementation" in code
|
|
@@ -2404,7 +2407,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2404
2407
|
|
|
2405
2408
|
tied_mapping = self._tied_weights_keys
|
|
2406
2409
|
# If the config does not specify any tying, return empty dict
|
|
2407
|
-
|
|
2410
|
+
# NOTE: not all modules have `tie_word_embeddings` attr, for example vision-only
|
|
2411
|
+
# modules do not have any word embeddings!
|
|
2412
|
+
tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False)
|
|
2413
|
+
if not tie_word_embeddings:
|
|
2408
2414
|
return {}
|
|
2409
2415
|
# If None, return empty dict
|
|
2410
2416
|
elif tied_mapping is None:
|
|
@@ -2542,35 +2548,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2542
2548
|
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
|
2543
2549
|
output_embeddings.out_features = input_embeddings.num_embeddings
|
|
2544
2550
|
|
|
2545
|
-
def _get_no_split_modules(self, device_map: str):
|
|
2546
|
-
"""
|
|
2547
|
-
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
|
2548
|
-
get the underlying `_no_split_modules`.
|
|
2549
|
-
|
|
2550
|
-
Args:
|
|
2551
|
-
device_map (`str`):
|
|
2552
|
-
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
|
2553
|
-
|
|
2554
|
-
Returns:
|
|
2555
|
-
`list[str]`: List of modules that should not be split
|
|
2556
|
-
"""
|
|
2557
|
-
_no_split_modules = set()
|
|
2558
|
-
modules_to_check = [self]
|
|
2559
|
-
while len(modules_to_check) > 0:
|
|
2560
|
-
module = modules_to_check.pop(-1)
|
|
2561
|
-
# if the module does not appear in _no_split_modules, we also check the children
|
|
2562
|
-
if module.__class__.__name__ not in _no_split_modules:
|
|
2563
|
-
if isinstance(module, PreTrainedModel):
|
|
2564
|
-
if module._no_split_modules is None:
|
|
2565
|
-
raise ValueError(
|
|
2566
|
-
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
|
2567
|
-
"class needs to implement the `_no_split_modules` attribute."
|
|
2568
|
-
)
|
|
2569
|
-
else:
|
|
2570
|
-
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
|
2571
|
-
modules_to_check += list(module.children())
|
|
2572
|
-
return list(_no_split_modules)
|
|
2573
|
-
|
|
2574
2551
|
def resize_token_embeddings(
|
|
2575
2552
|
self,
|
|
2576
2553
|
new_num_tokens: int | None = None,
|
|
@@ -2654,10 +2631,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2654
2631
|
new_num_tokens = new_embeddings.weight.shape[0]
|
|
2655
2632
|
|
|
2656
2633
|
# if word embeddings are not tied, make sure that lm head is resized as well
|
|
2657
|
-
if (
|
|
2658
|
-
self.get_output_embeddings() is not None
|
|
2659
|
-
and not self.config.get_text_config(decoder=True).tie_word_embeddings
|
|
2660
|
-
):
|
|
2634
|
+
if self.get_output_embeddings() is not None:
|
|
2661
2635
|
old_lm_head = self.get_output_embeddings()
|
|
2662
2636
|
if isinstance(old_lm_head, torch.nn.Embedding):
|
|
2663
2637
|
new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
|
|
@@ -3038,15 +3012,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3038
3012
|
|
|
3039
3013
|
def init_weights(self):
|
|
3040
3014
|
"""
|
|
3041
|
-
|
|
3015
|
+
Initialize and tie the weights if needed. If using a custom `PreTrainedModel`, you need to implement any
|
|
3042
3016
|
initialization logic in `_init_weights`.
|
|
3043
3017
|
"""
|
|
3044
3018
|
# If we are initializing on meta device, there is no point in trying to run inits
|
|
3045
3019
|
if get_torch_context_manager_or_global_device() != torch.device("meta"):
|
|
3046
3020
|
# Initialize weights
|
|
3047
3021
|
self.initialize_weights()
|
|
3048
|
-
|
|
3049
|
-
|
|
3022
|
+
# Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
|
|
3023
|
+
self.tie_weights(recompute_mapping=False)
|
|
3050
3024
|
|
|
3051
3025
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
|
3052
3026
|
"""
|
|
@@ -3063,7 +3037,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3063
3037
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
|
3064
3038
|
|
|
3065
3039
|
if gradient_checkpointing_kwargs is None:
|
|
3066
|
-
gradient_checkpointing_kwargs = {"use_reentrant":
|
|
3040
|
+
gradient_checkpointing_kwargs = {"use_reentrant": False}
|
|
3067
3041
|
|
|
3068
3042
|
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
|
3069
3043
|
|
|
@@ -3316,16 +3290,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3316
3290
|
if ignore_key in state_dict:
|
|
3317
3291
|
del state_dict[ignore_key]
|
|
3318
3292
|
|
|
3319
|
-
# If model was sharded
|
|
3320
|
-
# therefore we replace them with DTensors that are equivalently sharded
|
|
3293
|
+
# If model was sharded with TP, gather full tensors for saving
|
|
3321
3294
|
if self._tp_size is not None:
|
|
3322
|
-
state_dict =
|
|
3295
|
+
state_dict = gather_state_dict_for_save(state_dict, self._tp_plan, self._device_mesh, self._tp_size)
|
|
3323
3296
|
|
|
3324
3297
|
# Remove tied weights as safetensors do not handle them
|
|
3325
3298
|
state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)
|
|
3326
3299
|
|
|
3327
3300
|
# Revert all renaming and/or weight operations
|
|
3328
|
-
if save_original_format:
|
|
3301
|
+
if save_original_format and not _hf_peft_config_loaded:
|
|
3329
3302
|
state_dict = revert_weight_conversion(model_to_save, state_dict)
|
|
3330
3303
|
|
|
3331
3304
|
# Shard the model if it is too big.
|
|
@@ -3377,13 +3350,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3377
3350
|
# Get the tensor, and remove it from state_dict to avoid keeping the ref
|
|
3378
3351
|
tensor = state_dict.pop(tensor_name)
|
|
3379
3352
|
|
|
3380
|
-
# In case of TP, get the full parameter back
|
|
3381
|
-
if _is_dtensor_available and isinstance(tensor, DTensor):
|
|
3382
|
-
tensor = tensor.full_tensor()
|
|
3383
|
-
# to get the correctly ordered tensor we need to repack if packed
|
|
3384
|
-
if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
|
|
3385
|
-
tensor = repack_weights(tensor, -1, self._tp_size, 2)
|
|
3386
|
-
|
|
3387
3353
|
# If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
|
|
3388
3354
|
# but it would otherwise not be contained in the saved shard if we were to simply move the file
|
|
3389
3355
|
# or something
|
|
@@ -3541,10 +3507,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3541
3507
|
" desired `dtype` by passing the correct `dtype` argument."
|
|
3542
3508
|
)
|
|
3543
3509
|
|
|
3544
|
-
if getattr(self, "is_loaded_in_8bit", False):
|
|
3510
|
+
if getattr(self, "is_loaded_in_8bit", False) and not is_bitsandbytes_available("0.48"):
|
|
3545
3511
|
raise ValueError(
|
|
3546
|
-
"
|
|
3547
|
-
" model has already been set to the correct devices and casted to the correct `dtype`."
|
|
3512
|
+
"You need to install `pip install bitsandbytes>=0.48.0` if you want to move a 8-bit model across devices using to()."
|
|
3548
3513
|
)
|
|
3549
3514
|
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
|
|
3550
3515
|
if dtype_present_in_args:
|
|
@@ -3577,7 +3542,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3577
3542
|
@classmethod
|
|
3578
3543
|
def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
|
|
3579
3544
|
# Need to instantiate with correct dtype
|
|
3580
|
-
init_contexts = [local_torch_dtype(dtype, cls.__name__)]
|
|
3545
|
+
init_contexts = [local_torch_dtype(dtype, cls.__name__), init.no_tie_weights()]
|
|
3581
3546
|
if is_deepspeed_zero3_enabled():
|
|
3582
3547
|
import deepspeed
|
|
3583
3548
|
|
|
@@ -3598,7 +3563,31 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3598
3563
|
|
|
3599
3564
|
return init_contexts
|
|
3600
3565
|
|
|
3601
|
-
def
|
|
3566
|
+
def _get_dtype_plan(self, dtype: torch.dtype) -> dict:
|
|
3567
|
+
"""Create the dtype_plan describing modules/parameters that should use the `keep_in_fp32` flag."""
|
|
3568
|
+
dtype_plan = {}
|
|
3569
|
+
|
|
3570
|
+
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
|
|
3571
|
+
# in case of force loading a model that should stay in bf16 in fp16
|
|
3572
|
+
# See https://github.com/huggingface/transformers/issues/20287 for details.
|
|
3573
|
+
if self._keep_in_fp32_modules is not None and dtype == torch.float16:
|
|
3574
|
+
dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
|
|
3575
|
+
|
|
3576
|
+
# The _keep_in_fp32_modules_strict was introduced to always force upcast to fp32, for both fp16 and bf16
|
|
3577
|
+
if self._keep_in_fp32_modules_strict is not None and dtype in (torch.float16, torch.bfloat16):
|
|
3578
|
+
dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
|
|
3579
|
+
|
|
3580
|
+
return dtype_plan
|
|
3581
|
+
|
|
3582
|
+
def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None):
|
|
3583
|
+
"""
|
|
3584
|
+
Set whether or not to use the `kernels` library to kernelize some layers of the model.
|
|
3585
|
+
Args:
|
|
3586
|
+
use_kernels (`bool`):
|
|
3587
|
+
Whether or not to use the `kernels` library to kernelize some layers of the model.
|
|
3588
|
+
kernel_config (`KernelConfig`, *optional*):
|
|
3589
|
+
The kernel configuration to use to kernelize the model. If `None`, the default kernel mapping will be used.
|
|
3590
|
+
"""
|
|
3602
3591
|
if use_kernels:
|
|
3603
3592
|
if not is_kernels_available():
|
|
3604
3593
|
raise ValueError(
|
|
@@ -3641,7 +3630,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3641
3630
|
local_files_only: bool = False,
|
|
3642
3631
|
token: str | bool | None = None,
|
|
3643
3632
|
revision: str = "main",
|
|
3644
|
-
use_safetensors: bool | None =
|
|
3633
|
+
use_safetensors: bool | None = None,
|
|
3645
3634
|
weights_only: bool = True,
|
|
3646
3635
|
**kwargs,
|
|
3647
3636
|
) -> SpecificPreTrainedModelType:
|
|
@@ -4040,6 +4029,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4040
4029
|
use_kernels=use_kernels,
|
|
4041
4030
|
)
|
|
4042
4031
|
|
|
4032
|
+
# Create the dtype_plan to potentially use the `keep_in_fp32` flags (this needs to be called on the already
|
|
4033
|
+
# instantiated model, as the flags can be modified by instances sometimes)
|
|
4034
|
+
dtype_plan = model._get_dtype_plan(dtype)
|
|
4035
|
+
|
|
4043
4036
|
# Obtain the weight conversion mapping for this model if any are registered
|
|
4044
4037
|
weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
|
|
4045
4038
|
|
|
@@ -4051,29 +4044,30 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4051
4044
|
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
|
|
4052
4045
|
|
|
4053
4046
|
# Finalize model weight initialization
|
|
4054
|
-
|
|
4055
|
-
|
|
4056
|
-
state_dict,
|
|
4057
|
-
checkpoint_files,
|
|
4058
|
-
pretrained_model_name_or_path,
|
|
4047
|
+
load_config = LoadStateDictConfig(
|
|
4048
|
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
4059
4049
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
4060
4050
|
sharded_metadata=sharded_metadata,
|
|
4061
4051
|
device_map=device_map,
|
|
4062
4052
|
disk_offload_folder=offload_folder,
|
|
4063
4053
|
offload_buffers=offload_buffers,
|
|
4064
4054
|
dtype=dtype,
|
|
4055
|
+
dtype_plan=dtype_plan,
|
|
4065
4056
|
hf_quantizer=hf_quantizer,
|
|
4066
4057
|
device_mesh=device_mesh,
|
|
4067
4058
|
weights_only=weights_only,
|
|
4068
4059
|
weight_mapping=weight_conversions,
|
|
4060
|
+
use_safetensors=use_safetensors,
|
|
4061
|
+
download_kwargs=download_kwargs,
|
|
4069
4062
|
)
|
|
4070
|
-
|
|
4063
|
+
loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
|
|
4064
|
+
loading_info = cls._finalize_model_loading(model, load_config, loading_info)
|
|
4071
4065
|
model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
|
|
4072
4066
|
model.set_use_kernels(use_kernels, kernel_config)
|
|
4073
4067
|
|
|
4074
4068
|
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
|
4075
4069
|
# custom generate function)
|
|
4076
|
-
if model.can_generate() and hasattr(model, "adjust_generation_fn"):
|
|
4070
|
+
if model.can_generate() and hasattr(model, "adjust_generation_fn") and not gguf_file:
|
|
4077
4071
|
model.adjust_generation_fn(
|
|
4078
4072
|
generation_config,
|
|
4079
4073
|
from_auto_class,
|
|
@@ -4086,7 +4080,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4086
4080
|
|
|
4087
4081
|
# If the device_map has more than 1 device: dispatch model with hooks on all devices
|
|
4088
4082
|
if device_map is not None and len(set(device_map.values())) > 1:
|
|
4089
|
-
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder,
|
|
4083
|
+
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, disk_offload_index, offload_buffers)
|
|
4090
4084
|
|
|
4091
4085
|
if hf_quantizer is not None:
|
|
4092
4086
|
model.hf_quantizer = hf_quantizer
|
|
@@ -4095,44 +4089,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4095
4089
|
) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
|
|
4096
4090
|
|
|
4097
4091
|
if _adapter_model_path is not None:
|
|
4098
|
-
|
|
4099
|
-
|
|
4092
|
+
if token is not None:
|
|
4093
|
+
adapter_kwargs["token"] = token
|
|
4094
|
+
loading_info = model.load_adapter(
|
|
4100
4095
|
_adapter_model_path,
|
|
4101
4096
|
adapter_name=adapter_name,
|
|
4102
|
-
|
|
4097
|
+
load_config=load_config,
|
|
4103
4098
|
adapter_kwargs=adapter_kwargs,
|
|
4104
4099
|
)
|
|
4105
4100
|
|
|
4106
4101
|
if output_loading_info:
|
|
4107
|
-
|
|
4108
|
-
"missing_keys": missing_keys,
|
|
4109
|
-
"unexpected_keys": unexpected_keys,
|
|
4110
|
-
"mismatched_keys": mismatched_keys,
|
|
4111
|
-
"error_msgs": error_msgs,
|
|
4112
|
-
}
|
|
4113
|
-
return model, loading_info
|
|
4102
|
+
return model, loading_info.to_dict()
|
|
4114
4103
|
return model
|
|
4115
4104
|
|
|
4116
|
-
@
|
|
4105
|
+
@staticmethod
|
|
4117
4106
|
def _load_pretrained_model(
|
|
4118
|
-
cls,
|
|
4119
4107
|
model: "PreTrainedModel",
|
|
4120
4108
|
state_dict: dict | None,
|
|
4121
4109
|
checkpoint_files: list[str] | None,
|
|
4122
|
-
|
|
4123
|
-
|
|
4124
|
-
|
|
4125
|
-
|
|
4126
|
-
|
|
4127
|
-
offload_buffers: bool = False,
|
|
4128
|
-
dtype: torch.dtype | None = None,
|
|
4129
|
-
hf_quantizer: HfQuantizer | None = None,
|
|
4130
|
-
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
|
4131
|
-
weights_only: bool = True,
|
|
4132
|
-
weight_mapping: Sequence[WeightConverter | WeightRenaming] | None = None,
|
|
4133
|
-
):
|
|
4134
|
-
is_quantized = hf_quantizer is not None
|
|
4135
|
-
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
|
4110
|
+
load_config: LoadStateDictConfig,
|
|
4111
|
+
) -> tuple[LoadStateDictInfo, dict]:
|
|
4112
|
+
"""Perform the actual loading of some checkpoints into a `model`, by reading them from disk and dispatching them accordingly."""
|
|
4113
|
+
is_quantized = load_config.is_quantized
|
|
4114
|
+
is_hqq_or_quark = is_quantized and load_config.hf_quantizer.quantization_config.quant_method in {
|
|
4136
4115
|
QuantizationMethod.HQQ,
|
|
4137
4116
|
QuantizationMethod.QUARK,
|
|
4138
4117
|
}
|
|
@@ -4146,21 +4125,21 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4146
4125
|
# This offload index if for params explicitly on the "disk" in the device_map
|
|
4147
4126
|
disk_offload_index = None
|
|
4148
4127
|
# Prepare parameters offloading if needed
|
|
4149
|
-
if device_map is not None and "disk" in device_map.values():
|
|
4128
|
+
if load_config.device_map is not None and "disk" in load_config.device_map.values():
|
|
4150
4129
|
disk_offload_index = accelerate_disk_offload(
|
|
4151
4130
|
model,
|
|
4152
|
-
disk_offload_folder,
|
|
4131
|
+
load_config.disk_offload_folder,
|
|
4153
4132
|
checkpoint_files,
|
|
4154
|
-
device_map,
|
|
4155
|
-
sharded_metadata,
|
|
4156
|
-
dtype,
|
|
4157
|
-
weight_mapping,
|
|
4133
|
+
load_config.device_map,
|
|
4134
|
+
load_config.sharded_metadata,
|
|
4135
|
+
load_config.dtype,
|
|
4136
|
+
load_config.weight_mapping,
|
|
4158
4137
|
)
|
|
4159
4138
|
|
|
4160
4139
|
# Warmup cuda to load the weights much faster on devices
|
|
4161
|
-
if device_map is not None and not is_hqq_or_quark:
|
|
4162
|
-
expanded_device_map = expand_device_map(device_map, expected_keys)
|
|
4163
|
-
caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
|
|
4140
|
+
if load_config.device_map is not None and not is_hqq_or_quark:
|
|
4141
|
+
expanded_device_map = expand_device_map(load_config.device_map, expected_keys)
|
|
4142
|
+
caching_allocator_warmup(model, expanded_device_map, load_config.hf_quantizer)
|
|
4164
4143
|
|
|
4165
4144
|
error_msgs = []
|
|
4166
4145
|
|
|
@@ -4168,24 +4147,30 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4168
4147
|
if state_dict is None:
|
|
4169
4148
|
merged_state_dict = {}
|
|
4170
4149
|
for ckpt_file in checkpoint_files:
|
|
4171
|
-
merged_state_dict.update(
|
|
4150
|
+
merged_state_dict.update(
|
|
4151
|
+
load_state_dict(ckpt_file, map_location="cpu", weights_only=load_config.weights_only)
|
|
4152
|
+
)
|
|
4172
4153
|
state_dict = merged_state_dict
|
|
4173
|
-
error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict)
|
|
4154
|
+
error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict, load_config)
|
|
4174
4155
|
# This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
|
|
4175
|
-
|
|
4156
|
+
loading_info = LoadStateDictInfo(
|
|
4157
|
+
missing_keys=missing_keys,
|
|
4158
|
+
error_msgs=error_msgs,
|
|
4159
|
+
unexpected_keys=set(),
|
|
4160
|
+
mismatched_keys=set(),
|
|
4161
|
+
conversion_errors={},
|
|
4162
|
+
)
|
|
4176
4163
|
else:
|
|
4177
4164
|
all_pointer = set()
|
|
4178
|
-
|
|
4179
|
-
|
|
4165
|
+
if state_dict is not None:
|
|
4166
|
+
merged_state_dict = state_dict
|
|
4167
|
+
elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None:
|
|
4180
4168
|
merged_state_dict = {}
|
|
4181
4169
|
for file in checkpoint_files:
|
|
4182
4170
|
file_pointer = safe_open(file, framework="pt", device="cpu")
|
|
4183
4171
|
all_pointer.add(file_pointer)
|
|
4184
4172
|
for k in file_pointer.keys():
|
|
4185
4173
|
merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
|
|
4186
|
-
# User passed an explicit state_dict
|
|
4187
|
-
elif state_dict is not None:
|
|
4188
|
-
merged_state_dict = state_dict
|
|
4189
4174
|
# Checkpoints are .bin
|
|
4190
4175
|
elif checkpoint_files is not None:
|
|
4191
4176
|
merged_state_dict = {}
|
|
@@ -4194,58 +4179,58 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4194
4179
|
else:
|
|
4195
4180
|
raise ValueError("Neither a state dict nor checkpoint files were found.")
|
|
4196
4181
|
|
|
4197
|
-
|
|
4198
|
-
|
|
4199
|
-
|
|
4200
|
-
|
|
4201
|
-
|
|
4202
|
-
|
|
4203
|
-
hf_quantizer=hf_quantizer,
|
|
4204
|
-
dtype=dtype,
|
|
4205
|
-
device_map=device_map,
|
|
4206
|
-
dtype_plan=model.dtype_plan,
|
|
4207
|
-
device_mesh=device_mesh,
|
|
4208
|
-
disk_offload_index=disk_offload_index,
|
|
4209
|
-
disk_offload_folder=disk_offload_folder,
|
|
4210
|
-
offload_buffers=offload_buffers,
|
|
4211
|
-
)
|
|
4182
|
+
loading_info, disk_offload_index = convert_and_load_state_dict_in_model(
|
|
4183
|
+
model=model,
|
|
4184
|
+
state_dict=merged_state_dict,
|
|
4185
|
+
load_config=load_config,
|
|
4186
|
+
tp_plan=model._tp_plan,
|
|
4187
|
+
disk_offload_index=disk_offload_index,
|
|
4212
4188
|
)
|
|
4213
4189
|
|
|
4214
4190
|
# finally close all opened file pointers
|
|
4215
4191
|
for k in all_pointer:
|
|
4216
4192
|
k.__exit__(None, None, None)
|
|
4217
4193
|
|
|
4218
|
-
|
|
4219
|
-
model.mark_tied_weights_as_initialized()
|
|
4220
|
-
|
|
4221
|
-
# Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
|
|
4222
|
-
# meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
|
|
4223
|
-
missing_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
|
|
4224
|
-
model._move_missing_keys_from_meta_to_device(missing_and_mismatched, device_map, device_mesh, hf_quantizer)
|
|
4225
|
-
|
|
4226
|
-
# Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
|
|
4227
|
-
model._initialize_missing_keys(is_quantized)
|
|
4194
|
+
return loading_info, disk_offload_index
|
|
4228
4195
|
|
|
4229
|
-
|
|
4230
|
-
|
|
4196
|
+
@staticmethod
|
|
4197
|
+
def _finalize_model_loading(
|
|
4198
|
+
model, load_config: LoadStateDictConfig, loading_info: LoadStateDictInfo
|
|
4199
|
+
) -> LoadStateDictInfo:
|
|
4200
|
+
"""Perform all post processing operations after having loaded some checkpoints into a model, such as moving
|
|
4201
|
+
missing keys from meta device to their expected device, reinitializing missing weights according to proper
|
|
4202
|
+
distributions, tying the weights and logging the loading report."""
|
|
4203
|
+
try:
|
|
4204
|
+
# Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
|
|
4205
|
+
model.mark_tied_weights_as_initialized()
|
|
4206
|
+
|
|
4207
|
+
# Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
|
|
4208
|
+
# meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
|
|
4209
|
+
model._move_missing_keys_from_meta_to_device(
|
|
4210
|
+
loading_info.missing_and_mismatched(),
|
|
4211
|
+
load_config.device_map,
|
|
4212
|
+
load_config.device_mesh,
|
|
4213
|
+
load_config.hf_quantizer,
|
|
4214
|
+
)
|
|
4231
4215
|
|
|
4232
|
-
|
|
4233
|
-
|
|
4216
|
+
# Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
|
|
4217
|
+
model._initialize_missing_keys(load_config.is_quantized)
|
|
4234
4218
|
|
|
4235
|
-
|
|
4236
|
-
model=
|
|
4237
|
-
|
|
4238
|
-
|
|
4239
|
-
|
|
4240
|
-
|
|
4241
|
-
|
|
4242
|
-
|
|
4243
|
-
|
|
4244
|
-
|
|
4245
|
-
|
|
4246
|
-
|
|
4219
|
+
# Tie the weights
|
|
4220
|
+
model.tie_weights(missing_keys=loading_info.missing_keys, recompute_mapping=False)
|
|
4221
|
+
|
|
4222
|
+
# Adjust missing and unexpected keys
|
|
4223
|
+
model._adjust_missing_and_unexpected_keys(loading_info)
|
|
4224
|
+
finally:
|
|
4225
|
+
log_state_dict_report(
|
|
4226
|
+
model=model,
|
|
4227
|
+
pretrained_model_name_or_path=load_config.pretrained_model_name_or_path,
|
|
4228
|
+
ignore_mismatched_sizes=load_config.ignore_mismatched_sizes,
|
|
4229
|
+
loading_info=loading_info,
|
|
4230
|
+
logger=logger,
|
|
4231
|
+
)
|
|
4247
4232
|
|
|
4248
|
-
return
|
|
4233
|
+
return loading_info
|
|
4249
4234
|
|
|
4250
4235
|
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
|
|
4251
4236
|
module_keys = {".".join(key.split(".")[:-1]) for key in names}
|
|
@@ -4314,15 +4299,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4314
4299
|
|
|
4315
4300
|
# If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
|
|
4316
4301
|
# attention_mask or not. In this case, we should still show a warning because this is a rare case.
|
|
4302
|
+
# NOTE: `sep_token_id` is not used in all models and it can be absent in the config
|
|
4303
|
+
sep_token_id = getattr(self.config, "sep_token_id", None)
|
|
4317
4304
|
if (
|
|
4318
4305
|
(self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
|
|
4319
4306
|
or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
|
|
4320
|
-
or (
|
|
4307
|
+
or (sep_token_id is not None and sep_token_id == self.config.pad_token_id)
|
|
4321
4308
|
):
|
|
4322
4309
|
warn_string += (
|
|
4323
4310
|
f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
|
|
4324
4311
|
f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
|
|
4325
|
-
f"or the `sep_token_id` ({
|
|
4312
|
+
f"or the `sep_token_id` ({sep_token_id}), and your input is not padded."
|
|
4326
4313
|
)
|
|
4327
4314
|
|
|
4328
4315
|
logger.warning_once(warn_string)
|
|
@@ -4499,11 +4486,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4499
4486
|
else:
|
|
4500
4487
|
self.initialize_weights()
|
|
4501
4488
|
|
|
4502
|
-
def _adjust_missing_and_unexpected_keys(
|
|
4503
|
-
self, missing_keys: set[str], unexpected_keys: set[str]
|
|
4504
|
-
) -> tuple[set[str], set[str]]:
|
|
4489
|
+
def _adjust_missing_and_unexpected_keys(self, loading_info: LoadStateDictInfo) -> None:
|
|
4505
4490
|
"""Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
|
|
4506
|
-
raising unneeded warnings/errors.
|
|
4491
|
+
raising unneeded warnings/errors. This is performed in-place.
|
|
4507
4492
|
"""
|
|
4508
4493
|
# Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
|
|
4509
4494
|
# (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
|
|
@@ -4521,13 +4506,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4521
4506
|
|
|
4522
4507
|
# Clean-up missing keys
|
|
4523
4508
|
if ignore_missing_regex is not None:
|
|
4524
|
-
missing_keys = {
|
|
4509
|
+
loading_info.missing_keys = {
|
|
4510
|
+
key for key in loading_info.missing_keys if ignore_missing_regex.search(key) is None
|
|
4511
|
+
}
|
|
4525
4512
|
|
|
4526
4513
|
# Clean-up unexpected keys
|
|
4527
4514
|
if ignore_unexpected_regex is not None:
|
|
4528
|
-
unexpected_keys = {
|
|
4529
|
-
|
|
4530
|
-
|
|
4515
|
+
loading_info.unexpected_keys = {
|
|
4516
|
+
key for key in loading_info.unexpected_keys if ignore_unexpected_regex.search(key) is None
|
|
4517
|
+
}
|
|
4531
4518
|
|
|
4532
4519
|
def mark_tied_weights_as_initialized(self):
|
|
4533
4520
|
"""Adds the `_is_hf_initialized` flag on parameters that will be tied, in order to avoid initializing them
|
|
@@ -4709,7 +4696,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4709
4696
|
) - torch_accelerator_module.memory_allocated(index)
|
|
4710
4697
|
byte_count = int(max(0, byte_count - unused_memory))
|
|
4711
4698
|
# We divide by 2 here as we allocate in fp16
|
|
4712
|
-
_ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
|
|
4699
|
+
_ = torch.empty(int(byte_count // 2), dtype=torch.float16, device=device, requires_grad=False)
|
|
4713
4700
|
|
|
4714
4701
|
|
|
4715
4702
|
class AttentionInterface(GeneralInterface):
|
|
@@ -4732,6 +4719,20 @@ class AttentionInterface(GeneralInterface):
|
|
|
4732
4719
|
"paged|eager": eager_paged_attention_forward,
|
|
4733
4720
|
}
|
|
4734
4721
|
|
|
4722
|
+
def get_interface(self, attn_implementation: str, default: Callable) -> Callable:
|
|
4723
|
+
"""Return the requested `attn_implementation`. Also strictly check its validity, and raise if invalid."""
|
|
4724
|
+
if attn_implementation is None:
|
|
4725
|
+
logger.warning_once(
|
|
4726
|
+
"You tried to access the `AttentionInterface` with a `config._attn_implementation` set to `None`. This "
|
|
4727
|
+
"is expected if you use an Attention Module as a standalone Module. If this is not the case, something went "
|
|
4728
|
+
"wrong with the dispatch of `config._attn_implementation`"
|
|
4729
|
+
)
|
|
4730
|
+
elif attn_implementation != "eager" and attn_implementation not in self:
|
|
4731
|
+
raise KeyError(
|
|
4732
|
+
f"`{attn_implementation}` is not a valid attention implementation registered in the `AttentionInterface`"
|
|
4733
|
+
)
|
|
4734
|
+
return super().get(attn_implementation, default)
|
|
4735
|
+
|
|
4735
4736
|
|
|
4736
4737
|
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
|
4737
4738
|
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
|