transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +27 -27
- transformers/activations.py +1 -1
- transformers/audio_utils.py +32 -33
- transformers/cache_utils.py +32 -139
- transformers/cli/chat.py +3 -3
- transformers/cli/serve.py +2 -2
- transformers/cli/transformers.py +2 -1
- transformers/configuration_utils.py +143 -101
- transformers/conversion_mapping.py +73 -6
- transformers/convert_slow_tokenizer.py +3 -8
- transformers/core_model_loading.py +215 -50
- transformers/data/processors/glue.py +0 -1
- transformers/data/processors/utils.py +0 -1
- transformers/data/processors/xnli.py +0 -1
- transformers/dependency_versions_table.py +5 -5
- transformers/distributed/configuration_utils.py +1 -2
- transformers/dynamic_module_utils.py +23 -23
- transformers/feature_extraction_sequence_utils.py +19 -23
- transformers/feature_extraction_utils.py +63 -31
- transformers/generation/candidate_generator.py +80 -33
- transformers/generation/configuration_utils.py +186 -131
- transformers/generation/continuous_batching/__init__.py +0 -1
- transformers/generation/continuous_batching/cache.py +81 -24
- transformers/generation/continuous_batching/cache_manager.py +155 -45
- transformers/generation/continuous_batching/continuous_api.py +152 -84
- transformers/generation/continuous_batching/requests.py +51 -3
- transformers/generation/continuous_batching/scheduler.py +127 -52
- transformers/generation/logits_process.py +0 -128
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/streamers.py +0 -1
- transformers/generation/utils.py +107 -119
- transformers/generation/watermarking.py +8 -6
- transformers/hf_argparser.py +9 -13
- transformers/hyperparameter_search.py +1 -2
- transformers/image_processing_base.py +11 -21
- transformers/image_processing_utils.py +11 -12
- transformers/image_processing_utils_fast.py +68 -57
- transformers/image_transforms.py +29 -29
- transformers/image_utils.py +30 -32
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +3 -8
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +19 -4
- transformers/integrations/eetq.py +3 -6
- transformers/integrations/fbgemm_fp8.py +2 -3
- transformers/integrations/finegrained_fp8.py +14 -23
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/flex_attention.py +1 -1
- transformers/integrations/fp_quant.py +4 -6
- transformers/integrations/ggml.py +0 -1
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +37 -3
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +9 -16
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/quark.py +2 -4
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +4 -6
- transformers/integrations/vptq.py +3 -5
- transformers/loss/loss_lw_detr.py +356 -0
- transformers/loss/loss_utils.py +2 -0
- transformers/masking_utils.py +47 -51
- transformers/model_debugging_utils.py +4 -5
- transformers/modelcard.py +14 -192
- transformers/modeling_attn_mask_utils.py +19 -19
- transformers/modeling_flash_attention_utils.py +27 -27
- transformers/modeling_gguf_pytorch_utils.py +71 -24
- transformers/modeling_layers.py +21 -22
- transformers/modeling_outputs.py +242 -253
- transformers/modeling_rope_utils.py +110 -113
- transformers/modeling_utils.py +633 -576
- transformers/models/__init__.py +23 -0
- transformers/models/afmoe/configuration_afmoe.py +26 -29
- transformers/models/afmoe/modeling_afmoe.py +37 -49
- transformers/models/afmoe/modular_afmoe.py +21 -31
- transformers/models/aimv2/configuration_aimv2.py +2 -5
- transformers/models/aimv2/modeling_aimv2.py +24 -21
- transformers/models/aimv2/modular_aimv2.py +11 -9
- transformers/models/albert/configuration_albert.py +0 -1
- transformers/models/albert/modeling_albert.py +70 -69
- transformers/models/albert/tokenization_albert.py +1 -4
- transformers/models/align/configuration_align.py +0 -1
- transformers/models/align/modeling_align.py +73 -68
- transformers/models/align/processing_align.py +2 -30
- transformers/models/altclip/configuration_altclip.py +0 -1
- transformers/models/altclip/modeling_altclip.py +83 -80
- transformers/models/altclip/processing_altclip.py +2 -15
- transformers/models/apertus/__init__.py +0 -1
- transformers/models/apertus/configuration_apertus.py +18 -21
- transformers/models/apertus/modeling_apertus.py +35 -36
- transformers/models/apertus/modular_apertus.py +32 -31
- transformers/models/arcee/configuration_arcee.py +20 -23
- transformers/models/arcee/modeling_arcee.py +32 -35
- transformers/models/arcee/modular_arcee.py +20 -23
- transformers/models/aria/configuration_aria.py +20 -23
- transformers/models/aria/image_processing_aria.py +25 -27
- transformers/models/aria/modeling_aria.py +71 -70
- transformers/models/aria/modular_aria.py +85 -88
- transformers/models/aria/processing_aria.py +28 -35
- transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
- transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +6 -8
- transformers/models/audioflamingo3/__init__.py +0 -1
- transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +22 -23
- transformers/models/audioflamingo3/modular_audioflamingo3.py +12 -17
- transformers/models/audioflamingo3/processing_audioflamingo3.py +33 -30
- transformers/models/auto/auto_factory.py +5 -6
- transformers/models/auto/configuration_auto.py +53 -5
- transformers/models/auto/feature_extraction_auto.py +12 -10
- transformers/models/auto/image_processing_auto.py +17 -28
- transformers/models/auto/modeling_auto.py +38 -188
- transformers/models/auto/processing_auto.py +6 -1
- transformers/models/auto/tokenization_auto.py +147 -169
- transformers/models/auto/video_processing_auto.py +12 -10
- transformers/models/autoformer/configuration_autoformer.py +4 -7
- transformers/models/autoformer/modeling_autoformer.py +98 -100
- transformers/models/aya_vision/configuration_aya_vision.py +0 -1
- transformers/models/aya_vision/modeling_aya_vision.py +42 -40
- transformers/models/aya_vision/modular_aya_vision.py +26 -29
- transformers/models/aya_vision/processing_aya_vision.py +25 -53
- transformers/models/bamba/configuration_bamba.py +29 -32
- transformers/models/bamba/modeling_bamba.py +78 -83
- transformers/models/bamba/modular_bamba.py +68 -71
- transformers/models/bark/configuration_bark.py +4 -7
- transformers/models/bark/generation_configuration_bark.py +3 -5
- transformers/models/bark/modeling_bark.py +49 -55
- transformers/models/bark/processing_bark.py +19 -41
- transformers/models/bart/configuration_bart.py +0 -2
- transformers/models/bart/modeling_bart.py +122 -117
- transformers/models/barthez/tokenization_barthez.py +1 -4
- transformers/models/bartpho/tokenization_bartpho.py +6 -7
- transformers/models/beit/configuration_beit.py +0 -11
- transformers/models/beit/image_processing_beit.py +53 -56
- transformers/models/beit/image_processing_beit_fast.py +8 -10
- transformers/models/beit/modeling_beit.py +51 -53
- transformers/models/bert/configuration_bert.py +0 -1
- transformers/models/bert/modeling_bert.py +114 -122
- transformers/models/bert/tokenization_bert.py +2 -4
- transformers/models/bert/tokenization_bert_legacy.py +3 -5
- transformers/models/bert_generation/configuration_bert_generation.py +0 -1
- transformers/models/bert_generation/modeling_bert_generation.py +49 -49
- transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
- transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
- transformers/models/bertweet/tokenization_bertweet.py +1 -3
- transformers/models/big_bird/configuration_big_bird.py +0 -1
- transformers/models/big_bird/modeling_big_bird.py +110 -109
- transformers/models/big_bird/tokenization_big_bird.py +1 -4
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +0 -1
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +116 -111
- transformers/models/biogpt/configuration_biogpt.py +0 -1
- transformers/models/biogpt/modeling_biogpt.py +69 -71
- transformers/models/biogpt/modular_biogpt.py +59 -61
- transformers/models/biogpt/tokenization_biogpt.py +3 -5
- transformers/models/bit/configuration_bit.py +0 -1
- transformers/models/bit/image_processing_bit.py +21 -24
- transformers/models/bit/image_processing_bit_fast.py +0 -1
- transformers/models/bit/modeling_bit.py +14 -12
- transformers/models/bitnet/configuration_bitnet.py +18 -21
- transformers/models/bitnet/modeling_bitnet.py +32 -35
- transformers/models/bitnet/modular_bitnet.py +4 -6
- transformers/models/blenderbot/configuration_blenderbot.py +0 -1
- transformers/models/blenderbot/modeling_blenderbot.py +71 -95
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -8
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +0 -1
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +73 -68
- transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
- transformers/models/blip/configuration_blip.py +0 -1
- transformers/models/blip/image_processing_blip.py +17 -20
- transformers/models/blip/image_processing_blip_fast.py +0 -1
- transformers/models/blip/modeling_blip.py +62 -71
- transformers/models/blip/modeling_blip_text.py +71 -65
- transformers/models/blip/processing_blip.py +5 -36
- transformers/models/blip_2/configuration_blip_2.py +0 -1
- transformers/models/blip_2/modeling_blip_2.py +72 -71
- transformers/models/blip_2/processing_blip_2.py +8 -38
- transformers/models/bloom/configuration_bloom.py +0 -1
- transformers/models/bloom/modeling_bloom.py +71 -103
- transformers/models/blt/configuration_blt.py +71 -74
- transformers/models/blt/modeling_blt.py +235 -78
- transformers/models/blt/modular_blt.py +225 -62
- transformers/models/bridgetower/configuration_bridgetower.py +0 -1
- transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +7 -10
- transformers/models/bridgetower/modeling_bridgetower.py +113 -109
- transformers/models/bridgetower/processing_bridgetower.py +2 -16
- transformers/models/bros/configuration_bros.py +0 -1
- transformers/models/bros/modeling_bros.py +86 -80
- transformers/models/bros/processing_bros.py +2 -12
- transformers/models/byt5/tokenization_byt5.py +4 -6
- transformers/models/camembert/configuration_camembert.py +0 -1
- transformers/models/camembert/modeling_camembert.py +196 -195
- transformers/models/camembert/modular_camembert.py +51 -54
- transformers/models/camembert/tokenization_camembert.py +1 -4
- transformers/models/canine/configuration_canine.py +0 -1
- transformers/models/canine/modeling_canine.py +79 -75
- transformers/models/canine/tokenization_canine.py +2 -1
- transformers/models/chameleon/configuration_chameleon.py +24 -27
- transformers/models/chameleon/image_processing_chameleon.py +21 -24
- transformers/models/chameleon/image_processing_chameleon_fast.py +0 -1
- transformers/models/chameleon/modeling_chameleon.py +62 -60
- transformers/models/chameleon/processing_chameleon.py +16 -41
- transformers/models/chinese_clip/configuration_chinese_clip.py +0 -1
- transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
- transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +71 -69
- transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
- transformers/models/clap/configuration_clap.py +0 -1
- transformers/models/clap/feature_extraction_clap.py +11 -12
- transformers/models/clap/modeling_clap.py +113 -104
- transformers/models/clap/processing_clap.py +2 -15
- transformers/models/clip/configuration_clip.py +0 -1
- transformers/models/clip/image_processing_clip.py +21 -24
- transformers/models/clip/image_processing_clip_fast.py +0 -1
- transformers/models/clip/modeling_clip.py +47 -46
- transformers/models/clip/processing_clip.py +2 -14
- transformers/models/clip/tokenization_clip.py +2 -5
- transformers/models/clipseg/configuration_clipseg.py +0 -1
- transformers/models/clipseg/modeling_clipseg.py +90 -87
- transformers/models/clipseg/processing_clipseg.py +8 -39
- transformers/models/clvp/configuration_clvp.py +1 -3
- transformers/models/clvp/feature_extraction_clvp.py +7 -10
- transformers/models/clvp/modeling_clvp.py +133 -118
- transformers/models/clvp/number_normalizer.py +1 -2
- transformers/models/clvp/processing_clvp.py +3 -20
- transformers/models/clvp/tokenization_clvp.py +0 -1
- transformers/models/code_llama/tokenization_code_llama.py +4 -7
- transformers/models/codegen/configuration_codegen.py +0 -1
- transformers/models/codegen/modeling_codegen.py +61 -52
- transformers/models/codegen/tokenization_codegen.py +5 -6
- transformers/models/cohere/configuration_cohere.py +20 -23
- transformers/models/cohere/modeling_cohere.py +36 -39
- transformers/models/cohere/modular_cohere.py +24 -28
- transformers/models/cohere/tokenization_cohere.py +5 -6
- transformers/models/cohere2/configuration_cohere2.py +21 -24
- transformers/models/cohere2/modeling_cohere2.py +35 -38
- transformers/models/cohere2/modular_cohere2.py +39 -41
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +6 -8
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +35 -33
- transformers/models/cohere2_vision/modular_cohere2_vision.py +21 -23
- transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
- transformers/models/colpali/configuration_colpali.py +0 -1
- transformers/models/colpali/modeling_colpali.py +14 -16
- transformers/models/colpali/modular_colpali.py +11 -51
- transformers/models/colpali/processing_colpali.py +14 -52
- transformers/models/colqwen2/modeling_colqwen2.py +20 -22
- transformers/models/colqwen2/modular_colqwen2.py +29 -68
- transformers/models/colqwen2/processing_colqwen2.py +16 -52
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -2
- transformers/models/conditional_detr/image_processing_conditional_detr.py +64 -66
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +22 -22
- transformers/models/conditional_detr/modeling_conditional_detr.py +82 -81
- transformers/models/conditional_detr/modular_conditional_detr.py +1 -3
- transformers/models/convbert/configuration_convbert.py +0 -1
- transformers/models/convbert/modeling_convbert.py +88 -87
- transformers/models/convbert/tokenization_convbert.py +0 -1
- transformers/models/convnext/configuration_convnext.py +0 -1
- transformers/models/convnext/image_processing_convnext.py +20 -23
- transformers/models/convnext/image_processing_convnext_fast.py +14 -19
- transformers/models/convnext/modeling_convnext.py +5 -8
- transformers/models/convnextv2/configuration_convnextv2.py +0 -1
- transformers/models/convnextv2/modeling_convnextv2.py +5 -8
- transformers/models/cpm/tokenization_cpm.py +6 -7
- transformers/models/cpm/tokenization_cpm_fast.py +3 -5
- transformers/models/cpmant/configuration_cpmant.py +0 -1
- transformers/models/cpmant/modeling_cpmant.py +38 -40
- transformers/models/cpmant/tokenization_cpmant.py +1 -3
- transformers/models/csm/configuration_csm.py +49 -51
- transformers/models/csm/generation_csm.py +31 -35
- transformers/models/csm/modeling_csm.py +81 -82
- transformers/models/csm/modular_csm.py +58 -58
- transformers/models/csm/processing_csm.py +25 -68
- transformers/models/ctrl/configuration_ctrl.py +0 -1
- transformers/models/ctrl/modeling_ctrl.py +52 -43
- transformers/models/ctrl/tokenization_ctrl.py +0 -1
- transformers/models/cvt/configuration_cvt.py +0 -1
- transformers/models/cvt/modeling_cvt.py +18 -16
- transformers/models/cwm/__init__.py +0 -1
- transformers/models/cwm/configuration_cwm.py +3 -5
- transformers/models/cwm/modeling_cwm.py +33 -35
- transformers/models/cwm/modular_cwm.py +10 -12
- transformers/models/d_fine/configuration_d_fine.py +3 -5
- transformers/models/d_fine/modeling_d_fine.py +127 -121
- transformers/models/d_fine/modular_d_fine.py +23 -13
- transformers/models/dab_detr/configuration_dab_detr.py +2 -3
- transformers/models/dab_detr/modeling_dab_detr.py +69 -71
- transformers/models/dac/configuration_dac.py +0 -1
- transformers/models/dac/feature_extraction_dac.py +6 -9
- transformers/models/dac/modeling_dac.py +21 -23
- transformers/models/data2vec/configuration_data2vec_audio.py +0 -1
- transformers/models/data2vec/configuration_data2vec_text.py +0 -1
- transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
- transformers/models/data2vec/modeling_data2vec_audio.py +52 -56
- transformers/models/data2vec/modeling_data2vec_text.py +98 -93
- transformers/models/data2vec/modeling_data2vec_vision.py +41 -42
- transformers/models/data2vec/modular_data2vec_audio.py +6 -1
- transformers/models/data2vec/modular_data2vec_text.py +58 -54
- transformers/models/dbrx/configuration_dbrx.py +27 -20
- transformers/models/dbrx/modeling_dbrx.py +40 -43
- transformers/models/dbrx/modular_dbrx.py +31 -33
- transformers/models/deberta/configuration_deberta.py +0 -1
- transformers/models/deberta/modeling_deberta.py +59 -60
- transformers/models/deberta/tokenization_deberta.py +2 -5
- transformers/models/deberta_v2/configuration_deberta_v2.py +0 -1
- transformers/models/deberta_v2/modeling_deberta_v2.py +65 -65
- transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
- transformers/models/decision_transformer/configuration_decision_transformer.py +0 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +56 -55
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +34 -37
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -37
- transformers/models/deepseek_v2/modular_deepseek_v2.py +44 -44
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +35 -38
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +40 -38
- transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -7
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +2 -3
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +25 -26
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +7 -7
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +40 -36
- transformers/models/deepseek_vl/modular_deepseek_vl.py +14 -43
- transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +3 -5
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +35 -35
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +16 -20
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +42 -38
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +80 -99
- transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -3
- transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +17 -17
- transformers/models/deformable_detr/modeling_deformable_detr.py +67 -68
- transformers/models/deformable_detr/modular_deformable_detr.py +1 -3
- transformers/models/deit/configuration_deit.py +0 -1
- transformers/models/deit/image_processing_deit.py +18 -21
- transformers/models/deit/image_processing_deit_fast.py +0 -1
- transformers/models/deit/modeling_deit.py +16 -18
- transformers/models/depth_anything/configuration_depth_anything.py +2 -4
- transformers/models/depth_anything/modeling_depth_anything.py +5 -8
- transformers/models/depth_pro/configuration_depth_pro.py +0 -1
- transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +6 -8
- transformers/models/depth_pro/modeling_depth_pro.py +21 -23
- transformers/models/detr/configuration_detr.py +1 -2
- transformers/models/detr/image_processing_detr.py +64 -66
- transformers/models/detr/image_processing_detr_fast.py +22 -23
- transformers/models/detr/modeling_detr.py +78 -73
- transformers/models/dia/configuration_dia.py +5 -8
- transformers/models/dia/feature_extraction_dia.py +6 -9
- transformers/models/dia/generation_dia.py +42 -45
- transformers/models/dia/modeling_dia.py +73 -65
- transformers/models/dia/modular_dia.py +63 -54
- transformers/models/dia/processing_dia.py +39 -29
- transformers/models/dia/tokenization_dia.py +3 -6
- transformers/models/diffllama/configuration_diffllama.py +20 -23
- transformers/models/diffllama/modeling_diffllama.py +44 -47
- transformers/models/diffllama/modular_diffllama.py +17 -19
- transformers/models/dinat/configuration_dinat.py +0 -1
- transformers/models/dinat/modeling_dinat.py +40 -42
- transformers/models/dinov2/configuration_dinov2.py +0 -1
- transformers/models/dinov2/modeling_dinov2.py +11 -13
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +1 -1
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +12 -13
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +5 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +4 -7
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +3 -6
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +5 -8
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +5 -7
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +17 -16
- transformers/models/dinov3_vit/modular_dinov3_vit.py +14 -13
- transformers/models/distilbert/configuration_distilbert.py +0 -1
- transformers/models/distilbert/modeling_distilbert.py +55 -55
- transformers/models/distilbert/tokenization_distilbert.py +0 -1
- transformers/models/doge/__init__.py +0 -1
- transformers/models/doge/configuration_doge.py +25 -28
- transformers/models/doge/modeling_doge.py +43 -46
- transformers/models/doge/modular_doge.py +57 -58
- transformers/models/donut/configuration_donut_swin.py +0 -1
- transformers/models/donut/image_processing_donut.py +26 -29
- transformers/models/donut/image_processing_donut_fast.py +5 -11
- transformers/models/donut/modeling_donut_swin.py +60 -58
- transformers/models/donut/processing_donut.py +5 -26
- transformers/models/dots1/configuration_dots1.py +27 -29
- transformers/models/dots1/modeling_dots1.py +45 -39
- transformers/models/dots1/modular_dots1.py +0 -1
- transformers/models/dpr/configuration_dpr.py +0 -1
- transformers/models/dpr/modeling_dpr.py +37 -39
- transformers/models/dpr/tokenization_dpr.py +7 -9
- transformers/models/dpr/tokenization_dpr_fast.py +7 -9
- transformers/models/dpt/configuration_dpt.py +1 -2
- transformers/models/dpt/image_processing_dpt.py +65 -66
- transformers/models/dpt/image_processing_dpt_fast.py +14 -16
- transformers/models/dpt/modeling_dpt.py +19 -21
- transformers/models/dpt/modular_dpt.py +11 -13
- transformers/models/edgetam/configuration_edgetam.py +1 -2
- transformers/models/edgetam/modeling_edgetam.py +44 -43
- transformers/models/edgetam/modular_edgetam.py +17 -20
- transformers/models/edgetam_video/__init__.py +0 -1
- transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +131 -120
- transformers/models/edgetam_video/modular_edgetam_video.py +29 -37
- transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
- transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +5 -6
- transformers/models/efficientloftr/modeling_efficientloftr.py +41 -30
- transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
- transformers/models/efficientnet/configuration_efficientnet.py +0 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +28 -32
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +15 -17
- transformers/models/efficientnet/modeling_efficientnet.py +17 -15
- transformers/models/electra/configuration_electra.py +0 -1
- transformers/models/electra/modeling_electra.py +108 -103
- transformers/models/emu3/configuration_emu3.py +5 -7
- transformers/models/emu3/image_processing_emu3.py +44 -39
- transformers/models/emu3/modeling_emu3.py +67 -64
- transformers/models/emu3/modular_emu3.py +39 -35
- transformers/models/emu3/processing_emu3.py +18 -43
- transformers/models/encodec/configuration_encodec.py +2 -4
- transformers/models/encodec/feature_extraction_encodec.py +10 -13
- transformers/models/encodec/modeling_encodec.py +39 -29
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +0 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +17 -19
- transformers/models/eomt/configuration_eomt.py +0 -1
- transformers/models/eomt/image_processing_eomt.py +53 -55
- transformers/models/eomt/image_processing_eomt_fast.py +59 -28
- transformers/models/eomt/modeling_eomt.py +23 -18
- transformers/models/eomt/modular_eomt.py +18 -13
- transformers/models/ernie/configuration_ernie.py +0 -1
- transformers/models/ernie/modeling_ernie.py +127 -132
- transformers/models/ernie/modular_ernie.py +97 -103
- transformers/models/ernie4_5/configuration_ernie4_5.py +18 -20
- transformers/models/ernie4_5/modeling_ernie4_5.py +32 -34
- transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +27 -29
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +52 -51
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +16 -44
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +329 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +455 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +231 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1895 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1901 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +249 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +593 -0
- transformers/models/esm/configuration_esm.py +2 -4
- transformers/models/esm/modeling_esm.py +38 -34
- transformers/models/esm/modeling_esmfold.py +48 -45
- transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
- transformers/models/esm/openfold_utils/loss.py +1 -2
- transformers/models/esm/openfold_utils/protein.py +13 -13
- transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
- transformers/models/esm/tokenization_esm.py +2 -4
- transformers/models/evolla/configuration_evolla.py +29 -32
- transformers/models/evolla/modeling_evolla.py +67 -62
- transformers/models/evolla/modular_evolla.py +53 -47
- transformers/models/evolla/processing_evolla.py +23 -35
- transformers/models/exaone4/configuration_exaone4.py +19 -22
- transformers/models/exaone4/modeling_exaone4.py +33 -36
- transformers/models/exaone4/modular_exaone4.py +40 -42
- transformers/models/falcon/configuration_falcon.py +22 -25
- transformers/models/falcon/modeling_falcon.py +75 -78
- transformers/models/falcon_h1/configuration_falcon_h1.py +40 -43
- transformers/models/falcon_h1/modeling_falcon_h1.py +80 -78
- transformers/models/falcon_h1/modular_falcon_h1.py +54 -50
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +0 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +50 -47
- transformers/models/falcon_mamba/modular_falcon_mamba.py +16 -14
- transformers/models/fast_vlm/configuration_fast_vlm.py +1 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +43 -39
- transformers/models/fast_vlm/modular_fast_vlm.py +2 -3
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -5
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +68 -57
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +2 -3
- transformers/models/flaubert/configuration_flaubert.py +0 -1
- transformers/models/flaubert/modeling_flaubert.py +138 -143
- transformers/models/flaubert/tokenization_flaubert.py +3 -5
- transformers/models/flava/configuration_flava.py +5 -6
- transformers/models/flava/image_processing_flava.py +66 -67
- transformers/models/flava/image_processing_flava_fast.py +42 -45
- transformers/models/flava/modeling_flava.py +111 -107
- transformers/models/flava/processing_flava.py +2 -12
- transformers/models/flex_olmo/__init__.py +0 -1
- transformers/models/flex_olmo/configuration_flex_olmo.py +23 -25
- transformers/models/flex_olmo/modeling_flex_olmo.py +44 -43
- transformers/models/flex_olmo/modular_flex_olmo.py +35 -37
- transformers/models/florence2/configuration_florence2.py +0 -1
- transformers/models/florence2/modeling_florence2.py +59 -43
- transformers/models/florence2/modular_florence2.py +65 -81
- transformers/models/florence2/processing_florence2.py +18 -47
- transformers/models/fnet/configuration_fnet.py +0 -1
- transformers/models/fnet/modeling_fnet.py +76 -80
- transformers/models/fnet/tokenization_fnet.py +0 -1
- transformers/models/focalnet/configuration_focalnet.py +0 -1
- transformers/models/focalnet/modeling_focalnet.py +39 -41
- transformers/models/fsmt/configuration_fsmt.py +0 -1
- transformers/models/fsmt/modeling_fsmt.py +47 -48
- transformers/models/fsmt/tokenization_fsmt.py +3 -5
- transformers/models/funnel/configuration_funnel.py +0 -1
- transformers/models/funnel/modeling_funnel.py +91 -93
- transformers/models/funnel/tokenization_funnel.py +2 -5
- transformers/models/fuyu/configuration_fuyu.py +23 -26
- transformers/models/fuyu/image_processing_fuyu.py +29 -31
- transformers/models/fuyu/image_processing_fuyu_fast.py +12 -13
- transformers/models/fuyu/modeling_fuyu.py +29 -30
- transformers/models/fuyu/processing_fuyu.py +23 -34
- transformers/models/gemma/configuration_gemma.py +20 -23
- transformers/models/gemma/modeling_gemma.py +42 -46
- transformers/models/gemma/modular_gemma.py +37 -40
- transformers/models/gemma/tokenization_gemma.py +3 -6
- transformers/models/gemma2/configuration_gemma2.py +25 -28
- transformers/models/gemma2/modeling_gemma2.py +35 -38
- transformers/models/gemma2/modular_gemma2.py +56 -58
- transformers/models/gemma3/configuration_gemma3.py +28 -29
- transformers/models/gemma3/image_processing_gemma3.py +29 -31
- transformers/models/gemma3/image_processing_gemma3_fast.py +9 -11
- transformers/models/gemma3/modeling_gemma3.py +112 -94
- transformers/models/gemma3/modular_gemma3.py +110 -91
- transformers/models/gemma3/processing_gemma3.py +5 -5
- transformers/models/gemma3n/configuration_gemma3n.py +12 -10
- transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
- transformers/models/gemma3n/modeling_gemma3n.py +127 -98
- transformers/models/gemma3n/modular_gemma3n.py +117 -84
- transformers/models/gemma3n/processing_gemma3n.py +12 -26
- transformers/models/git/configuration_git.py +0 -1
- transformers/models/git/modeling_git.py +250 -197
- transformers/models/git/processing_git.py +2 -14
- transformers/models/glm/configuration_glm.py +19 -21
- transformers/models/glm/modeling_glm.py +33 -36
- transformers/models/glm/modular_glm.py +4 -7
- transformers/models/glm4/configuration_glm4.py +19 -21
- transformers/models/glm4/modeling_glm4.py +36 -38
- transformers/models/glm4/modular_glm4.py +8 -10
- transformers/models/glm46v/configuration_glm46v.py +0 -1
- transformers/models/glm46v/image_processing_glm46v.py +35 -40
- transformers/models/glm46v/image_processing_glm46v_fast.py +7 -7
- transformers/models/glm46v/modeling_glm46v.py +54 -52
- transformers/models/glm46v/modular_glm46v.py +4 -3
- transformers/models/glm46v/processing_glm46v.py +7 -41
- transformers/models/glm46v/video_processing_glm46v.py +9 -11
- transformers/models/glm4_moe/configuration_glm4_moe.py +25 -28
- transformers/models/glm4_moe/modeling_glm4_moe.py +41 -40
- transformers/models/glm4_moe/modular_glm4_moe.py +27 -30
- transformers/models/glm4_moe_lite/__init__.py +28 -0
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +235 -0
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +304 -0
- transformers/models/glm4v/configuration_glm4v.py +14 -17
- transformers/models/glm4v/image_processing_glm4v.py +34 -40
- transformers/models/glm4v/image_processing_glm4v_fast.py +6 -7
- transformers/models/glm4v/modeling_glm4v.py +148 -156
- transformers/models/glm4v/modular_glm4v.py +142 -185
- transformers/models/glm4v/processing_glm4v.py +7 -41
- transformers/models/glm4v/video_processing_glm4v.py +9 -11
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +119 -122
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +275 -319
- transformers/models/glm4v_moe/modular_glm4v_moe.py +66 -163
- transformers/models/glm_image/__init__.py +31 -0
- transformers/models/glm_image/configuration_glm_image.py +352 -0
- transformers/models/glm_image/image_processing_glm_image.py +503 -0
- transformers/models/glm_image/image_processing_glm_image_fast.py +296 -0
- transformers/models/glm_image/modeling_glm_image.py +1590 -0
- transformers/models/glm_image/modular_glm_image.py +1480 -0
- transformers/models/glm_image/processing_glm_image.py +217 -0
- transformers/models/glmasr/__init__.py +29 -0
- transformers/models/glmasr/configuration_glmasr.py +196 -0
- transformers/models/glmasr/modeling_glmasr.py +511 -0
- transformers/models/glmasr/modular_glmasr.py +431 -0
- transformers/models/glmasr/processing_glmasr.py +331 -0
- transformers/models/glpn/configuration_glpn.py +0 -1
- transformers/models/glpn/image_processing_glpn.py +11 -12
- transformers/models/glpn/image_processing_glpn_fast.py +8 -10
- transformers/models/glpn/modeling_glpn.py +10 -12
- transformers/models/got_ocr2/configuration_got_ocr2.py +5 -8
- transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +6 -8
- transformers/models/got_ocr2/modeling_got_ocr2.py +48 -45
- transformers/models/got_ocr2/modular_got_ocr2.py +31 -34
- transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
- transformers/models/gpt2/configuration_gpt2.py +0 -1
- transformers/models/gpt2/modeling_gpt2.py +114 -113
- transformers/models/gpt2/tokenization_gpt2.py +6 -9
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +0 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +76 -88
- transformers/models/gpt_neo/configuration_gpt_neo.py +0 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +77 -66
- transformers/models/gpt_neox/configuration_gpt_neox.py +19 -22
- transformers/models/gpt_neox/modeling_gpt_neox.py +71 -73
- transformers/models/gpt_neox/modular_gpt_neox.py +64 -66
- transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +15 -18
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +42 -45
- transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
- transformers/models/gpt_oss/configuration_gpt_oss.py +38 -24
- transformers/models/gpt_oss/modeling_gpt_oss.py +40 -44
- transformers/models/gpt_oss/modular_gpt_oss.py +22 -26
- transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
- transformers/models/gptj/configuration_gptj.py +0 -1
- transformers/models/gptj/modeling_gptj.py +96 -86
- transformers/models/granite/configuration_granite.py +23 -26
- transformers/models/granite/modeling_granite.py +40 -42
- transformers/models/granite/modular_granite.py +29 -31
- transformers/models/granite_speech/configuration_granite_speech.py +0 -1
- transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
- transformers/models/granite_speech/modeling_granite_speech.py +36 -24
- transformers/models/granite_speech/processing_granite_speech.py +11 -4
- transformers/models/granitemoe/configuration_granitemoe.py +26 -29
- transformers/models/granitemoe/modeling_granitemoe.py +37 -40
- transformers/models/granitemoe/modular_granitemoe.py +22 -25
- transformers/models/granitemoehybrid/__init__.py +0 -1
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +41 -40
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +92 -86
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +29 -21
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +27 -30
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +50 -55
- transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -4
- transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +17 -18
- transformers/models/grounding_dino/modeling_grounding_dino.py +95 -97
- transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
- transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
- transformers/models/groupvit/configuration_groupvit.py +0 -1
- transformers/models/groupvit/modeling_groupvit.py +75 -71
- transformers/models/helium/configuration_helium.py +20 -22
- transformers/models/helium/modeling_helium.py +34 -37
- transformers/models/helium/modular_helium.py +3 -7
- transformers/models/herbert/tokenization_herbert.py +4 -6
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +0 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -9
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -9
- transformers/models/hiera/configuration_hiera.py +0 -1
- transformers/models/hiera/modeling_hiera.py +60 -62
- transformers/models/hubert/configuration_hubert.py +0 -1
- transformers/models/hubert/modeling_hubert.py +39 -37
- transformers/models/hubert/modular_hubert.py +12 -11
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +21 -24
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +31 -34
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +4 -6
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +25 -28
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +44 -39
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +9 -9
- transformers/models/ibert/configuration_ibert.py +0 -1
- transformers/models/ibert/modeling_ibert.py +76 -62
- transformers/models/ibert/quant_modules.py +0 -1
- transformers/models/idefics/configuration_idefics.py +0 -1
- transformers/models/idefics/image_processing_idefics.py +13 -15
- transformers/models/idefics/modeling_idefics.py +70 -61
- transformers/models/idefics/perceiver.py +1 -3
- transformers/models/idefics/processing_idefics.py +32 -48
- transformers/models/idefics/vision.py +22 -24
- transformers/models/idefics2/configuration_idefics2.py +0 -1
- transformers/models/idefics2/image_processing_idefics2.py +31 -32
- transformers/models/idefics2/image_processing_idefics2_fast.py +7 -8
- transformers/models/idefics2/modeling_idefics2.py +63 -59
- transformers/models/idefics2/processing_idefics2.py +10 -68
- transformers/models/idefics3/configuration_idefics3.py +0 -1
- transformers/models/idefics3/image_processing_idefics3.py +42 -43
- transformers/models/idefics3/image_processing_idefics3_fast.py +11 -12
- transformers/models/idefics3/modeling_idefics3.py +57 -55
- transformers/models/idefics3/processing_idefics3.py +15 -69
- transformers/models/ijepa/configuration_ijepa.py +0 -1
- transformers/models/ijepa/modeling_ijepa.py +10 -11
- transformers/models/ijepa/modular_ijepa.py +5 -7
- transformers/models/imagegpt/configuration_imagegpt.py +0 -1
- transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +9 -14
- transformers/models/imagegpt/modeling_imagegpt.py +66 -60
- transformers/models/informer/configuration_informer.py +6 -9
- transformers/models/informer/modeling_informer.py +84 -86
- transformers/models/informer/modular_informer.py +13 -16
- transformers/models/instructblip/configuration_instructblip.py +0 -1
- transformers/models/instructblip/modeling_instructblip.py +45 -44
- transformers/models/instructblip/processing_instructblip.py +10 -36
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +0 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +107 -105
- transformers/models/instructblipvideo/modular_instructblipvideo.py +34 -36
- transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +4 -6
- transformers/models/internvl/configuration_internvl.py +0 -1
- transformers/models/internvl/modeling_internvl.py +52 -51
- transformers/models/internvl/modular_internvl.py +24 -30
- transformers/models/internvl/processing_internvl.py +12 -45
- transformers/models/internvl/video_processing_internvl.py +8 -10
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +150 -0
- transformers/models/jais2/modeling_jais2.py +484 -0
- transformers/models/jais2/modular_jais2.py +194 -0
- transformers/models/jamba/configuration_jamba.py +0 -1
- transformers/models/jamba/modeling_jamba.py +67 -65
- transformers/models/jamba/modular_jamba.py +54 -55
- transformers/models/janus/configuration_janus.py +0 -1
- transformers/models/janus/image_processing_janus.py +35 -37
- transformers/models/janus/image_processing_janus_fast.py +12 -14
- transformers/models/janus/modeling_janus.py +56 -50
- transformers/models/janus/modular_janus.py +76 -70
- transformers/models/janus/processing_janus.py +17 -43
- transformers/models/jetmoe/configuration_jetmoe.py +20 -23
- transformers/models/jetmoe/modeling_jetmoe.py +41 -44
- transformers/models/jetmoe/modular_jetmoe.py +31 -33
- transformers/models/kosmos2/configuration_kosmos2.py +0 -1
- transformers/models/kosmos2/modeling_kosmos2.py +159 -148
- transformers/models/kosmos2/processing_kosmos2.py +40 -55
- transformers/models/kosmos2_5/__init__.py +0 -1
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +0 -1
- transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +4 -13
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +118 -110
- transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +23 -25
- transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +67 -68
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +28 -22
- transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
- transformers/models/lasr/configuration_lasr.py +5 -3
- transformers/models/lasr/feature_extraction_lasr.py +10 -12
- transformers/models/lasr/modeling_lasr.py +21 -23
- transformers/models/lasr/modular_lasr.py +16 -11
- transformers/models/lasr/processing_lasr.py +12 -8
- transformers/models/lasr/tokenization_lasr.py +2 -4
- transformers/models/layoutlm/configuration_layoutlm.py +0 -1
- transformers/models/layoutlm/modeling_layoutlm.py +72 -72
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +0 -1
- transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +5 -7
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +60 -50
- transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +64 -74
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +0 -1
- transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +7 -9
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +78 -56
- transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
- transformers/models/layoutxlm/configuration_layoutxlm.py +0 -1
- transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
- transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
- transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
- transformers/models/led/configuration_led.py +1 -4
- transformers/models/led/modeling_led.py +119 -267
- transformers/models/levit/configuration_levit.py +0 -1
- transformers/models/levit/image_processing_levit.py +19 -21
- transformers/models/levit/image_processing_levit_fast.py +0 -1
- transformers/models/levit/modeling_levit.py +35 -19
- transformers/models/lfm2/configuration_lfm2.py +22 -23
- transformers/models/lfm2/modeling_lfm2.py +43 -45
- transformers/models/lfm2/modular_lfm2.py +29 -29
- transformers/models/lfm2_moe/__init__.py +0 -1
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +1 -2
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +58 -49
- transformers/models/lfm2_moe/modular_lfm2_moe.py +13 -37
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +34 -5
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -38
- transformers/models/lfm2_vl/modular_lfm2_vl.py +28 -29
- transformers/models/lfm2_vl/processing_lfm2_vl.py +96 -76
- transformers/models/lightglue/image_processing_lightglue.py +16 -15
- transformers/models/lightglue/image_processing_lightglue_fast.py +5 -6
- transformers/models/lightglue/modeling_lightglue.py +28 -30
- transformers/models/lightglue/modular_lightglue.py +28 -28
- transformers/models/lighton_ocr/__init__.py +28 -0
- transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +460 -0
- transformers/models/lighton_ocr/modular_lighton_ocr.py +403 -0
- transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
- transformers/models/lilt/configuration_lilt.py +0 -1
- transformers/models/lilt/modeling_lilt.py +72 -70
- transformers/models/llama/configuration_llama.py +21 -24
- transformers/models/llama/modeling_llama.py +32 -35
- transformers/models/llama/tokenization_llama.py +2 -4
- transformers/models/llama4/configuration_llama4.py +20 -22
- transformers/models/llama4/image_processing_llama4_fast.py +9 -11
- transformers/models/llama4/modeling_llama4.py +78 -75
- transformers/models/llama4/processing_llama4.py +33 -57
- transformers/models/llava/configuration_llava.py +0 -1
- transformers/models/llava/image_processing_llava.py +25 -28
- transformers/models/llava/image_processing_llava_fast.py +6 -8
- transformers/models/llava/modeling_llava.py +47 -44
- transformers/models/llava/processing_llava.py +18 -51
- transformers/models/llava_next/configuration_llava_next.py +0 -1
- transformers/models/llava_next/image_processing_llava_next.py +43 -45
- transformers/models/llava_next/image_processing_llava_next_fast.py +5 -7
- transformers/models/llava_next/modeling_llava_next.py +49 -47
- transformers/models/llava_next/processing_llava_next.py +18 -47
- transformers/models/llava_next_video/configuration_llava_next_video.py +0 -1
- transformers/models/llava_next_video/modeling_llava_next_video.py +60 -58
- transformers/models/llava_next_video/modular_llava_next_video.py +51 -49
- transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
- transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
- transformers/models/llava_onevision/configuration_llava_onevision.py +0 -1
- transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +6 -8
- transformers/models/llava_onevision/modeling_llava_onevision.py +67 -65
- transformers/models/llava_onevision/modular_llava_onevision.py +58 -56
- transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
- transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
- transformers/models/longcat_flash/__init__.py +0 -1
- transformers/models/longcat_flash/configuration_longcat_flash.py +32 -35
- transformers/models/longcat_flash/modeling_longcat_flash.py +32 -32
- transformers/models/longcat_flash/modular_longcat_flash.py +18 -19
- transformers/models/longformer/configuration_longformer.py +1 -4
- transformers/models/longformer/modeling_longformer.py +99 -101
- transformers/models/longt5/configuration_longt5.py +0 -1
- transformers/models/longt5/modeling_longt5.py +43 -48
- transformers/models/luke/configuration_luke.py +0 -1
- transformers/models/luke/modeling_luke.py +179 -181
- transformers/models/luke/tokenization_luke.py +99 -105
- transformers/models/lw_detr/__init__.py +27 -0
- transformers/models/lw_detr/configuration_lw_detr.py +374 -0
- transformers/models/lw_detr/modeling_lw_detr.py +1698 -0
- transformers/models/lw_detr/modular_lw_detr.py +1611 -0
- transformers/models/lxmert/configuration_lxmert.py +0 -1
- transformers/models/lxmert/modeling_lxmert.py +63 -74
- transformers/models/m2m_100/configuration_m2m_100.py +0 -1
- transformers/models/m2m_100/modeling_m2m_100.py +79 -71
- transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
- transformers/models/mamba/configuration_mamba.py +0 -1
- transformers/models/mamba/modeling_mamba.py +44 -44
- transformers/models/mamba2/configuration_mamba2.py +0 -1
- transformers/models/mamba2/modeling_mamba2.py +67 -68
- transformers/models/marian/configuration_marian.py +1 -2
- transformers/models/marian/modeling_marian.py +87 -86
- transformers/models/marian/tokenization_marian.py +6 -6
- transformers/models/markuplm/configuration_markuplm.py +0 -1
- transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
- transformers/models/markuplm/modeling_markuplm.py +65 -70
- transformers/models/markuplm/processing_markuplm.py +31 -38
- transformers/models/markuplm/tokenization_markuplm.py +67 -77
- transformers/models/mask2former/configuration_mask2former.py +5 -8
- transformers/models/mask2former/image_processing_mask2former.py +84 -85
- transformers/models/mask2former/image_processing_mask2former_fast.py +30 -33
- transformers/models/mask2former/modeling_mask2former.py +99 -92
- transformers/models/mask2former/modular_mask2former.py +6 -8
- transformers/models/maskformer/configuration_maskformer.py +6 -9
- transformers/models/maskformer/configuration_maskformer_swin.py +0 -1
- transformers/models/maskformer/image_processing_maskformer.py +84 -85
- transformers/models/maskformer/image_processing_maskformer_fast.py +29 -33
- transformers/models/maskformer/modeling_maskformer.py +65 -59
- transformers/models/maskformer/modeling_maskformer_swin.py +34 -32
- transformers/models/mbart/configuration_mbart.py +1 -1
- transformers/models/mbart/modeling_mbart.py +118 -113
- transformers/models/mbart/tokenization_mbart.py +2 -4
- transformers/models/mbart50/tokenization_mbart50.py +3 -5
- transformers/models/megatron_bert/configuration_megatron_bert.py +0 -1
- transformers/models/megatron_bert/modeling_megatron_bert.py +141 -150
- transformers/models/metaclip_2/modeling_metaclip_2.py +48 -46
- transformers/models/metaclip_2/modular_metaclip_2.py +21 -21
- transformers/models/mgp_str/configuration_mgp_str.py +0 -1
- transformers/models/mgp_str/modeling_mgp_str.py +14 -16
- transformers/models/mgp_str/processing_mgp_str.py +3 -20
- transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
- transformers/models/mimi/configuration_mimi.py +38 -40
- transformers/models/mimi/modeling_mimi.py +100 -82
- transformers/models/minimax/__init__.py +0 -1
- transformers/models/minimax/configuration_minimax.py +32 -36
- transformers/models/minimax/modeling_minimax.py +57 -47
- transformers/models/minimax/modular_minimax.py +62 -54
- transformers/models/minimax_m2/__init__.py +28 -0
- transformers/models/minimax_m2/configuration_minimax_m2.py +211 -0
- transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
- transformers/models/minimax_m2/modular_minimax_m2.py +369 -0
- transformers/models/ministral/configuration_ministral.py +20 -22
- transformers/models/ministral/modeling_ministral.py +32 -34
- transformers/models/ministral/modular_ministral.py +27 -29
- transformers/models/ministral3/configuration_ministral3.py +19 -22
- transformers/models/ministral3/modeling_ministral3.py +32 -34
- transformers/models/ministral3/modular_ministral3.py +4 -5
- transformers/models/mistral/configuration_mistral.py +19 -22
- transformers/models/mistral/modeling_mistral.py +32 -34
- transformers/models/mistral/modular_mistral.py +11 -12
- transformers/models/mistral3/configuration_mistral3.py +0 -1
- transformers/models/mistral3/modeling_mistral3.py +53 -46
- transformers/models/mistral3/modular_mistral3.py +38 -36
- transformers/models/mixtral/configuration_mixtral.py +24 -27
- transformers/models/mixtral/modeling_mixtral.py +47 -42
- transformers/models/mixtral/modular_mixtral.py +32 -31
- transformers/models/mlcd/configuration_mlcd.py +0 -1
- transformers/models/mlcd/modeling_mlcd.py +16 -12
- transformers/models/mlcd/modular_mlcd.py +13 -11
- transformers/models/mllama/configuration_mllama.py +5 -8
- transformers/models/mllama/image_processing_mllama.py +23 -25
- transformers/models/mllama/image_processing_mllama_fast.py +5 -6
- transformers/models/mllama/modeling_mllama.py +94 -86
- transformers/models/mllama/processing_mllama.py +6 -55
- transformers/models/mluke/tokenization_mluke.py +97 -103
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -3
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +95 -97
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -3
- transformers/models/mobilebert/configuration_mobilebert.py +0 -1
- transformers/models/mobilebert/modeling_mobilebert.py +77 -85
- transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
- transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
- transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +10 -12
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +17 -20
- transformers/models/mobilevit/configuration_mobilevit.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +46 -49
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +9 -11
- transformers/models/mobilevit/modeling_mobilevit.py +21 -19
- transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -20
- transformers/models/modernbert/configuration_modernbert.py +34 -34
- transformers/models/modernbert/modeling_modernbert.py +135 -126
- transformers/models/modernbert/modular_modernbert.py +167 -156
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +30 -32
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -48
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +78 -71
- transformers/models/moonshine/configuration_moonshine.py +22 -24
- transformers/models/moonshine/modeling_moonshine.py +64 -66
- transformers/models/moonshine/modular_moonshine.py +72 -73
- transformers/models/moshi/configuration_moshi.py +18 -21
- transformers/models/moshi/modeling_moshi.py +150 -183
- transformers/models/mpnet/configuration_mpnet.py +0 -1
- transformers/models/mpnet/modeling_mpnet.py +57 -57
- transformers/models/mpnet/tokenization_mpnet.py +1 -4
- transformers/models/mpt/configuration_mpt.py +1 -9
- transformers/models/mpt/modeling_mpt.py +58 -60
- transformers/models/mra/configuration_mra.py +0 -1
- transformers/models/mra/modeling_mra.py +58 -57
- transformers/models/mt5/configuration_mt5.py +2 -4
- transformers/models/mt5/modeling_mt5.py +75 -87
- transformers/models/musicgen/configuration_musicgen.py +0 -1
- transformers/models/musicgen/modeling_musicgen.py +113 -120
- transformers/models/musicgen/processing_musicgen.py +3 -21
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +0 -1
- transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +110 -109
- transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
- transformers/models/mvp/configuration_mvp.py +0 -1
- transformers/models/mvp/modeling_mvp.py +122 -119
- transformers/models/myt5/tokenization_myt5.py +8 -10
- transformers/models/nanochat/configuration_nanochat.py +0 -1
- transformers/models/nanochat/modeling_nanochat.py +33 -36
- transformers/models/nanochat/modular_nanochat.py +12 -14
- transformers/models/nemotron/configuration_nemotron.py +20 -23
- transformers/models/nemotron/modeling_nemotron.py +51 -54
- transformers/models/nllb/tokenization_nllb.py +7 -9
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -1
- transformers/models/nllb_moe/modeling_nllb_moe.py +77 -69
- transformers/models/nougat/image_processing_nougat.py +29 -32
- transformers/models/nougat/image_processing_nougat_fast.py +4 -6
- transformers/models/nougat/processing_nougat.py +37 -39
- transformers/models/nougat/tokenization_nougat.py +16 -23
- transformers/models/nystromformer/configuration_nystromformer.py +0 -1
- transformers/models/nystromformer/modeling_nystromformer.py +68 -63
- transformers/models/olmo/configuration_olmo.py +18 -21
- transformers/models/olmo/modeling_olmo.py +32 -35
- transformers/models/olmo/modular_olmo.py +5 -9
- transformers/models/olmo2/configuration_olmo2.py +18 -21
- transformers/models/olmo2/modeling_olmo2.py +33 -36
- transformers/models/olmo2/modular_olmo2.py +29 -31
- transformers/models/olmo3/__init__.py +0 -1
- transformers/models/olmo3/configuration_olmo3.py +20 -23
- transformers/models/olmo3/modeling_olmo3.py +32 -35
- transformers/models/olmo3/modular_olmo3.py +31 -33
- transformers/models/olmoe/configuration_olmoe.py +24 -26
- transformers/models/olmoe/modeling_olmoe.py +49 -43
- transformers/models/olmoe/modular_olmoe.py +16 -15
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -3
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +42 -40
- transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
- transformers/models/oneformer/configuration_oneformer.py +5 -8
- transformers/models/oneformer/image_processing_oneformer.py +83 -84
- transformers/models/oneformer/image_processing_oneformer_fast.py +33 -34
- transformers/models/oneformer/modeling_oneformer.py +130 -162
- transformers/models/oneformer/processing_oneformer.py +28 -43
- transformers/models/openai/configuration_openai.py +0 -1
- transformers/models/openai/modeling_openai.py +62 -51
- transformers/models/openai/tokenization_openai.py +2 -5
- transformers/models/opt/configuration_opt.py +0 -1
- transformers/models/opt/modeling_opt.py +74 -75
- transformers/models/ovis2/__init__.py +0 -1
- transformers/models/ovis2/configuration_ovis2.py +0 -1
- transformers/models/ovis2/image_processing_ovis2.py +22 -24
- transformers/models/ovis2/image_processing_ovis2_fast.py +6 -8
- transformers/models/ovis2/modeling_ovis2.py +58 -48
- transformers/models/ovis2/modular_ovis2.py +38 -32
- transformers/models/ovis2/processing_ovis2.py +12 -40
- transformers/models/owlv2/configuration_owlv2.py +0 -1
- transformers/models/owlv2/image_processing_owlv2.py +20 -21
- transformers/models/owlv2/image_processing_owlv2_fast.py +7 -10
- transformers/models/owlv2/modeling_owlv2.py +89 -90
- transformers/models/owlv2/modular_owlv2.py +6 -9
- transformers/models/owlv2/processing_owlv2.py +20 -49
- transformers/models/owlvit/configuration_owlvit.py +0 -1
- transformers/models/owlvit/image_processing_owlvit.py +21 -22
- transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
- transformers/models/owlvit/modeling_owlvit.py +88 -89
- transformers/models/owlvit/processing_owlvit.py +20 -48
- transformers/models/paddleocr_vl/__init__.py +0 -1
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +19 -19
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +37 -37
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +104 -90
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +90 -80
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
- transformers/models/paligemma/configuration_paligemma.py +0 -1
- transformers/models/paligemma/modeling_paligemma.py +73 -67
- transformers/models/paligemma/processing_paligemma.py +13 -66
- transformers/models/parakeet/configuration_parakeet.py +1 -4
- transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
- transformers/models/parakeet/modeling_parakeet.py +23 -22
- transformers/models/parakeet/modular_parakeet.py +21 -18
- transformers/models/parakeet/processing_parakeet.py +12 -5
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +5 -7
- transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +64 -62
- transformers/models/patchtst/configuration_patchtst.py +6 -9
- transformers/models/patchtst/modeling_patchtst.py +77 -78
- transformers/models/pe_audio/__init__.py +29 -0
- transformers/models/pe_audio/configuration_pe_audio.py +204 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +160 -0
- transformers/models/pe_audio/modeling_pe_audio.py +819 -0
- transformers/models/pe_audio/modular_pe_audio.py +298 -0
- transformers/models/pe_audio/processing_pe_audio.py +23 -0
- transformers/models/pe_audio_video/__init__.py +28 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +223 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +971 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +763 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +24 -0
- transformers/models/pe_video/__init__.py +29 -0
- transformers/models/pe_video/configuration_pe_video.py +209 -0
- transformers/models/pe_video/modeling_pe_video.py +635 -0
- transformers/models/pe_video/modular_pe_video.py +218 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +64 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -1
- transformers/models/pegasus/modeling_pegasus.py +66 -65
- transformers/models/pegasus/tokenization_pegasus.py +1 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +0 -1
- transformers/models/pegasus_x/modeling_pegasus_x.py +51 -52
- transformers/models/perceiver/configuration_perceiver.py +0 -1
- transformers/models/perceiver/image_processing_perceiver.py +22 -25
- transformers/models/perceiver/image_processing_perceiver_fast.py +5 -7
- transformers/models/perceiver/modeling_perceiver.py +140 -137
- transformers/models/perceiver/tokenization_perceiver.py +3 -6
- transformers/models/perception_lm/configuration_perception_lm.py +0 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -10
- transformers/models/perception_lm/modeling_perception_lm.py +45 -43
- transformers/models/perception_lm/modular_perception_lm.py +38 -36
- transformers/models/perception_lm/processing_perception_lm.py +13 -47
- transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
- transformers/models/persimmon/configuration_persimmon.py +18 -21
- transformers/models/persimmon/modeling_persimmon.py +40 -43
- transformers/models/phi/configuration_phi.py +19 -22
- transformers/models/phi/modeling_phi.py +36 -38
- transformers/models/phi/modular_phi.py +23 -23
- transformers/models/phi3/configuration_phi3.py +23 -26
- transformers/models/phi3/modeling_phi3.py +34 -37
- transformers/models/phi3/modular_phi3.py +13 -17
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +25 -26
- transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +7 -7
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +58 -57
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +62 -60
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -44
- transformers/models/phimoe/configuration_phimoe.py +26 -29
- transformers/models/phimoe/modeling_phimoe.py +47 -42
- transformers/models/phimoe/modular_phimoe.py +1 -2
- transformers/models/phobert/tokenization_phobert.py +4 -6
- transformers/models/pix2struct/configuration_pix2struct.py +0 -1
- transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +7 -10
- transformers/models/pix2struct/modeling_pix2struct.py +42 -45
- transformers/models/pix2struct/processing_pix2struct.py +5 -30
- transformers/models/pixio/__init__.py +29 -0
- transformers/models/pixio/configuration_pixio.py +150 -0
- transformers/models/pixio/modeling_pixio.py +505 -0
- transformers/models/pixio/modular_pixio.py +401 -0
- transformers/models/pixtral/configuration_pixtral.py +11 -14
- transformers/models/pixtral/image_processing_pixtral.py +26 -28
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -6
- transformers/models/pixtral/modeling_pixtral.py +23 -26
- transformers/models/pixtral/processing_pixtral.py +21 -53
- transformers/models/plbart/configuration_plbart.py +1 -1
- transformers/models/plbart/modeling_plbart.py +107 -102
- transformers/models/plbart/modular_plbart.py +36 -32
- transformers/models/plbart/tokenization_plbart.py +4 -5
- transformers/models/poolformer/configuration_poolformer.py +0 -1
- transformers/models/poolformer/image_processing_poolformer.py +21 -24
- transformers/models/poolformer/image_processing_poolformer_fast.py +6 -8
- transformers/models/poolformer/modeling_poolformer.py +21 -13
- transformers/models/pop2piano/configuration_pop2piano.py +0 -2
- transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
- transformers/models/pop2piano/modeling_pop2piano.py +22 -23
- transformers/models/pop2piano/processing_pop2piano.py +25 -33
- transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +3 -3
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +14 -15
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +9 -10
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +9 -10
- transformers/models/prophetnet/configuration_prophetnet.py +26 -28
- transformers/models/prophetnet/modeling_prophetnet.py +111 -131
- transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
- transformers/models/pvt/configuration_pvt.py +0 -1
- transformers/models/pvt/image_processing_pvt.py +17 -20
- transformers/models/pvt/image_processing_pvt_fast.py +0 -1
- transformers/models/pvt/modeling_pvt.py +19 -21
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +21 -23
- transformers/models/qwen2/configuration_qwen2.py +18 -21
- transformers/models/qwen2/modeling_qwen2.py +32 -34
- transformers/models/qwen2/modular_qwen2.py +11 -12
- transformers/models/qwen2/tokenization_qwen2.py +2 -5
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +20 -23
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +239 -192
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +174 -127
- transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +22 -25
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +112 -101
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +72 -107
- transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
- transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +29 -31
- transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +28 -31
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +48 -43
- transformers/models/qwen2_moe/modular_qwen2_moe.py +7 -10
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +22 -24
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +41 -42
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +8 -9
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +108 -96
- transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +35 -13
- transformers/models/qwen3/configuration_qwen3.py +20 -23
- transformers/models/qwen3/modeling_qwen3.py +32 -35
- transformers/models/qwen3/modular_qwen3.py +4 -6
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +25 -28
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +48 -43
- transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
- transformers/models/qwen3_next/configuration_qwen3_next.py +31 -34
- transformers/models/qwen3_next/modeling_qwen3_next.py +43 -48
- transformers/models/qwen3_next/modular_qwen3_next.py +33 -34
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +89 -88
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +199 -156
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +170 -152
- transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +21 -24
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +91 -81
- transformers/models/qwen3_vl/modular_qwen3_vl.py +86 -112
- transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
- transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +21 -25
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +174 -195
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +65 -117
- transformers/models/rag/configuration_rag.py +0 -9
- transformers/models/rag/modeling_rag.py +123 -127
- transformers/models/rag/retrieval_rag.py +2 -4
- transformers/models/rag/tokenization_rag.py +0 -50
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +21 -24
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +34 -36
- transformers/models/reformer/configuration_reformer.py +0 -1
- transformers/models/reformer/modeling_reformer.py +76 -69
- transformers/models/reformer/tokenization_reformer.py +3 -6
- transformers/models/regnet/configuration_regnet.py +0 -1
- transformers/models/regnet/modeling_regnet.py +11 -9
- transformers/models/rembert/configuration_rembert.py +0 -1
- transformers/models/rembert/modeling_rembert.py +115 -111
- transformers/models/rembert/tokenization_rembert.py +1 -4
- transformers/models/resnet/configuration_resnet.py +0 -1
- transformers/models/resnet/modeling_resnet.py +16 -13
- transformers/models/roberta/configuration_roberta.py +0 -1
- transformers/models/roberta/modeling_roberta.py +94 -93
- transformers/models/roberta/modular_roberta.py +58 -58
- transformers/models/roberta/tokenization_roberta.py +2 -5
- transformers/models/roberta/tokenization_roberta_old.py +2 -4
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +0 -1
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +94 -93
- transformers/models/roc_bert/configuration_roc_bert.py +0 -1
- transformers/models/roc_bert/modeling_roc_bert.py +122 -121
- transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
- transformers/models/roformer/configuration_roformer.py +0 -1
- transformers/models/roformer/modeling_roformer.py +79 -81
- transformers/models/roformer/tokenization_roformer.py +3 -6
- transformers/models/roformer/tokenization_utils.py +0 -1
- transformers/models/rt_detr/configuration_rt_detr.py +1 -2
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +0 -1
- transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +15 -15
- transformers/models/rt_detr/modeling_rt_detr.py +84 -82
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +10 -7
- transformers/models/rt_detr/modular_rt_detr.py +14 -14
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -4
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +86 -81
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +10 -7
- transformers/models/rwkv/configuration_rwkv.py +0 -1
- transformers/models/rwkv/modeling_rwkv.py +30 -32
- transformers/models/sam/configuration_sam.py +1 -1
- transformers/models/sam/image_processing_sam.py +59 -60
- transformers/models/sam/image_processing_sam_fast.py +21 -23
- transformers/models/sam/modeling_sam.py +37 -36
- transformers/models/sam/processing_sam.py +39 -27
- transformers/models/sam2/configuration_sam2.py +1 -2
- transformers/models/sam2/image_processing_sam2_fast.py +14 -15
- transformers/models/sam2/modeling_sam2.py +50 -48
- transformers/models/sam2/modular_sam2.py +48 -45
- transformers/models/sam2/processing_sam2.py +31 -47
- transformers/models/sam2_video/configuration_sam2_video.py +0 -1
- transformers/models/sam2_video/modeling_sam2_video.py +119 -112
- transformers/models/sam2_video/modular_sam2_video.py +91 -97
- transformers/models/sam2_video/processing_sam2_video.py +49 -66
- transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
- transformers/models/sam3/configuration_sam3.py +21 -2
- transformers/models/sam3/image_processing_sam3_fast.py +17 -20
- transformers/models/sam3/modeling_sam3.py +77 -56
- transformers/models/sam3/modular_sam3.py +3 -8
- transformers/models/sam3/processing_sam3.py +29 -48
- transformers/models/sam3_tracker/__init__.py +0 -1
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +36 -36
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -47
- transformers/models/sam3_tracker_video/__init__.py +0 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +96 -85
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +27 -6
- transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
- transformers/models/sam3_video/configuration_sam3_video.py +14 -1
- transformers/models/sam3_video/modeling_sam3_video.py +32 -34
- transformers/models/sam3_video/processing_sam3_video.py +26 -46
- transformers/models/sam_hq/__init__.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -1
- transformers/models/sam_hq/modeling_sam_hq.py +65 -64
- transformers/models/sam_hq/modular_sam_hq.py +17 -19
- transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +39 -28
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +0 -1
- transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +207 -193
- transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +0 -1
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +199 -195
- transformers/models/seed_oss/configuration_seed_oss.py +23 -25
- transformers/models/seed_oss/modeling_seed_oss.py +31 -33
- transformers/models/seed_oss/modular_seed_oss.py +3 -4
- transformers/models/segformer/configuration_segformer.py +0 -10
- transformers/models/segformer/image_processing_segformer.py +39 -42
- transformers/models/segformer/image_processing_segformer_fast.py +7 -9
- transformers/models/segformer/modeling_segformer.py +26 -28
- transformers/models/segformer/modular_segformer.py +5 -7
- transformers/models/seggpt/configuration_seggpt.py +0 -1
- transformers/models/seggpt/image_processing_seggpt.py +38 -41
- transformers/models/seggpt/modeling_seggpt.py +28 -30
- transformers/models/sew/configuration_sew.py +0 -1
- transformers/models/sew/modeling_sew.py +33 -35
- transformers/models/sew/modular_sew.py +10 -12
- transformers/models/sew_d/configuration_sew_d.py +0 -1
- transformers/models/sew_d/modeling_sew_d.py +28 -30
- transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +16 -17
- transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
- transformers/models/siglip/configuration_siglip.py +0 -1
- transformers/models/siglip/image_processing_siglip.py +17 -20
- transformers/models/siglip/image_processing_siglip_fast.py +0 -1
- transformers/models/siglip/modeling_siglip.py +62 -41
- transformers/models/siglip/processing_siglip.py +2 -14
- transformers/models/siglip/tokenization_siglip.py +6 -7
- transformers/models/siglip2/configuration_siglip2.py +1 -1
- transformers/models/siglip2/image_processing_siglip2.py +15 -16
- transformers/models/siglip2/image_processing_siglip2_fast.py +4 -5
- transformers/models/siglip2/modeling_siglip2.py +114 -92
- transformers/models/siglip2/modular_siglip2.py +23 -25
- transformers/models/siglip2/processing_siglip2.py +2 -14
- transformers/models/smollm3/configuration_smollm3.py +23 -26
- transformers/models/smollm3/modeling_smollm3.py +32 -35
- transformers/models/smollm3/modular_smollm3.py +27 -29
- transformers/models/smolvlm/configuration_smolvlm.py +1 -1
- transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +12 -12
- transformers/models/smolvlm/modeling_smolvlm.py +56 -53
- transformers/models/smolvlm/modular_smolvlm.py +15 -17
- transformers/models/smolvlm/processing_smolvlm.py +15 -76
- transformers/models/smolvlm/video_processing_smolvlm.py +7 -9
- transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +20 -23
- transformers/models/speech_to_text/configuration_speech_to_text.py +0 -1
- transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
- transformers/models/speech_to_text/modeling_speech_to_text.py +62 -54
- transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
- transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
- transformers/models/speecht5/configuration_speecht5.py +0 -1
- transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
- transformers/models/speecht5/modeling_speecht5.py +200 -174
- transformers/models/speecht5/number_normalizer.py +0 -1
- transformers/models/speecht5/processing_speecht5.py +3 -37
- transformers/models/speecht5/tokenization_speecht5.py +4 -5
- transformers/models/splinter/configuration_splinter.py +0 -1
- transformers/models/splinter/modeling_splinter.py +63 -59
- transformers/models/splinter/tokenization_splinter.py +2 -4
- transformers/models/squeezebert/configuration_squeezebert.py +0 -1
- transformers/models/squeezebert/modeling_squeezebert.py +62 -62
- transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
- transformers/models/stablelm/configuration_stablelm.py +20 -23
- transformers/models/stablelm/modeling_stablelm.py +40 -43
- transformers/models/starcoder2/configuration_starcoder2.py +19 -22
- transformers/models/starcoder2/modeling_starcoder2.py +34 -37
- transformers/models/starcoder2/modular_starcoder2.py +13 -15
- transformers/models/superglue/configuration_superglue.py +3 -3
- transformers/models/superglue/image_processing_superglue.py +15 -15
- transformers/models/superglue/image_processing_superglue_fast.py +5 -7
- transformers/models/superglue/modeling_superglue.py +32 -33
- transformers/models/superpoint/image_processing_superpoint.py +15 -15
- transformers/models/superpoint/image_processing_superpoint_fast.py +5 -7
- transformers/models/superpoint/modeling_superpoint.py +13 -14
- transformers/models/swiftformer/configuration_swiftformer.py +0 -1
- transformers/models/swiftformer/modeling_swiftformer.py +16 -14
- transformers/models/swin/configuration_swin.py +0 -1
- transformers/models/swin/modeling_swin.py +74 -82
- transformers/models/swin2sr/configuration_swin2sr.py +0 -1
- transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -6
- transformers/models/swin2sr/modeling_swin2sr.py +75 -61
- transformers/models/swinv2/configuration_swinv2.py +0 -1
- transformers/models/swinv2/modeling_swinv2.py +96 -100
- transformers/models/switch_transformers/configuration_switch_transformers.py +0 -1
- transformers/models/switch_transformers/modeling_switch_transformers.py +34 -41
- transformers/models/switch_transformers/modular_switch_transformers.py +31 -38
- transformers/models/t5/configuration_t5.py +7 -2
- transformers/models/t5/modeling_t5.py +76 -84
- transformers/models/t5/tokenization_t5.py +1 -3
- transformers/models/t5gemma/configuration_t5gemma.py +33 -34
- transformers/models/t5gemma/modeling_t5gemma.py +97 -100
- transformers/models/t5gemma/modular_t5gemma.py +117 -118
- transformers/models/t5gemma2/configuration_t5gemma2.py +59 -96
- transformers/models/t5gemma2/modeling_t5gemma2.py +109 -103
- transformers/models/t5gemma2/modular_t5gemma2.py +375 -91
- transformers/models/table_transformer/configuration_table_transformer.py +1 -2
- transformers/models/table_transformer/modeling_table_transformer.py +47 -49
- transformers/models/tapas/configuration_tapas.py +0 -1
- transformers/models/tapas/modeling_tapas.py +64 -66
- transformers/models/tapas/tokenization_tapas.py +115 -153
- transformers/models/textnet/configuration_textnet.py +0 -1
- transformers/models/textnet/image_processing_textnet.py +22 -25
- transformers/models/textnet/image_processing_textnet_fast.py +5 -7
- transformers/models/textnet/modeling_textnet.py +13 -14
- transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +79 -81
- transformers/models/timesfm/configuration_timesfm.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +29 -19
- transformers/models/timesfm/modular_timesfm.py +28 -18
- transformers/models/timesformer/configuration_timesformer.py +0 -1
- transformers/models/timesformer/modeling_timesformer.py +13 -16
- transformers/models/timm_backbone/configuration_timm_backbone.py +0 -1
- transformers/models/timm_backbone/modeling_timm_backbone.py +17 -15
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +5 -3
- transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +32 -28
- transformers/models/trocr/configuration_trocr.py +0 -1
- transformers/models/trocr/modeling_trocr.py +39 -42
- transformers/models/trocr/processing_trocr.py +5 -25
- transformers/models/tvp/configuration_tvp.py +5 -2
- transformers/models/tvp/image_processing_tvp.py +50 -52
- transformers/models/tvp/image_processing_tvp_fast.py +9 -10
- transformers/models/tvp/modeling_tvp.py +25 -27
- transformers/models/tvp/processing_tvp.py +2 -14
- transformers/models/udop/configuration_udop.py +1 -1
- transformers/models/udop/modeling_udop.py +63 -70
- transformers/models/udop/processing_udop.py +7 -26
- transformers/models/udop/tokenization_udop.py +80 -93
- transformers/models/umt5/configuration_umt5.py +2 -3
- transformers/models/umt5/modeling_umt5.py +80 -87
- transformers/models/unispeech/configuration_unispeech.py +0 -1
- transformers/models/unispeech/modeling_unispeech.py +47 -49
- transformers/models/unispeech/modular_unispeech.py +20 -22
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +0 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +63 -65
- transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
- transformers/models/univnet/feature_extraction_univnet.py +14 -14
- transformers/models/univnet/modeling_univnet.py +7 -8
- transformers/models/upernet/configuration_upernet.py +0 -1
- transformers/models/upernet/modeling_upernet.py +10 -13
- transformers/models/vaultgemma/__init__.py +0 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +24 -26
- transformers/models/vaultgemma/modeling_vaultgemma.py +35 -37
- transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
- transformers/models/video_llama_3/image_processing_video_llama_3.py +43 -42
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +8 -8
- transformers/models/video_llama_3/modeling_video_llama_3.py +77 -66
- transformers/models/video_llama_3/modular_video_llama_3.py +110 -112
- transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
- transformers/models/video_llama_3/video_processing_video_llama_3.py +18 -18
- transformers/models/video_llava/configuration_video_llava.py +0 -1
- transformers/models/video_llava/image_processing_video_llava.py +35 -38
- transformers/models/video_llava/modeling_video_llava.py +59 -57
- transformers/models/video_llava/processing_video_llava.py +38 -78
- transformers/models/video_llava/video_processing_video_llava.py +0 -1
- transformers/models/videomae/configuration_videomae.py +0 -1
- transformers/models/videomae/image_processing_videomae.py +31 -34
- transformers/models/videomae/modeling_videomae.py +13 -15
- transformers/models/videomae/video_processing_videomae.py +0 -1
- transformers/models/vilt/configuration_vilt.py +2 -3
- transformers/models/vilt/image_processing_vilt.py +29 -30
- transformers/models/vilt/image_processing_vilt_fast.py +9 -10
- transformers/models/vilt/modeling_vilt.py +83 -78
- transformers/models/vilt/processing_vilt.py +2 -14
- transformers/models/vipllava/configuration_vipllava.py +0 -1
- transformers/models/vipllava/modeling_vipllava.py +45 -42
- transformers/models/vipllava/modular_vipllava.py +30 -32
- transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +18 -21
- transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +18 -21
- transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
- transformers/models/visual_bert/configuration_visual_bert.py +0 -1
- transformers/models/visual_bert/modeling_visual_bert.py +92 -92
- transformers/models/vit/configuration_vit.py +0 -1
- transformers/models/vit/image_processing_vit.py +19 -22
- transformers/models/vit/image_processing_vit_fast.py +0 -1
- transformers/models/vit/modeling_vit.py +13 -15
- transformers/models/vit_mae/configuration_vit_mae.py +0 -1
- transformers/models/vit_mae/modeling_vit_mae.py +21 -23
- transformers/models/vit_msn/configuration_vit_msn.py +0 -1
- transformers/models/vit_msn/modeling_vit_msn.py +10 -12
- transformers/models/vitdet/configuration_vitdet.py +0 -1
- transformers/models/vitdet/modeling_vitdet.py +12 -14
- transformers/models/vitmatte/configuration_vitmatte.py +2 -5
- transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +14 -16
- transformers/models/vitmatte/modeling_vitmatte.py +13 -11
- transformers/models/vitpose/configuration_vitpose.py +4 -7
- transformers/models/vitpose/image_processing_vitpose.py +24 -25
- transformers/models/vitpose/image_processing_vitpose_fast.py +9 -11
- transformers/models/vitpose/modeling_vitpose.py +10 -12
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +0 -1
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +8 -10
- transformers/models/vits/configuration_vits.py +0 -1
- transformers/models/vits/modeling_vits.py +34 -35
- transformers/models/vits/tokenization_vits.py +3 -4
- transformers/models/vivit/configuration_vivit.py +0 -1
- transformers/models/vivit/image_processing_vivit.py +36 -39
- transformers/models/vivit/modeling_vivit.py +5 -7
- transformers/models/vjepa2/__init__.py +0 -1
- transformers/models/vjepa2/configuration_vjepa2.py +0 -1
- transformers/models/vjepa2/modeling_vjepa2.py +30 -32
- transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
- transformers/models/voxtral/__init__.py +0 -1
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +19 -27
- transformers/models/voxtral/modular_voxtral.py +12 -21
- transformers/models/voxtral/processing_voxtral.py +25 -48
- transformers/models/wav2vec2/configuration_wav2vec2.py +0 -1
- transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
- transformers/models/wav2vec2/modeling_wav2vec2.py +67 -122
- transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
- transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +0 -1
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +65 -62
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +52 -48
- transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +0 -1
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +84 -77
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +37 -30
- transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
- transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
- transformers/models/wavlm/configuration_wavlm.py +0 -1
- transformers/models/wavlm/modeling_wavlm.py +45 -48
- transformers/models/wavlm/modular_wavlm.py +4 -5
- transformers/models/whisper/configuration_whisper.py +0 -1
- transformers/models/whisper/english_normalizer.py +3 -4
- transformers/models/whisper/feature_extraction_whisper.py +9 -24
- transformers/models/whisper/generation_whisper.py +27 -48
- transformers/models/whisper/modeling_whisper.py +73 -73
- transformers/models/whisper/processing_whisper.py +3 -20
- transformers/models/whisper/tokenization_whisper.py +9 -30
- transformers/models/x_clip/configuration_x_clip.py +0 -1
- transformers/models/x_clip/modeling_x_clip.py +70 -69
- transformers/models/x_clip/processing_x_clip.py +2 -14
- transformers/models/xcodec/configuration_xcodec.py +4 -6
- transformers/models/xcodec/modeling_xcodec.py +20 -17
- transformers/models/xglm/configuration_xglm.py +0 -1
- transformers/models/xglm/modeling_xglm.py +59 -55
- transformers/models/xglm/tokenization_xglm.py +1 -4
- transformers/models/xlm/configuration_xlm.py +0 -1
- transformers/models/xlm/modeling_xlm.py +139 -144
- transformers/models/xlm/tokenization_xlm.py +3 -5
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +0 -1
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +195 -194
- transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +0 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +94 -93
- transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
- transformers/models/xlnet/configuration_xlnet.py +0 -11
- transformers/models/xlnet/modeling_xlnet.py +152 -163
- transformers/models/xlnet/tokenization_xlnet.py +1 -4
- transformers/models/xlstm/configuration_xlstm.py +3 -5
- transformers/models/xlstm/modeling_xlstm.py +62 -65
- transformers/models/xmod/configuration_xmod.py +0 -1
- transformers/models/xmod/modeling_xmod.py +101 -100
- transformers/models/yolos/configuration_yolos.py +0 -1
- transformers/models/yolos/image_processing_yolos.py +60 -62
- transformers/models/yolos/image_processing_yolos_fast.py +18 -18
- transformers/models/yolos/modeling_yolos.py +12 -14
- transformers/models/yolos/modular_yolos.py +2 -4
- transformers/models/yoso/configuration_yoso.py +0 -1
- transformers/models/yoso/modeling_yoso.py +64 -63
- transformers/models/zamba/configuration_zamba.py +0 -1
- transformers/models/zamba/modeling_zamba.py +70 -70
- transformers/models/zamba2/configuration_zamba2.py +36 -37
- transformers/models/zamba2/modeling_zamba2.py +87 -89
- transformers/models/zamba2/modular_zamba2.py +43 -45
- transformers/models/zoedepth/configuration_zoedepth.py +1 -2
- transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +12 -15
- transformers/models/zoedepth/modeling_zoedepth.py +21 -16
- transformers/pipelines/__init__.py +59 -55
- transformers/pipelines/any_to_any.py +14 -22
- transformers/pipelines/audio_utils.py +1 -2
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +13 -17
- transformers/pipelines/deprecated/__init__.py +0 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/image_text_to_text.py +0 -1
- transformers/pipelines/image_to_text.py +4 -44
- transformers/pipelines/question_answering.py +5 -44
- transformers/pipelines/text_classification.py +1 -14
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/pipelines/token_classification.py +1 -22
- transformers/pipelines/video_classification.py +1 -9
- transformers/pipelines/zero_shot_audio_classification.py +0 -1
- transformers/pipelines/zero_shot_classification.py +0 -6
- transformers/pipelines/zero_shot_image_classification.py +0 -7
- transformers/processing_utils.py +222 -151
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +19 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_quark.py +0 -1
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +31 -127
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +31 -49
- transformers/tokenization_mistral_common.py +554 -902
- transformers/tokenization_utils_base.py +112 -124
- transformers/tokenization_utils_sentencepiece.py +5 -6
- transformers/tokenization_utils_tokenizers.py +30 -7
- transformers/trainer.py +30 -11
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +1 -2
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +11 -13
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +5 -5
- transformers/utils/auto_docstring.py +598 -37
- transformers/utils/doc.py +1 -1
- transformers/utils/dummy_pt_objects.py +0 -42
- transformers/utils/generic.py +21 -1
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/loading_report.py +3 -3
- transformers/utils/quantization_config.py +16 -18
- transformers/video_processing_utils.py +35 -32
- transformers/video_utils.py +18 -22
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc3.dist-info}/METADATA +23 -24
- transformers-5.0.0rc3.dist-info/RECORD +2067 -0
- transformers-5.0.0rc1.dist-info/RECORD +0 -2003
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc3.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc3.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc3.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc3.dist-info}/top_level.txt +0 -0
transformers/modeling_utils.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# coding=utf-8
|
|
2
1
|
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
|
3
2
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
4
3
|
#
|
|
@@ -16,7 +15,6 @@
|
|
|
16
15
|
import collections
|
|
17
16
|
import copy
|
|
18
17
|
import functools
|
|
19
|
-
import gc
|
|
20
18
|
import importlib.metadata
|
|
21
19
|
import inspect
|
|
22
20
|
import json
|
|
@@ -26,13 +24,13 @@ import sys
|
|
|
26
24
|
import warnings
|
|
27
25
|
from abc import abstractmethod
|
|
28
26
|
from collections import defaultdict
|
|
29
|
-
from collections.abc import Callable, Sequence
|
|
27
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
30
28
|
from contextlib import contextmanager
|
|
31
29
|
from enum import Enum
|
|
32
30
|
from functools import partial, wraps
|
|
33
31
|
from itertools import cycle
|
|
34
32
|
from threading import Thread
|
|
35
|
-
from typing import Optional, TypeVar,
|
|
33
|
+
from typing import Optional, TypeVar, get_type_hints
|
|
36
34
|
from zipfile import is_zipfile
|
|
37
35
|
|
|
38
36
|
import torch
|
|
@@ -63,7 +61,8 @@ from .integrations.accelerate import (
|
|
|
63
61
|
accelerate_dispatch,
|
|
64
62
|
check_and_set_device_map,
|
|
65
63
|
expand_device_map,
|
|
66
|
-
|
|
64
|
+
get_device,
|
|
65
|
+
load_offloaded_parameter,
|
|
67
66
|
)
|
|
68
67
|
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
|
69
68
|
from .integrations.eager_paged import eager_paged_attention_forward
|
|
@@ -86,6 +85,7 @@ from .integrations.tensor_parallel import (
|
|
|
86
85
|
)
|
|
87
86
|
from .loss.loss_utils import LOSS_MAPPING
|
|
88
87
|
from .modeling_flash_attention_utils import lazy_import_flash_attention, lazy_import_paged_flash_attention
|
|
88
|
+
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
89
89
|
from .pytorch_utils import id_tensor_storage
|
|
90
90
|
from .quantizers import HfQuantizer
|
|
91
91
|
from .quantizers.auto import get_hf_quantizer
|
|
@@ -108,6 +108,7 @@ from .utils import (
|
|
|
108
108
|
is_accelerate_available,
|
|
109
109
|
is_flash_attn_2_available,
|
|
110
110
|
is_flash_attn_3_available,
|
|
111
|
+
is_grouped_mm_available,
|
|
111
112
|
is_kernels_available,
|
|
112
113
|
is_torch_flex_attn_available,
|
|
113
114
|
is_torch_greater_or_equal,
|
|
@@ -130,7 +131,6 @@ from .utils.quantization_config import QuantizationMethod
|
|
|
130
131
|
if is_accelerate_available():
|
|
131
132
|
from accelerate.hooks import add_hook_to_module
|
|
132
133
|
from accelerate.utils import extract_model_from_parallel
|
|
133
|
-
from accelerate.utils.modeling import get_state_dict_from_offload
|
|
134
134
|
|
|
135
135
|
|
|
136
136
|
_torch_distributed_available = torch.distributed.is_available()
|
|
@@ -152,10 +152,15 @@ logger = logging.get_logger(__name__)
|
|
|
152
152
|
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
|
|
153
153
|
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
|
154
154
|
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
|
|
155
|
-
_init_weights = True
|
|
156
155
|
_is_quantized = False
|
|
157
156
|
_is_ds_init_called = False
|
|
158
157
|
|
|
158
|
+
# Mapping from flash attention implementations to their kernel fallback repositories
|
|
159
|
+
FLASH_ATTN_KERNEL_FALLBACK = {
|
|
160
|
+
"flash_attention_2": "kernels-community/flash-attn2",
|
|
161
|
+
"flash_attention_3": "kernels-community/vllm-flash-attn3",
|
|
162
|
+
}
|
|
163
|
+
|
|
159
164
|
|
|
160
165
|
def is_local_dist_rank_0():
|
|
161
166
|
return (
|
|
@@ -165,51 +170,6 @@ def is_local_dist_rank_0():
|
|
|
165
170
|
)
|
|
166
171
|
|
|
167
172
|
|
|
168
|
-
TORCH_INIT_FUNCTIONS = {
|
|
169
|
-
"uniform_": nn.init.uniform_,
|
|
170
|
-
"normal_": nn.init.normal_,
|
|
171
|
-
"trunc_normal_": nn.init.trunc_normal_,
|
|
172
|
-
"constant_": nn.init.constant_,
|
|
173
|
-
"xavier_uniform_": nn.init.xavier_uniform_,
|
|
174
|
-
"xavier_normal_": nn.init.xavier_normal_,
|
|
175
|
-
"kaiming_uniform_": nn.init.kaiming_uniform_,
|
|
176
|
-
"kaiming_normal_": nn.init.kaiming_normal_,
|
|
177
|
-
"uniform": nn.init.uniform,
|
|
178
|
-
"normal": nn.init.normal,
|
|
179
|
-
"xavier_uniform": nn.init.xavier_uniform,
|
|
180
|
-
"xavier_normal": nn.init.xavier_normal,
|
|
181
|
-
"kaiming_uniform": nn.init.kaiming_uniform,
|
|
182
|
-
"kaiming_normal": nn.init.kaiming_normal,
|
|
183
|
-
"orthogonal_": nn.init.orthogonal_,
|
|
184
|
-
}
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
@contextmanager
|
|
188
|
-
def no_init_weights():
|
|
189
|
-
"""
|
|
190
|
-
Context manager to globally disable weight initialization to speed up loading large models.
|
|
191
|
-
"""
|
|
192
|
-
global _init_weights
|
|
193
|
-
old_init_weights = _init_weights
|
|
194
|
-
|
|
195
|
-
_init_weights = False
|
|
196
|
-
|
|
197
|
-
def _skip_init(*args, **kwargs):
|
|
198
|
-
pass
|
|
199
|
-
|
|
200
|
-
# Save the original initialization functions
|
|
201
|
-
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
|
202
|
-
setattr(torch.nn.init, name, _skip_init)
|
|
203
|
-
|
|
204
|
-
try:
|
|
205
|
-
yield
|
|
206
|
-
finally:
|
|
207
|
-
_init_weights = old_init_weights
|
|
208
|
-
# Restore the original initialization functions
|
|
209
|
-
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
|
210
|
-
setattr(torch.nn.init, name, init_func)
|
|
211
|
-
|
|
212
|
-
|
|
213
173
|
@contextmanager
|
|
214
174
|
def set_quantized_state():
|
|
215
175
|
global _is_quantized
|
|
@@ -233,23 +193,28 @@ def set_zero3_state():
|
|
|
233
193
|
_is_ds_init_called = False
|
|
234
194
|
|
|
235
195
|
|
|
236
|
-
|
|
196
|
+
@contextmanager
|
|
197
|
+
def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
|
|
237
198
|
"""
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
as a backup in case calling the function raises
|
|
241
|
-
an error after the function has changed the default dtype but before it could restore it.
|
|
199
|
+
Locally change the torch default dtype to `dtype`, and restore the old one upon exiting the context.
|
|
200
|
+
If `model_class_name` is provided, it's used to provide a more helpful error message if `dtype` is not valid.
|
|
242
201
|
"""
|
|
202
|
+
# Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
|
|
203
|
+
if not dtype.is_floating_point:
|
|
204
|
+
if model_class_name is not None:
|
|
205
|
+
error_message = (
|
|
206
|
+
f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype"
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
|
|
210
|
+
raise ValueError(error_message)
|
|
243
211
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
torch.set_default_dtype(old_dtype)
|
|
251
|
-
|
|
252
|
-
return _wrapper
|
|
212
|
+
original_dtype = torch.get_default_dtype()
|
|
213
|
+
try:
|
|
214
|
+
torch.set_default_dtype(dtype)
|
|
215
|
+
yield
|
|
216
|
+
finally:
|
|
217
|
+
torch.set_default_dtype(original_dtype)
|
|
253
218
|
|
|
254
219
|
|
|
255
220
|
def get_torch_context_manager_or_global_device():
|
|
@@ -305,7 +270,7 @@ if is_torch_greater_or_equal("2.3.0"):
|
|
|
305
270
|
|
|
306
271
|
|
|
307
272
|
def load_state_dict(
|
|
308
|
-
checkpoint_file:
|
|
273
|
+
checkpoint_file: str | os.PathLike, map_location: str | torch.device = "cpu", weights_only: bool = True
|
|
309
274
|
) -> dict[str, torch.Tensor]:
|
|
310
275
|
"""
|
|
311
276
|
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
|
|
@@ -405,14 +370,97 @@ def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]
|
|
|
405
370
|
return shared_tensors, identical
|
|
406
371
|
|
|
407
372
|
|
|
373
|
+
def remove_tied_weights_from_state_dict(
|
|
374
|
+
state_dict: dict[str, torch.Tensor], model: "PreTrainedModel"
|
|
375
|
+
) -> dict[str, torch.Tensor]:
|
|
376
|
+
"""
|
|
377
|
+
Remove all tied weights from the given `state_dict`, making sure to keep only the main weight that `model`
|
|
378
|
+
will expect when reloading (even if we know tie weights symmetrically, it's better to keep the intended one).
|
|
379
|
+
This is because `safetensors` does not allow tensor aliasing - so we're going to remove aliases before saving.
|
|
380
|
+
"""
|
|
381
|
+
# To avoid any potential mistakes and mismatches between config and actual tied weights, here we check the pointers
|
|
382
|
+
# of the Tensors themselves -> we are guaranteed to find all the actual tied weights
|
|
383
|
+
ptrs = collections.defaultdict(list)
|
|
384
|
+
for name, tensor in state_dict.items():
|
|
385
|
+
if not isinstance(tensor, torch.Tensor):
|
|
386
|
+
# Sometimes in the state_dict we have non-tensor objects.
|
|
387
|
+
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
|
388
|
+
# In the non-tensor case, fall back to the pointer of the object itself
|
|
389
|
+
ptrs[id(tensor)].append(name)
|
|
390
|
+
|
|
391
|
+
elif tensor.device.type == "meta":
|
|
392
|
+
# In offloaded cases, there may be meta tensors in the state_dict.
|
|
393
|
+
# For these cases, key by the pointer of the original tensor object
|
|
394
|
+
# (state_dict tensors are detached and therefore no longer shared)
|
|
395
|
+
tensor = model.get_parameter(name)
|
|
396
|
+
ptrs[id(tensor)].append(name)
|
|
397
|
+
|
|
398
|
+
else:
|
|
399
|
+
ptrs[id_tensor_storage(tensor)].append(name)
|
|
400
|
+
|
|
401
|
+
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
|
402
|
+
|
|
403
|
+
# Recursively descend to find tied weight keys
|
|
404
|
+
all_potential_tied_weights_keys = set(_get_tied_weight_keys(model))
|
|
405
|
+
error_names = []
|
|
406
|
+
to_delete_names = set()
|
|
407
|
+
# Removing the keys which are declared as known duplicates on load. This allows to make sure the name which is
|
|
408
|
+
# kept is consistent
|
|
409
|
+
if all_potential_tied_weights_keys is not None:
|
|
410
|
+
for names in shared_ptrs.values():
|
|
411
|
+
found = 0
|
|
412
|
+
for name in sorted(names):
|
|
413
|
+
matches_pattern = any(re.search(pat, name) for pat in all_potential_tied_weights_keys)
|
|
414
|
+
if matches_pattern and name in state_dict:
|
|
415
|
+
found += 1
|
|
416
|
+
if found < len(names):
|
|
417
|
+
to_delete_names.add(name)
|
|
418
|
+
# We are entering a place where the weights and the transformers configuration do NOT match.
|
|
419
|
+
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
|
|
420
|
+
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
|
|
421
|
+
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
|
|
422
|
+
for name in disjoint_names:
|
|
423
|
+
state_dict[name] = state_dict[name].clone()
|
|
424
|
+
|
|
425
|
+
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
|
426
|
+
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
|
427
|
+
# the key back leading to random tensor. A proper warning will be shown
|
|
428
|
+
# during reload (if applicable), but since the file is not necessarily compatible with
|
|
429
|
+
# the config, better show a proper warning.
|
|
430
|
+
shared_names, identical_names = _find_identical(shared_names, state_dict)
|
|
431
|
+
# delete tensors that have identical storage
|
|
432
|
+
for inames in identical_names:
|
|
433
|
+
known = inames.intersection(to_delete_names)
|
|
434
|
+
for name in known:
|
|
435
|
+
del state_dict[name]
|
|
436
|
+
unknown = inames.difference(to_delete_names)
|
|
437
|
+
if len(unknown) > 1:
|
|
438
|
+
error_names.append(unknown)
|
|
439
|
+
|
|
440
|
+
if shared_names:
|
|
441
|
+
error_names.extend(shared_names)
|
|
442
|
+
|
|
443
|
+
if len(error_names) > 0:
|
|
444
|
+
raise RuntimeError(
|
|
445
|
+
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. "
|
|
446
|
+
f"We found all the potential target tied weights keys to be: {all_potential_tied_weights_keys}.\n"
|
|
447
|
+
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
return state_dict
|
|
451
|
+
|
|
452
|
+
|
|
408
453
|
def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
|
|
409
|
-
"""Cast a single parameter `param_name` into the `model`, with value `tensor`."""
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
454
|
+
"""Cast a single parameter or buffer `param_name` into the `model`, with value `tensor`."""
|
|
455
|
+
parent, param_type = get_module_from_name(model, param_name)
|
|
456
|
+
if param_type in parent._parameters and not isinstance(tensor, nn.Parameter):
|
|
457
|
+
tensor = nn.Parameter(tensor, requires_grad=tensor.is_floating_point())
|
|
458
|
+
# We need to use setattr here, as we set non-persistent buffers as well with this function (`load_state_dict`
|
|
459
|
+
# does not allow to do it)
|
|
460
|
+
setattr(parent, param_type, tensor)
|
|
413
461
|
|
|
414
462
|
|
|
415
|
-
def _add_variant(weights_name: str, variant:
|
|
463
|
+
def _add_variant(weights_name: str, variant: str | None = None) -> str:
|
|
416
464
|
if variant is not None:
|
|
417
465
|
path, name = weights_name.rsplit(".", 1)
|
|
418
466
|
weights_name = f"{path}.{variant}.{name}"
|
|
@@ -420,15 +468,15 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|
|
420
468
|
|
|
421
469
|
|
|
422
470
|
def _get_resolved_checkpoint_files(
|
|
423
|
-
pretrained_model_name_or_path:
|
|
424
|
-
variant:
|
|
425
|
-
gguf_file:
|
|
426
|
-
use_safetensors:
|
|
471
|
+
pretrained_model_name_or_path: str | os.PathLike | None,
|
|
472
|
+
variant: str | None,
|
|
473
|
+
gguf_file: str | None,
|
|
474
|
+
use_safetensors: bool | None,
|
|
427
475
|
download_kwargs: DownloadKwargs,
|
|
428
476
|
user_agent: dict,
|
|
429
477
|
is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
|
|
430
|
-
transformers_explicit_filename:
|
|
431
|
-
) -> tuple[
|
|
478
|
+
transformers_explicit_filename: str | None = None,
|
|
479
|
+
) -> tuple[list[str] | None, dict | None]:
|
|
432
480
|
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
|
433
481
|
checkpoints are sharded.
|
|
434
482
|
This function will download the data if necessary.
|
|
@@ -696,22 +744,20 @@ def _get_resolved_checkpoint_files(
|
|
|
696
744
|
|
|
697
745
|
|
|
698
746
|
def _get_dtype(
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
checkpoint_files: Optional[list[str]],
|
|
747
|
+
dtype: str | torch.dtype | dict | None,
|
|
748
|
+
checkpoint_files: list[str] | None,
|
|
702
749
|
config: PreTrainedConfig,
|
|
703
|
-
sharded_metadata:
|
|
704
|
-
state_dict:
|
|
750
|
+
sharded_metadata: dict | None,
|
|
751
|
+
state_dict: dict | None,
|
|
705
752
|
weights_only: bool,
|
|
706
|
-
|
|
753
|
+
hf_quantizer: HfQuantizer | None = None,
|
|
754
|
+
) -> tuple[PreTrainedConfig, torch.dtype]:
|
|
707
755
|
"""Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
|
|
708
756
|
inferred dtype. We do the following:
|
|
709
|
-
1. If dtype is
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
we also may have config.dtype available, but we won't rely on it till v5
|
|
757
|
+
1. If dtype is "auto", we try to read the config, else auto-detect dtype from the loaded state_dict, by checking
|
|
758
|
+
its first weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
|
759
|
+
2. Else, use the dtype provided as a dict or str
|
|
713
760
|
"""
|
|
714
|
-
dtype_orig = None
|
|
715
761
|
is_sharded = sharded_metadata is not None
|
|
716
762
|
|
|
717
763
|
if dtype is not None:
|
|
@@ -736,43 +782,46 @@ def _get_dtype(
|
|
|
736
782
|
)
|
|
737
783
|
elif hasattr(torch, dtype):
|
|
738
784
|
dtype = getattr(torch, dtype)
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
sub_config.dtype = dtype
|
|
748
|
-
elif isinstance(dtype, dict):
|
|
749
|
-
for key, curr_dtype in dtype.items():
|
|
750
|
-
if hasattr(config, key):
|
|
751
|
-
value = getattr(config, key)
|
|
752
|
-
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
|
|
753
|
-
value.dtype = curr_dtype
|
|
754
|
-
# main torch dtype for modules that aren't part of any sub-config
|
|
755
|
-
dtype = dtype.get("")
|
|
756
|
-
dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
|
|
757
|
-
config.dtype = dtype
|
|
758
|
-
if dtype is None:
|
|
759
|
-
dtype = torch.float32
|
|
760
|
-
else:
|
|
785
|
+
else:
|
|
786
|
+
raise ValueError(
|
|
787
|
+
"`dtype` provided as a `str` can only be `'auto'`, or a string representation of a valid `torch.dtype`"
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# cast it to a proper `torch.dtype` object
|
|
791
|
+
dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
|
792
|
+
elif not isinstance(dtype, (dict, torch.dtype)):
|
|
761
793
|
raise ValueError(
|
|
762
794
|
f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
|
|
763
795
|
f"for each sub-config in composite configs, but received {dtype}"
|
|
764
796
|
)
|
|
797
|
+
else:
|
|
798
|
+
# set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
|
|
799
|
+
dtype = torch.get_default_dtype()
|
|
800
|
+
|
|
801
|
+
if hf_quantizer is not None:
|
|
802
|
+
hf_quantizer.update_dtype(dtype)
|
|
803
|
+
|
|
804
|
+
# Get the main dtype
|
|
805
|
+
if isinstance(dtype, dict):
|
|
806
|
+
main_dtype = dtype.get("", torch.get_default_dtype())
|
|
807
|
+
main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
|
|
808
|
+
|
|
809
|
+
logger.warning_once(
|
|
810
|
+
"Using different dtypes per module is deprecated and will be removed in future versions "
|
|
811
|
+
"Setting different dtypes per backbone model might cause device errors downstream, therefore "
|
|
812
|
+
f"setting the dtype={main_dtype} for all modules."
|
|
813
|
+
)
|
|
765
814
|
|
|
766
|
-
dtype_orig = cls._set_default_dtype(dtype)
|
|
767
815
|
else:
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
816
|
+
main_dtype = dtype
|
|
817
|
+
|
|
818
|
+
# Set it on the config and subconfigs
|
|
819
|
+
config.dtype = main_dtype
|
|
820
|
+
for sub_config_key in config.sub_configs:
|
|
821
|
+
if (sub_config := getattr(config, sub_config_key)) is not None:
|
|
822
|
+
sub_config.dtype = main_dtype
|
|
823
|
+
|
|
824
|
+
return config, main_dtype
|
|
776
825
|
|
|
777
826
|
|
|
778
827
|
class PipelineParallel(Enum):
|
|
@@ -798,11 +847,7 @@ class ModuleUtilsMixin:
|
|
|
798
847
|
"""
|
|
799
848
|
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
|
800
849
|
"""
|
|
801
|
-
|
|
802
|
-
if isinstance(dtype, str):
|
|
803
|
-
if hasattr(torch, dtype):
|
|
804
|
-
dtype = getattr(torch, dtype)
|
|
805
|
-
return dtype
|
|
850
|
+
return next(param.dtype for param in self.parameters() if param.is_floating_point())
|
|
806
851
|
|
|
807
852
|
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
|
808
853
|
"""
|
|
@@ -827,13 +872,8 @@ class ModuleUtilsMixin:
|
|
|
827
872
|
return encoder_extended_attention_mask
|
|
828
873
|
|
|
829
874
|
@staticmethod
|
|
830
|
-
def create_extended_attention_mask_for_decoder(input_shape, attention_mask
|
|
831
|
-
|
|
832
|
-
warnings.warn(
|
|
833
|
-
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
|
834
|
-
)
|
|
835
|
-
else:
|
|
836
|
-
device = attention_mask.device
|
|
875
|
+
def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
|
|
876
|
+
device = attention_mask.device
|
|
837
877
|
batch_size, seq_length = input_shape
|
|
838
878
|
seq_ids = torch.arange(seq_length, device=device)
|
|
839
879
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
@@ -857,8 +897,7 @@ class ModuleUtilsMixin:
|
|
|
857
897
|
self,
|
|
858
898
|
attention_mask: Tensor,
|
|
859
899
|
input_shape: tuple[int, ...],
|
|
860
|
-
|
|
861
|
-
dtype: Optional[torch.dtype] = None,
|
|
900
|
+
dtype: torch.dtype | None = None,
|
|
862
901
|
) -> Tensor:
|
|
863
902
|
"""
|
|
864
903
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
|
@@ -875,12 +914,6 @@ class ModuleUtilsMixin:
|
|
|
875
914
|
if dtype is None:
|
|
876
915
|
dtype = self.dtype
|
|
877
916
|
|
|
878
|
-
if not (attention_mask.dim() == 2 and self.config.is_decoder):
|
|
879
|
-
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
|
|
880
|
-
if device is not None:
|
|
881
|
-
warnings.warn(
|
|
882
|
-
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
|
883
|
-
)
|
|
884
917
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
885
918
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
886
919
|
if attention_mask.dim() == 3:
|
|
@@ -891,7 +924,7 @@ class ModuleUtilsMixin:
|
|
|
891
924
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
892
925
|
if self.config.is_decoder:
|
|
893
926
|
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
|
|
894
|
-
input_shape, attention_mask
|
|
927
|
+
input_shape, attention_mask
|
|
895
928
|
)
|
|
896
929
|
else:
|
|
897
930
|
extended_attention_mask = attention_mask[:, None, None, :]
|
|
@@ -972,54 +1005,52 @@ class EmbeddingAccessMixin:
|
|
|
972
1005
|
`nn.Module`: A torch module mapping vocabulary to hidden states.
|
|
973
1006
|
"""
|
|
974
1007
|
|
|
975
|
-
# 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
|
|
976
|
-
# for most NLP models), and if so, return it.
|
|
977
|
-
|
|
978
1008
|
name = getattr(self, "_input_embed_layer", "embed_tokens")
|
|
979
1009
|
|
|
1010
|
+
# 1) Direct attribute (most NLP models).
|
|
980
1011
|
if (default_embedding := getattr(self, name, None)) is not None:
|
|
981
1012
|
return default_embedding
|
|
982
|
-
# 2)
|
|
1013
|
+
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision/audio models).
|
|
1014
|
+
if hasattr(self, "embeddings") and hasattr(self.embeddings, name):
|
|
1015
|
+
return getattr(self.embeddings, name)
|
|
1016
|
+
# 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides).
|
|
1017
|
+
if hasattr(self, "model") and hasattr(self.model, name):
|
|
1018
|
+
return getattr(self.model, name)
|
|
983
1019
|
|
|
984
|
-
if hasattr(self, "
|
|
985
|
-
|
|
1020
|
+
if hasattr(self, "base_model"):
|
|
1021
|
+
base_model = self.base_model
|
|
1022
|
+
if base_model is not None and base_model is not self:
|
|
1023
|
+
return base_model.get_input_embeddings()
|
|
986
1024
|
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
else:
|
|
991
|
-
base_model = getattr(self, "base_model_prefix", None)
|
|
992
|
-
if base_model is not None:
|
|
993
|
-
base_model = getattr(self, base_model, None)
|
|
994
|
-
if base_model is not None and base_model is not self:
|
|
995
|
-
return base_model.get_input_embeddings()
|
|
996
|
-
raise NotImplementedError(
|
|
997
|
-
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
|
|
998
|
-
"please override in the subclass."
|
|
999
|
-
)
|
|
1025
|
+
raise NotImplementedError(
|
|
1026
|
+
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
|
|
1027
|
+
)
|
|
1000
1028
|
|
|
1001
1029
|
def set_input_embeddings(self, value: nn.Module):
|
|
1002
1030
|
"""Fallback setter that handles **~70%** of models in the code-base.
|
|
1003
1031
|
|
|
1004
1032
|
Order of attempts:
|
|
1005
|
-
1. `self
|
|
1006
|
-
2. `self.
|
|
1007
|
-
3.
|
|
1008
|
-
4.
|
|
1033
|
+
1. `self.<_input_embed_layer>` (direct attribute)
|
|
1034
|
+
2. `self.embeddings.<_input_embed_layer>` (nested embeddings for vision/audio models)
|
|
1035
|
+
3. `self.model.<_input_embed_layer>` (encoder/decoder models)
|
|
1036
|
+
4. delegate to the *base model* if one exists
|
|
1037
|
+
5. otherwise raise `NotImplementedError` so subclasses still can (and
|
|
1009
1038
|
should) override for exotic layouts.
|
|
1010
1039
|
"""
|
|
1011
1040
|
|
|
1012
|
-
# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
|
|
1013
1041
|
name = getattr(self, "_input_embed_layer", "embed_tokens")
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
# 2) as well as vanilla decoder‑only architectures
|
|
1017
|
-
elif hasattr(self, name):
|
|
1042
|
+
# 1) Direct attribute (most NLP models)
|
|
1043
|
+
if hasattr(self, name):
|
|
1018
1044
|
setattr(self, name, value)
|
|
1019
|
-
#
|
|
1020
|
-
elif
|
|
1021
|
-
|
|
1022
|
-
|
|
1045
|
+
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models)
|
|
1046
|
+
elif hasattr(self, "embeddings") and hasattr(self.embeddings, name):
|
|
1047
|
+
setattr(self.embeddings, name, value)
|
|
1048
|
+
# 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
|
|
1049
|
+
elif hasattr(self, "model") and hasattr(self.model, name):
|
|
1050
|
+
setattr(self.model, name, value)
|
|
1051
|
+
# 4) recurse once into the registered *base* model (e.g. for encoder/decoder)
|
|
1052
|
+
elif hasattr(self, "base_model") and self.base_model is not self:
|
|
1053
|
+
self.base_model.set_input_embeddings(value)
|
|
1023
1054
|
else:
|
|
1024
1055
|
raise NotImplementedError(
|
|
1025
1056
|
f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
|
|
@@ -1080,8 +1111,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1080
1111
|
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
|
|
1081
1112
|
_keep_in_fp32_modules_strict = None
|
|
1082
1113
|
|
|
1083
|
-
dtype_plan:
|
|
1084
|
-
_dtype: Optional[Union[str, torch.dtype]] = torch.get_default_dtype()
|
|
1114
|
+
dtype_plan: dict[str, torch.dtype] | None = None
|
|
1085
1115
|
|
|
1086
1116
|
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
|
1087
1117
|
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
|
@@ -1141,7 +1171,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1141
1171
|
|
|
1142
1172
|
# Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
|
|
1143
1173
|
# Possible values are: text, image, video, audio and time
|
|
1144
|
-
input_modalities:
|
|
1174
|
+
input_modalities: str | list[str] = "text" # most models are text
|
|
1145
1175
|
|
|
1146
1176
|
@property
|
|
1147
1177
|
@torch._dynamo.allow_in_graph
|
|
@@ -1226,14 +1256,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1226
1256
|
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
|
1227
1257
|
)
|
|
1228
1258
|
self.config = config
|
|
1229
|
-
default_dtype = torch.get_default_dtype()
|
|
1230
|
-
self._dtype = default_dtype
|
|
1231
1259
|
|
|
1232
1260
|
# Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
|
|
1233
1261
|
# setting it recursively)
|
|
1234
1262
|
self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
|
1235
1263
|
self.config._attn_implementation, is_init_check=True
|
|
1236
1264
|
)
|
|
1265
|
+
# Check the experts implementation is supported, or set it if not yet set (on the internal attr, to avoid
|
|
1266
|
+
# setting it recursively)
|
|
1267
|
+
self.config._experts_implementation_internal = self._check_and_adjust_experts_implementation(
|
|
1268
|
+
self.config._experts_implementation
|
|
1269
|
+
)
|
|
1237
1270
|
if self.can_generate():
|
|
1238
1271
|
self.generation_config = GenerationConfig.from_model_config(config)
|
|
1239
1272
|
|
|
@@ -1349,7 +1382,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1349
1382
|
def pp_plan(self, plan: dict[str, tuple[str, str]]):
|
|
1350
1383
|
self._pp_plan = plan
|
|
1351
1384
|
|
|
1352
|
-
def dequantize(self):
|
|
1385
|
+
def dequantize(self, dtype=None):
|
|
1353
1386
|
"""
|
|
1354
1387
|
Potentially dequantize the model in case it has been quantized by a quantization method that support
|
|
1355
1388
|
dequantization.
|
|
@@ -1359,7 +1392,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1359
1392
|
if hf_quantizer is None:
|
|
1360
1393
|
raise ValueError("You need to first quantize your model in order to dequantize it")
|
|
1361
1394
|
|
|
1362
|
-
return hf_quantizer.dequantize(self)
|
|
1395
|
+
return hf_quantizer.dequantize(self, dtype=dtype)
|
|
1363
1396
|
|
|
1364
1397
|
def _backward_compatibility_gradient_checkpointing(self):
|
|
1365
1398
|
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
|
|
@@ -1367,7 +1400,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1367
1400
|
# Remove the attribute now that is has been consumed, so it's no saved in the config.
|
|
1368
1401
|
delattr(self.config, "gradient_checkpointing")
|
|
1369
1402
|
|
|
1370
|
-
def add_model_tags(self, tags:
|
|
1403
|
+
def add_model_tags(self, tags: list[str] | str) -> None:
|
|
1371
1404
|
r"""
|
|
1372
1405
|
Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
|
|
1373
1406
|
not overwrite existing tags in the model.
|
|
@@ -1400,7 +1433,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1400
1433
|
self.model_tags.append(tag)
|
|
1401
1434
|
|
|
1402
1435
|
@classmethod
|
|
1403
|
-
@restore_default_dtype
|
|
1404
1436
|
def _from_config(cls, config, **kwargs):
|
|
1405
1437
|
"""
|
|
1406
1438
|
All context managers that the model should be initialized under go here.
|
|
@@ -1409,9 +1441,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1409
1441
|
dtype (`torch.dtype`, *optional*):
|
|
1410
1442
|
Override the default `dtype` and load the model under this dtype.
|
|
1411
1443
|
"""
|
|
1412
|
-
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
|
|
1413
|
-
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
|
|
1414
|
-
# modeling code, we can try to infer it here same way as done in `from_pretrained`
|
|
1415
1444
|
# For BC on the old `torch_dtype`
|
|
1416
1445
|
dtype = kwargs.pop("dtype", config.dtype)
|
|
1417
1446
|
if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
|
|
@@ -1421,67 +1450,32 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1421
1450
|
if isinstance(dtype, str):
|
|
1422
1451
|
dtype = getattr(torch, dtype)
|
|
1423
1452
|
|
|
1424
|
-
# override default dtype if needed
|
|
1425
|
-
dtype_orig = None
|
|
1426
|
-
if dtype is not None:
|
|
1427
|
-
dtype_orig = cls._set_default_dtype(dtype)
|
|
1428
|
-
|
|
1429
1453
|
# If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
|
|
1430
1454
|
if "attn_implementation" in kwargs:
|
|
1431
1455
|
config._attn_implementation = kwargs.pop("attn_implementation")
|
|
1432
1456
|
|
|
1457
|
+
# If passing `experts_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
|
|
1458
|
+
if "experts_implementation" in kwargs:
|
|
1459
|
+
config._experts_implementation = kwargs.pop("experts_implementation")
|
|
1460
|
+
|
|
1461
|
+
init_contexts = []
|
|
1462
|
+
if dtype is not None:
|
|
1463
|
+
init_contexts.append(local_torch_dtype(dtype, cls.__name__))
|
|
1464
|
+
|
|
1433
1465
|
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
|
|
1434
1466
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
|
1435
1467
|
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
|
1436
1468
|
# and memory copying it on CPU or each GPU first
|
|
1437
1469
|
import deepspeed
|
|
1438
1470
|
|
|
1439
|
-
init_contexts
|
|
1440
|
-
with ContextManagers(init_contexts):
|
|
1441
|
-
model = cls(config, **kwargs)
|
|
1471
|
+
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
|
|
1442
1472
|
|
|
1443
|
-
|
|
1473
|
+
# Instantiate the model
|
|
1474
|
+
with ContextManagers(init_contexts):
|
|
1444
1475
|
model = cls(config, **kwargs)
|
|
1445
1476
|
|
|
1446
|
-
# restore default dtype if it was modified
|
|
1447
|
-
if dtype_orig is not None:
|
|
1448
|
-
torch.set_default_dtype(dtype_orig)
|
|
1449
|
-
|
|
1450
1477
|
return model
|
|
1451
1478
|
|
|
1452
|
-
@classmethod
|
|
1453
|
-
def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
|
1454
|
-
"""
|
|
1455
|
-
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
|
|
1456
|
-
under specific dtype.
|
|
1457
|
-
|
|
1458
|
-
Args:
|
|
1459
|
-
dtype (`torch.dtype`):
|
|
1460
|
-
a floating dtype to set to.
|
|
1461
|
-
|
|
1462
|
-
Returns:
|
|
1463
|
-
`torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
|
|
1464
|
-
modified. If it wasn't, returns `None`.
|
|
1465
|
-
|
|
1466
|
-
Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
|
|
1467
|
-
`torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
|
|
1468
|
-
"""
|
|
1469
|
-
if isinstance(dtype, str):
|
|
1470
|
-
if hasattr(torch, dtype):
|
|
1471
|
-
dtype = getattr(torch, dtype)
|
|
1472
|
-
else:
|
|
1473
|
-
raise ValueError(f"Received an invalid string dtype: {dtype}")
|
|
1474
|
-
if not dtype.is_floating_point:
|
|
1475
|
-
raise ValueError(
|
|
1476
|
-
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
|
|
1477
|
-
)
|
|
1478
|
-
|
|
1479
|
-
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
|
|
1480
|
-
dtype_orig = torch.get_default_dtype()
|
|
1481
|
-
torch.set_default_dtype(dtype)
|
|
1482
|
-
cls._dtype = dtype
|
|
1483
|
-
return dtype_orig
|
|
1484
|
-
|
|
1485
1479
|
@property
|
|
1486
1480
|
def base_model(self) -> nn.Module:
|
|
1487
1481
|
"""
|
|
@@ -1558,7 +1552,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1558
1552
|
return True
|
|
1559
1553
|
|
|
1560
1554
|
if is_torch_xpu_available():
|
|
1561
|
-
logger.info(
|
|
1555
|
+
logger.info(
|
|
1556
|
+
f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU."
|
|
1557
|
+
)
|
|
1562
1558
|
return True
|
|
1563
1559
|
|
|
1564
1560
|
if importlib.util.find_spec("flash_attn") is None:
|
|
@@ -1727,6 +1723,22 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1727
1723
|
|
|
1728
1724
|
return True
|
|
1729
1725
|
|
|
1726
|
+
def _grouped_mm_can_dispatch(self) -> bool:
|
|
1727
|
+
"""
|
|
1728
|
+
Check the availability of Grouped MM for a given model.
|
|
1729
|
+
"""
|
|
1730
|
+
|
|
1731
|
+
if not self._can_set_experts_implementation():
|
|
1732
|
+
raise ValueError(f"{self.__class__.__name__} does not support setting experts implementation.")
|
|
1733
|
+
|
|
1734
|
+
if not is_grouped_mm_available():
|
|
1735
|
+
raise ImportError(
|
|
1736
|
+
"PyTorch Grouped MM requirements in Transformers are not met. Please install torch>=2.9.0."
|
|
1737
|
+
)
|
|
1738
|
+
|
|
1739
|
+
# If no error raised by this point, we can return `True`
|
|
1740
|
+
return True
|
|
1741
|
+
|
|
1730
1742
|
def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
|
|
1731
1743
|
"""
|
|
1732
1744
|
Check the availability of Flex Attention for a given model.
|
|
@@ -1755,7 +1767,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1755
1767
|
return True
|
|
1756
1768
|
|
|
1757
1769
|
def _check_and_adjust_attn_implementation(
|
|
1758
|
-
self, attn_implementation:
|
|
1770
|
+
self, attn_implementation: str | None, is_init_check: bool = False
|
|
1759
1771
|
) -> str:
|
|
1760
1772
|
"""
|
|
1761
1773
|
Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
|
|
@@ -1790,14 +1802,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1790
1802
|
and is_kernels_available()
|
|
1791
1803
|
and not is_torch_npu_available()
|
|
1792
1804
|
):
|
|
1793
|
-
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
else:
|
|
1800
|
-
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
|
|
1805
|
+
applicable_attn_implementation = FLASH_ATTN_KERNEL_FALLBACK[attn_implementation.removeprefix("paged|")]
|
|
1806
|
+
|
|
1807
|
+
if is_torch_xpu_available() and attn_implementation.removeprefix("paged|") == "flash_attention_2":
|
|
1808
|
+
# On XPU, kernels library is the native implementation
|
|
1809
|
+
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
|
|
1810
|
+
requested_original_flash_attn = False
|
|
1801
1811
|
|
|
1802
1812
|
if is_paged:
|
|
1803
1813
|
applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
|
|
@@ -1837,7 +1847,20 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1837
1847
|
|
|
1838
1848
|
return applicable_attn_implementation
|
|
1839
1849
|
|
|
1840
|
-
def
|
|
1850
|
+
def _check_and_adjust_experts_implementation(self, experts_implementation: str | None) -> str:
|
|
1851
|
+
"""
|
|
1852
|
+
Check that the `experts_implementation` exists and is supported by the models.
|
|
1853
|
+
|
|
1854
|
+
Args:
|
|
1855
|
+
experts_implementation (`str` or `None`):
|
|
1856
|
+
The experts implementation to check for existence/validity.
|
|
1857
|
+
Returns:
|
|
1858
|
+
`str`: The final experts implementation to use.
|
|
1859
|
+
"""
|
|
1860
|
+
applicable_experts_implementation = self.get_correct_experts_implementation(experts_implementation)
|
|
1861
|
+
return applicable_experts_implementation
|
|
1862
|
+
|
|
1863
|
+
def get_correct_attn_implementation(self, requested_attention: str | None, is_init_check: bool = False) -> str:
|
|
1841
1864
|
applicable_attention = "sdpa" if requested_attention is None else requested_attention
|
|
1842
1865
|
if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
|
1843
1866
|
message = (
|
|
@@ -1871,6 +1894,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1871
1894
|
|
|
1872
1895
|
return applicable_attention
|
|
1873
1896
|
|
|
1897
|
+
def get_correct_experts_implementation(self, requested_experts: str | None) -> str:
|
|
1898
|
+
applicable_experts = "grouped_mm" if requested_experts is None else requested_experts
|
|
1899
|
+
if applicable_experts not in ["eager", "grouped_mm", "batched_mm"]:
|
|
1900
|
+
message = (
|
|
1901
|
+
f'Specified `experts_implementation="{applicable_experts}"` is not supported. The only possible arguments are '
|
|
1902
|
+
'`experts_implementation="eager"`, `"experts_implementation=grouped_mm"` and `"experts_implementation=batched_mm"`.'
|
|
1903
|
+
)
|
|
1904
|
+
raise ValueError(message)
|
|
1905
|
+
|
|
1906
|
+
# Perform relevant checks
|
|
1907
|
+
if applicable_experts == "grouped_mm":
|
|
1908
|
+
try:
|
|
1909
|
+
self._grouped_mm_can_dispatch()
|
|
1910
|
+
except (ValueError, ImportError) as e:
|
|
1911
|
+
if requested_experts == "grouped_mm":
|
|
1912
|
+
raise e
|
|
1913
|
+
applicable_experts = "eager"
|
|
1914
|
+
|
|
1915
|
+
return applicable_experts
|
|
1916
|
+
|
|
1874
1917
|
@classmethod
|
|
1875
1918
|
def _can_set_attn_implementation(cls) -> bool:
|
|
1876
1919
|
"""Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
|
|
@@ -1889,7 +1932,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1889
1932
|
# If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
|
|
1890
1933
|
return True
|
|
1891
1934
|
|
|
1892
|
-
|
|
1935
|
+
@classmethod
|
|
1936
|
+
def _can_set_experts_implementation(cls) -> bool:
|
|
1937
|
+
"""Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
|
|
1938
|
+
opening the file, but avoids maintaining yet another property flag.
|
|
1939
|
+
"""
|
|
1940
|
+
class_file = sys.modules[cls.__module__].__file__
|
|
1941
|
+
with open(class_file, "r") as f:
|
|
1942
|
+
code = f.read()
|
|
1943
|
+
# heuristic -> if we the use_experts_implementation decorator is used, then we can set it
|
|
1944
|
+
return "@use_experts_implementation" in code
|
|
1945
|
+
|
|
1946
|
+
def set_attn_implementation(self, attn_implementation: str | dict):
|
|
1893
1947
|
"""
|
|
1894
1948
|
Set the requested `attn_implementation` for this model.
|
|
1895
1949
|
|
|
@@ -1988,6 +2042,50 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1988
2042
|
if hasattr(subconfig, "_attn_was_changed"):
|
|
1989
2043
|
del subconfig._attn_was_changed
|
|
1990
2044
|
|
|
2045
|
+
def set_experts_implementation(self, experts_implementation: str | dict):
|
|
2046
|
+
"""
|
|
2047
|
+
Set the requested `experts_implementation` for this model.
|
|
2048
|
+
|
|
2049
|
+
Args:
|
|
2050
|
+
experts_implementation (`str` or `dict`):
|
|
2051
|
+
The experts implementation to set for this model. It can be either a `str`, in which case it will be
|
|
2052
|
+
dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
|
|
2053
|
+
submodel will dispatch the corresponding value.
|
|
2054
|
+
"""
|
|
2055
|
+
requested_implementation = (
|
|
2056
|
+
experts_implementation
|
|
2057
|
+
if not isinstance(experts_implementation, dict)
|
|
2058
|
+
else experts_implementation.get("", self.config._experts_implementation)
|
|
2059
|
+
)
|
|
2060
|
+
|
|
2061
|
+
if requested_implementation != self.config._experts_implementation:
|
|
2062
|
+
requested_implementation = self._check_and_adjust_experts_implementation(requested_implementation)
|
|
2063
|
+
# Apply the change (on the internal attr, to avoid setting it recursively)
|
|
2064
|
+
self.config._experts_implementation_internal = requested_implementation
|
|
2065
|
+
|
|
2066
|
+
# Apply it to all submodels as well
|
|
2067
|
+
for submodule in self.modules():
|
|
2068
|
+
# We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
|
|
2069
|
+
# e.g. ForCausalLM has a Model inside, but no need to check it again)
|
|
2070
|
+
if (
|
|
2071
|
+
submodule is not self
|
|
2072
|
+
and isinstance(submodule, PreTrainedModel)
|
|
2073
|
+
and submodule.config.__class__ != self.config.__class__
|
|
2074
|
+
):
|
|
2075
|
+
# Set the experts on the submodule
|
|
2076
|
+
sub_implementation = requested_implementation
|
|
2077
|
+
if isinstance(experts_implementation, dict):
|
|
2078
|
+
for subconfig_key in self.config.sub_configs:
|
|
2079
|
+
# We need to check for exact object match here, with `is`
|
|
2080
|
+
if getattr(self.config, subconfig_key) is submodule.config:
|
|
2081
|
+
sub_implementation = experts_implementation.get(
|
|
2082
|
+
subconfig_key, submodule.config._experts_implementation
|
|
2083
|
+
)
|
|
2084
|
+
break
|
|
2085
|
+
# Check the module can use correctly, otherwise we raise an error if requested experts can't be set for submodule
|
|
2086
|
+
sub_implementation = submodule.get_correct_experts_implementation(sub_implementation)
|
|
2087
|
+
submodule.config._experts_implementation_internal = sub_implementation
|
|
2088
|
+
|
|
1991
2089
|
def enable_input_require_grads(self):
|
|
1992
2090
|
"""
|
|
1993
2091
|
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
|
|
@@ -1999,14 +2097,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1999
2097
|
|
|
2000
2098
|
hooks = []
|
|
2001
2099
|
seen_modules = set()
|
|
2100
|
+
found_embeddings = False
|
|
2002
2101
|
|
|
2003
2102
|
for module in self.modules():
|
|
2004
2103
|
if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
|
|
2005
2104
|
continue
|
|
2006
2105
|
|
|
2007
|
-
|
|
2106
|
+
try:
|
|
2107
|
+
input_embeddings = module.get_input_embeddings()
|
|
2108
|
+
except NotImplementedError:
|
|
2109
|
+
continue
|
|
2008
2110
|
|
|
2009
|
-
if input_embeddings is None:
|
|
2111
|
+
if input_embeddings is None or not hasattr(input_embeddings, "register_forward_hook"):
|
|
2010
2112
|
continue
|
|
2011
2113
|
|
|
2012
2114
|
embedding_id = id(input_embeddings)
|
|
@@ -2015,11 +2117,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2015
2117
|
|
|
2016
2118
|
seen_modules.add(embedding_id)
|
|
2017
2119
|
hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))
|
|
2120
|
+
found_embeddings = True
|
|
2018
2121
|
|
|
2019
2122
|
self._require_grads_hooks = hooks
|
|
2020
2123
|
if hooks:
|
|
2021
2124
|
# for BC
|
|
2022
2125
|
self._require_grads_hook = hooks[0]
|
|
2126
|
+
if not found_embeddings:
|
|
2127
|
+
logger.warning_once(
|
|
2128
|
+
f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
|
|
2129
|
+
"embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
|
|
2130
|
+
"support those features, or set `_input_embed_layer` to the attribute name that holds the embeddings."
|
|
2131
|
+
)
|
|
2023
2132
|
|
|
2024
2133
|
def disable_input_require_grads(self):
|
|
2025
2134
|
"""
|
|
@@ -2036,7 +2145,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2036
2145
|
if hasattr(self, "_require_grads_hook"):
|
|
2037
2146
|
del self._require_grads_hook
|
|
2038
2147
|
|
|
2039
|
-
def get_encoder(self, modality:
|
|
2148
|
+
def get_encoder(self, modality: str | None = None):
|
|
2040
2149
|
"""
|
|
2041
2150
|
Best-effort lookup of the *encoder* module. If provided with `modality` argument,
|
|
2042
2151
|
it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder")
|
|
@@ -2068,7 +2177,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2068
2177
|
# If this is a base transformer model (no encoder/model attributes), return self
|
|
2069
2178
|
return self
|
|
2070
2179
|
|
|
2071
|
-
def set_encoder(self, encoder, modality:
|
|
2180
|
+
def set_encoder(self, encoder, modality: str | None = None):
|
|
2072
2181
|
"""
|
|
2073
2182
|
Symmetric setter. Mirrors the lookup logic used in `get_encoder`.
|
|
2074
2183
|
"""
|
|
@@ -2154,14 +2263,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2154
2263
|
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
|
|
2155
2264
|
if getattr(module, "weight", None) is not None:
|
|
2156
2265
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
2157
|
-
if
|
|
2266
|
+
if module.bias is not None:
|
|
2158
2267
|
init.zeros_(module.bias)
|
|
2159
2268
|
elif isinstance(module, nn.Embedding):
|
|
2160
|
-
|
|
2161
|
-
|
|
2162
|
-
|
|
2163
|
-
|
|
2164
|
-
init.zeros_(module.weight[module.padding_idx])
|
|
2269
|
+
init.normal_(module.weight, mean=0.0, std=std)
|
|
2270
|
+
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
2271
|
+
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
|
|
2272
|
+
init.zeros_(module.weight[module.padding_idx])
|
|
2165
2273
|
elif isinstance(module, nn.MultiheadAttention):
|
|
2166
2274
|
# This uses torch's original init
|
|
2167
2275
|
module._reset_parameters()
|
|
@@ -2173,10 +2281,25 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2173
2281
|
or "RMSNorm" in module.__class__.__name__
|
|
2174
2282
|
):
|
|
2175
2283
|
# Norms can exist without weights (in which case they are None from torch primitives)
|
|
2176
|
-
if
|
|
2284
|
+
if getattr(module, "weight", None) is not None:
|
|
2177
2285
|
init.ones_(module.weight)
|
|
2178
|
-
if
|
|
2286
|
+
if getattr(module, "bias", None) is not None:
|
|
2179
2287
|
init.zeros_(module.bias)
|
|
2288
|
+
# And the potential buffers for the BatchNorms
|
|
2289
|
+
if getattr(module, "running_mean", None) is not None:
|
|
2290
|
+
init.zeros_(module.running_mean)
|
|
2291
|
+
init.ones_(module.running_var)
|
|
2292
|
+
init.zeros_(module.num_batches_tracked)
|
|
2293
|
+
# This matches all the usual RotaryEmbeddings modules
|
|
2294
|
+
elif "RotaryEmbedding" in module.__class__.__name__ and hasattr(module, "original_inv_freq"):
|
|
2295
|
+
rope_fn = (
|
|
2296
|
+
ROPE_INIT_FUNCTIONS[module.rope_type]
|
|
2297
|
+
if module.rope_type != "default"
|
|
2298
|
+
else module.compute_default_rope_parameters
|
|
2299
|
+
)
|
|
2300
|
+
buffer_value, _ = rope_fn(module.config)
|
|
2301
|
+
init.copy_(module.inv_freq, buffer_value)
|
|
2302
|
+
init.copy_(module.original_inv_freq, buffer_value)
|
|
2180
2303
|
|
|
2181
2304
|
def _initialize_weights(self, module):
|
|
2182
2305
|
"""
|
|
@@ -2281,7 +2404,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2281
2404
|
|
|
2282
2405
|
tied_mapping = self._tied_weights_keys
|
|
2283
2406
|
# If the config does not specify any tying, return empty dict
|
|
2284
|
-
if not self.config.tie_word_embeddings
|
|
2407
|
+
if not self.config.tie_word_embeddings:
|
|
2285
2408
|
return {}
|
|
2286
2409
|
# If None, return empty dict
|
|
2287
2410
|
elif tied_mapping is None:
|
|
@@ -2327,7 +2450,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2327
2450
|
|
|
2328
2451
|
return expanded_tied_weights
|
|
2329
2452
|
|
|
2330
|
-
def tie_weights(self, missing_keys:
|
|
2453
|
+
def tie_weights(self, missing_keys: set[str] | None = None, recompute_mapping: bool = True):
|
|
2331
2454
|
"""
|
|
2332
2455
|
Tie the model weights. If `recompute_mapping=False` (default when called internally), it will rely on the
|
|
2333
2456
|
`model.all_tied_weights_keys` attribute, containing the `{target: source}` mapping for the tied params.
|
|
@@ -2347,30 +2470,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2347
2470
|
|
|
2348
2471
|
tied_keys = list(tied_keys.items())
|
|
2349
2472
|
for i, (target_param_name, source_param_name) in enumerate(tied_keys):
|
|
2350
|
-
# Usually we tie a single target to a single source, but when both are missing we may later tie
|
|
2351
|
-
# both the source and target to a third "backup" parameter that is present in the checkpoint, so we use
|
|
2352
|
-
# a list here
|
|
2353
|
-
target_param_names = [target_param_name]
|
|
2354
|
-
|
|
2355
2473
|
# This is `from_pretrained` -> let's check symmetrically in case the source key is not present
|
|
2356
2474
|
if missing_keys is not None:
|
|
2357
2475
|
remove_from_missing = True
|
|
2358
2476
|
source_is_there = source_param_name not in missing_keys
|
|
2359
2477
|
target_is_there = target_param_name not in missing_keys
|
|
2360
2478
|
# Both are already present -> it means the config is wrong and do not reflect the actual
|
|
2361
|
-
# checkpoint -> let's raise a warning and
|
|
2479
|
+
# checkpoint -> let's raise a warning and NOT tie them
|
|
2362
2480
|
if source_is_there and target_is_there:
|
|
2363
2481
|
logger.warning(
|
|
2364
2482
|
f"The tied weights mapping and config for this model specifies to tie {source_param_name} to "
|
|
2365
2483
|
f"{target_param_name}, but both are present in the checkpoints, so we will NOT tie them. "
|
|
2366
2484
|
"You should update the config with `tie_word_embeddings=False` to silence this warning"
|
|
2367
2485
|
)
|
|
2486
|
+
# Remove from internal attribute to correctly reflect actual tied weights
|
|
2487
|
+
self.all_tied_weights_keys.pop(target_param_name)
|
|
2368
2488
|
# Skip to next iteration
|
|
2369
2489
|
continue
|
|
2370
2490
|
# We're missing the source but we have the target -> we swap them, tying the parameter that exists
|
|
2371
2491
|
elif not source_is_there and target_is_there:
|
|
2372
2492
|
target_param_name, source_param_name = source_param_name, target_param_name
|
|
2373
|
-
target_param_names = [target_param_name]
|
|
2374
2493
|
# Both are missing -> check other keys in case more than 2 keys are tied to the same weight
|
|
2375
2494
|
elif not source_is_there and not target_is_there:
|
|
2376
2495
|
for target_backup, source_backup in tied_keys[i + 1 :]:
|
|
@@ -2379,10 +2498,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2379
2498
|
if source_backup == source_param_name:
|
|
2380
2499
|
target_backup_is_there = target_backup not in missing_keys
|
|
2381
2500
|
# If the target is present, we found the correct weight to tie into (we know the source is missing)
|
|
2501
|
+
# Note here that we do not tie the missing source right now as well, as it will be done anyway when
|
|
2502
|
+
# the pair (target_backup, source_backup) becomes the main pair (target_param_name, source_param_name)
|
|
2382
2503
|
if target_backup_is_there:
|
|
2383
2504
|
source_param_name = target_backup
|
|
2384
|
-
# Append the source as well, since both are missing we'll tie both
|
|
2385
|
-
target_param_names.append(source_param_name)
|
|
2386
2505
|
break
|
|
2387
2506
|
# If we did not break from the loop, it was impossible to find a source key -> let's raise
|
|
2388
2507
|
else:
|
|
@@ -2398,19 +2517,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2398
2517
|
|
|
2399
2518
|
# Perform the actual tying
|
|
2400
2519
|
source_param = self.get_parameter_or_buffer(source_param_name)
|
|
2401
|
-
|
|
2402
|
-
|
|
2403
|
-
|
|
2404
|
-
|
|
2405
|
-
|
|
2406
|
-
|
|
2407
|
-
|
|
2408
|
-
|
|
2409
|
-
|
|
2410
|
-
|
|
2411
|
-
|
|
2412
|
-
|
|
2413
|
-
missing_keys.discard(target_param_name)
|
|
2520
|
+
if "." in target_param_name:
|
|
2521
|
+
parent_name, name = target_param_name.rsplit(".", 1)
|
|
2522
|
+
parent = self.get_submodule(parent_name)
|
|
2523
|
+
else:
|
|
2524
|
+
name = target_param_name
|
|
2525
|
+
parent = self
|
|
2526
|
+
# Tie the weights
|
|
2527
|
+
setattr(parent, name, source_param)
|
|
2528
|
+
self._adjust_bias(parent, source_param)
|
|
2529
|
+
# Remove from missing if necesary
|
|
2530
|
+
if missing_keys is not None and remove_from_missing:
|
|
2531
|
+
missing_keys.discard(target_param_name)
|
|
2414
2532
|
|
|
2415
2533
|
def _adjust_bias(self, output_embeddings, input_embeddings):
|
|
2416
2534
|
if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"):
|
|
@@ -2455,8 +2573,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2455
2573
|
|
|
2456
2574
|
def resize_token_embeddings(
|
|
2457
2575
|
self,
|
|
2458
|
-
new_num_tokens:
|
|
2459
|
-
pad_to_multiple_of:
|
|
2576
|
+
new_num_tokens: int | None = None,
|
|
2577
|
+
pad_to_multiple_of: int | None = None,
|
|
2460
2578
|
mean_resizing: bool = True,
|
|
2461
2579
|
) -> nn.Embedding:
|
|
2462
2580
|
"""
|
|
@@ -2557,8 +2675,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2557
2675
|
def _get_resized_embeddings(
|
|
2558
2676
|
self,
|
|
2559
2677
|
old_embeddings: nn.Embedding,
|
|
2560
|
-
new_num_tokens:
|
|
2561
|
-
pad_to_multiple_of:
|
|
2678
|
+
new_num_tokens: int | None = None,
|
|
2679
|
+
pad_to_multiple_of: int | None = None,
|
|
2562
2680
|
mean_resizing: bool = True,
|
|
2563
2681
|
) -> nn.Embedding:
|
|
2564
2682
|
"""
|
|
@@ -2715,7 +2833,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2715
2833
|
def _get_resized_lm_head(
|
|
2716
2834
|
self,
|
|
2717
2835
|
old_lm_head: nn.Linear,
|
|
2718
|
-
new_num_tokens:
|
|
2836
|
+
new_num_tokens: int | None = None,
|
|
2719
2837
|
transposed: bool = False,
|
|
2720
2838
|
mean_resizing: bool = True,
|
|
2721
2839
|
) -> nn.Linear:
|
|
@@ -2912,7 +3030,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2912
3030
|
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
|
|
2913
3031
|
)
|
|
2914
3032
|
|
|
2915
|
-
def get_position_embeddings(self) ->
|
|
3033
|
+
def get_position_embeddings(self) -> nn.Embedding | tuple[nn.Embedding]:
|
|
2916
3034
|
raise NotImplementedError(
|
|
2917
3035
|
f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
|
|
2918
3036
|
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
|
|
@@ -2923,7 +3041,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2923
3041
|
Maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
|
|
2924
3042
|
initialization logic in `_init_weights`.
|
|
2925
3043
|
"""
|
|
2926
|
-
|
|
3044
|
+
# If we are initializing on meta device, there is no point in trying to run inits
|
|
3045
|
+
if get_torch_context_manager_or_global_device() != torch.device("meta"):
|
|
2927
3046
|
# Initialize weights
|
|
2928
3047
|
self.initialize_weights()
|
|
2929
3048
|
# Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
|
|
@@ -2961,7 +3080,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2961
3080
|
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
|
2962
3081
|
)
|
|
2963
3082
|
|
|
2964
|
-
|
|
3083
|
+
needs_embedding_grads = self.main_input_name == "input_ids"
|
|
3084
|
+
# we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all)
|
|
3085
|
+
enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False)
|
|
3086
|
+
if enable_input_grads:
|
|
2965
3087
|
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
|
2966
3088
|
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
|
2967
3089
|
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
|
@@ -3019,13 +3141,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3019
3141
|
|
|
3020
3142
|
def save_pretrained(
|
|
3021
3143
|
self,
|
|
3022
|
-
save_directory:
|
|
3144
|
+
save_directory: str | os.PathLike,
|
|
3023
3145
|
is_main_process: bool = True,
|
|
3024
|
-
state_dict:
|
|
3146
|
+
state_dict: dict | None = None,
|
|
3025
3147
|
push_to_hub: bool = False,
|
|
3026
|
-
max_shard_size:
|
|
3027
|
-
variant:
|
|
3028
|
-
token:
|
|
3148
|
+
max_shard_size: int | str = "50GB",
|
|
3149
|
+
variant: str | None = None,
|
|
3150
|
+
token: str | bool | None = None,
|
|
3029
3151
|
save_peft_format: bool = True,
|
|
3030
3152
|
save_original_format: bool = True,
|
|
3031
3153
|
**kwargs,
|
|
@@ -3092,12 +3214,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3092
3214
|
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
|
3093
3215
|
)
|
|
3094
3216
|
|
|
3095
|
-
if "save_config" in kwargs:
|
|
3096
|
-
warnings.warn(
|
|
3097
|
-
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
|
|
3098
|
-
)
|
|
3099
|
-
is_main_process = kwargs.pop("save_config")
|
|
3100
|
-
|
|
3101
3217
|
# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
|
|
3102
3218
|
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
|
|
3103
3219
|
raise ImportError(
|
|
@@ -3172,29 +3288,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3172
3288
|
current_peft_config = self.peft_config[active_adapter]
|
|
3173
3289
|
current_peft_config.save_pretrained(save_directory)
|
|
3174
3290
|
|
|
3175
|
-
#
|
|
3176
|
-
module_map = {}
|
|
3177
|
-
|
|
3178
|
-
# Save the model
|
|
3291
|
+
# Get the model state_dict
|
|
3179
3292
|
if state_dict is None:
|
|
3180
|
-
# if any model parameters are offloaded, make module map
|
|
3181
|
-
if (
|
|
3182
|
-
hasattr(self, "hf_device_map")
|
|
3183
|
-
and len(set(self.hf_device_map.values())) > 1
|
|
3184
|
-
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
|
|
3185
|
-
):
|
|
3186
|
-
warnings.warn(
|
|
3187
|
-
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
|
|
3188
|
-
)
|
|
3189
|
-
for name, module in model_to_save.named_modules():
|
|
3190
|
-
if name == "":
|
|
3191
|
-
continue
|
|
3192
|
-
module_state_dict = module.state_dict()
|
|
3193
|
-
|
|
3194
|
-
for key in module_state_dict:
|
|
3195
|
-
module_map[name + f".{key}"] = module
|
|
3196
3293
|
state_dict = model_to_save.state_dict()
|
|
3197
3294
|
|
|
3295
|
+
# if any model parameters are offloaded, we need to know it for later
|
|
3296
|
+
is_offloaded = False
|
|
3297
|
+
if (
|
|
3298
|
+
hasattr(self, "hf_device_map")
|
|
3299
|
+
and len(set(self.hf_device_map.values())) > 1
|
|
3300
|
+
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
|
|
3301
|
+
):
|
|
3302
|
+
is_offloaded = True
|
|
3303
|
+
warnings.warn(
|
|
3304
|
+
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory "
|
|
3305
|
+
"exceeds the `shard_size` (50GB default)"
|
|
3306
|
+
)
|
|
3307
|
+
|
|
3198
3308
|
# Translate state_dict from smp to hf if saving with smp >= 1.10
|
|
3199
3309
|
if IS_SAGEMAKER_MP_POST_1_10:
|
|
3200
3310
|
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
|
|
@@ -3211,76 +3321,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3211
3321
|
if self._tp_size is not None:
|
|
3212
3322
|
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
|
3213
3323
|
|
|
3214
|
-
#
|
|
3215
|
-
|
|
3216
|
-
for name, tensor in state_dict.items():
|
|
3217
|
-
if not isinstance(tensor, torch.Tensor):
|
|
3218
|
-
# Sometimes in the state_dict we have non-tensor objects.
|
|
3219
|
-
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
|
3220
|
-
# In the non-tensor case, fall back to the pointer of the object itself
|
|
3221
|
-
ptrs[id(tensor)].append(name)
|
|
3222
|
-
|
|
3223
|
-
elif tensor.device.type == "meta":
|
|
3224
|
-
# In offloaded cases, there may be meta tensors in the state_dict.
|
|
3225
|
-
# For these cases, key by the pointer of the original tensor object
|
|
3226
|
-
# (state_dict tensors are detached and therefore no longer shared)
|
|
3227
|
-
tensor = self.get_parameter(name)
|
|
3228
|
-
ptrs[id(tensor)].append(name)
|
|
3229
|
-
|
|
3230
|
-
else:
|
|
3231
|
-
ptrs[id_tensor_storage(tensor)].append(name)
|
|
3232
|
-
|
|
3233
|
-
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
|
3234
|
-
|
|
3235
|
-
# Recursively descend to find tied weight keys
|
|
3236
|
-
_tied_weights_keys = set(_get_tied_weight_keys(self))
|
|
3237
|
-
error_names = []
|
|
3238
|
-
to_delete_names = set()
|
|
3239
|
-
for names in shared_ptrs.values():
|
|
3240
|
-
# Removing the keys which are declared as known duplicates on
|
|
3241
|
-
# load. This allows to make sure the name which is kept is consistent.
|
|
3242
|
-
if _tied_weights_keys is not None:
|
|
3243
|
-
found = 0
|
|
3244
|
-
for name in sorted(names):
|
|
3245
|
-
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
|
|
3246
|
-
if matches_pattern and name in state_dict:
|
|
3247
|
-
found += 1
|
|
3248
|
-
if found < len(names):
|
|
3249
|
-
to_delete_names.add(name)
|
|
3250
|
-
# We are entering a place where the weights and the transformers configuration do NOT match.
|
|
3251
|
-
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
|
|
3252
|
-
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
|
|
3253
|
-
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
|
|
3254
|
-
for name in disjoint_names:
|
|
3255
|
-
state_dict[name] = state_dict[name].clone()
|
|
3256
|
-
|
|
3257
|
-
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
|
3258
|
-
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
|
3259
|
-
# the key back leading to random tensor. A proper warning will be shown
|
|
3260
|
-
# during reload (if applicable), but since the file is not necessarily compatible with
|
|
3261
|
-
# the config, better show a proper warning.
|
|
3262
|
-
shared_names, identical_names = _find_identical(shared_names, state_dict)
|
|
3263
|
-
# delete tensors that have identical storage
|
|
3264
|
-
for inames in identical_names:
|
|
3265
|
-
known = inames.intersection(to_delete_names)
|
|
3266
|
-
for name in known:
|
|
3267
|
-
del state_dict[name]
|
|
3268
|
-
unknown = inames.difference(to_delete_names)
|
|
3269
|
-
if len(unknown) > 1:
|
|
3270
|
-
error_names.append(unknown)
|
|
3271
|
-
|
|
3272
|
-
if shared_names:
|
|
3273
|
-
error_names.extend(shared_names)
|
|
3274
|
-
|
|
3275
|
-
if len(error_names) > 0:
|
|
3276
|
-
raise RuntimeError(
|
|
3277
|
-
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
|
|
3278
|
-
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
|
|
3279
|
-
)
|
|
3324
|
+
# Remove tied weights as safetensors do not handle them
|
|
3325
|
+
state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)
|
|
3280
3326
|
|
|
3281
3327
|
# Revert all renaming and/or weight operations
|
|
3282
3328
|
if save_original_format:
|
|
3283
|
-
state_dict = revert_weight_conversion(
|
|
3329
|
+
state_dict = revert_weight_conversion(model_to_save, state_dict)
|
|
3284
3330
|
|
|
3285
3331
|
# Shard the model if it is too big.
|
|
3286
3332
|
if not _hf_peft_config_loaded:
|
|
@@ -3320,47 +3366,39 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3320
3366
|
and reg.fullmatch(filename_no_suffix) is not None
|
|
3321
3367
|
):
|
|
3322
3368
|
os.remove(full_filename)
|
|
3369
|
+
|
|
3323
3370
|
# Save the model
|
|
3324
|
-
|
|
3325
|
-
|
|
3326
|
-
|
|
3327
|
-
|
|
3328
|
-
|
|
3329
|
-
for
|
|
3330
|
-
|
|
3331
|
-
|
|
3371
|
+
for shard_file, tensor_names in logging.tqdm(
|
|
3372
|
+
state_dict_split.filename_to_tensors.items(), desc="Writing model shards"
|
|
3373
|
+
):
|
|
3374
|
+
filename = os.path.join(save_directory, shard_file)
|
|
3375
|
+
shard_state_dict = {}
|
|
3376
|
+
for tensor_name in tensor_names:
|
|
3377
|
+
# Get the tensor, and remove it from state_dict to avoid keeping the ref
|
|
3378
|
+
tensor = state_dict.pop(tensor_name)
|
|
3379
|
+
|
|
3380
|
+
# In case of TP, get the full parameter back
|
|
3381
|
+
if _is_dtensor_available and isinstance(tensor, DTensor):
|
|
3382
|
+
tensor = tensor.full_tensor()
|
|
3332
3383
|
# to get the correctly ordered tensor we need to repack if packed
|
|
3333
|
-
if _get_parameter_tp_plan(
|
|
3334
|
-
|
|
3335
|
-
|
|
3336
|
-
|
|
3337
|
-
|
|
3338
|
-
#
|
|
3339
|
-
|
|
3340
|
-
|
|
3341
|
-
|
|
3342
|
-
|
|
3343
|
-
|
|
3344
|
-
|
|
3345
|
-
|
|
3346
|
-
|
|
3347
|
-
|
|
3348
|
-
|
|
3349
|
-
|
|
3350
|
-
|
|
3351
|
-
module = module_map[module_name]
|
|
3352
|
-
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
|
|
3353
|
-
|
|
3354
|
-
# assign shard to be the completed state dict
|
|
3355
|
-
shard = shard_state_dict
|
|
3356
|
-
del shard_state_dict
|
|
3357
|
-
gc.collect()
|
|
3358
|
-
|
|
3359
|
-
# TODO: we should def parallelize this we are otherwise just waiting
|
|
3360
|
-
# too much before scheduling the next write when its in a different file
|
|
3361
|
-
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
|
3362
|
-
|
|
3363
|
-
del state_dict
|
|
3384
|
+
if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
|
|
3385
|
+
tensor = repack_weights(tensor, -1, self._tp_size, 2)
|
|
3386
|
+
|
|
3387
|
+
# If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
|
|
3388
|
+
# but it would otherwise not be contained in the saved shard if we were to simply move the file
|
|
3389
|
+
# or something
|
|
3390
|
+
if is_offloaded and tensor.device.type == "meta":
|
|
3391
|
+
tensor = load_offloaded_parameter(model_to_save, tensor_name)
|
|
3392
|
+
|
|
3393
|
+
# only do contiguous after it's permuted correctly in case of TP
|
|
3394
|
+
shard_state_dict[tensor_name] = tensor.contiguous()
|
|
3395
|
+
|
|
3396
|
+
# TODO: it would be very nice to do the writing concurrently, but safetensors never releases the GIL,
|
|
3397
|
+
# so it's not possible for now....
|
|
3398
|
+
# Write the shard to disk
|
|
3399
|
+
safe_save_file(shard_state_dict, filename, metadata=metadata)
|
|
3400
|
+
# Cleanup the data before next loop (important with offloading, so we don't blowup cpu RAM)
|
|
3401
|
+
del shard_state_dict
|
|
3364
3402
|
|
|
3365
3403
|
if index is None:
|
|
3366
3404
|
path_to_weights = os.path.join(save_directory, weights_name)
|
|
@@ -3537,19 +3575,26 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3537
3575
|
return super().float(*args)
|
|
3538
3576
|
|
|
3539
3577
|
@classmethod
|
|
3540
|
-
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
|
3578
|
+
def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
|
|
3579
|
+
# Need to instantiate with correct dtype
|
|
3580
|
+
init_contexts = [local_torch_dtype(dtype, cls.__name__)]
|
|
3541
3581
|
if is_deepspeed_zero3_enabled():
|
|
3542
3582
|
import deepspeed
|
|
3543
3583
|
|
|
3544
|
-
init_contexts = [no_init_weights()]
|
|
3545
3584
|
# We cannot initialize the model on meta device with deepspeed when not quantized
|
|
3546
3585
|
if not is_quantized and not _is_ds_init_called:
|
|
3547
3586
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
|
3548
|
-
init_contexts.extend(
|
|
3587
|
+
init_contexts.extend(
|
|
3588
|
+
[
|
|
3589
|
+
init.no_init_weights(),
|
|
3590
|
+
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
|
3591
|
+
set_zero3_state(),
|
|
3592
|
+
]
|
|
3593
|
+
)
|
|
3549
3594
|
elif is_quantized:
|
|
3550
|
-
init_contexts.extend([
|
|
3595
|
+
init_contexts.extend([torch.device("meta"), set_quantized_state()])
|
|
3551
3596
|
else:
|
|
3552
|
-
init_contexts
|
|
3597
|
+
init_contexts.append(torch.device("meta"))
|
|
3553
3598
|
|
|
3554
3599
|
return init_contexts
|
|
3555
3600
|
|
|
@@ -3574,7 +3619,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3574
3619
|
|
|
3575
3620
|
# This is a context manager to override the default kernel mapping
|
|
3576
3621
|
# We are calling kernelize inside this context manager using the use_kernels setter
|
|
3577
|
-
|
|
3622
|
+
# Param inherit_mapping should be False to avoid still loading kernel from remote
|
|
3623
|
+
inherit_mapping = not kernel_config.use_local_kernel
|
|
3624
|
+
with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
|
|
3578
3625
|
self.use_kernels = True
|
|
3579
3626
|
# We use the default kernel mapping in .integrations.hub_kernels
|
|
3580
3627
|
else:
|
|
@@ -3583,19 +3630,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3583
3630
|
self.use_kernels = False
|
|
3584
3631
|
|
|
3585
3632
|
@classmethod
|
|
3586
|
-
@restore_default_dtype
|
|
3587
3633
|
def from_pretrained(
|
|
3588
3634
|
cls: type[SpecificPreTrainedModelType],
|
|
3589
|
-
pretrained_model_name_or_path:
|
|
3635
|
+
pretrained_model_name_or_path: str | os.PathLike | None,
|
|
3590
3636
|
*model_args,
|
|
3591
|
-
config:
|
|
3592
|
-
cache_dir:
|
|
3637
|
+
config: PreTrainedConfig | str | os.PathLike | None = None,
|
|
3638
|
+
cache_dir: str | os.PathLike | None = None,
|
|
3593
3639
|
ignore_mismatched_sizes: bool = False,
|
|
3594
3640
|
force_download: bool = False,
|
|
3595
3641
|
local_files_only: bool = False,
|
|
3596
|
-
token:
|
|
3642
|
+
token: str | bool | None = None,
|
|
3597
3643
|
revision: str = "main",
|
|
3598
|
-
use_safetensors:
|
|
3644
|
+
use_safetensors: bool | None = True,
|
|
3599
3645
|
weights_only: bool = True,
|
|
3600
3646
|
**kwargs,
|
|
3601
3647
|
) -> SpecificPreTrainedModelType:
|
|
@@ -3692,10 +3738,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3692
3738
|
"org/model@main"
|
|
3693
3739
|
"org/model:custom_kernel"
|
|
3694
3740
|
"org/model@v1.2.3:custom_kernel"
|
|
3741
|
+
experts_implementation (`str`, *optional*):
|
|
3742
|
+
The experts implementation to use in the model (if relevant). Can be any of:
|
|
3743
|
+
|
|
3744
|
+
- `"eager"` (sequential implementation of the experts matrix multiplications).
|
|
3745
|
+
- `"batched_mm"` (using [`torch.bmm`](https://pytorch.org/docs/stable/generated/torch.bmm.html)).
|
|
3746
|
+
- `"grouped_mm"` (using [`torch._grouped_mm`](https://docs.pytorch.org/docs/main/generated/torch.nn.functional.grouped_mm.html)).
|
|
3747
|
+
|
|
3748
|
+
By default, if available, `grouped_mm` will be used for torch>=2.9.0. The default is otherwise the sequential `"eager"` implementation.
|
|
3695
3749
|
|
|
3696
3750
|
> Parameters for big model inference
|
|
3697
3751
|
|
|
3698
|
-
dtype (`str` or `torch.dtype`, *optional
|
|
3752
|
+
dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"`):
|
|
3699
3753
|
Override the default `torch_dtype` and load the model under a specific `dtype`. The different options
|
|
3700
3754
|
are:
|
|
3701
3755
|
|
|
@@ -3915,8 +3969,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3915
3969
|
if "attn_implementation" in kwargs:
|
|
3916
3970
|
config._attn_implementation = kwargs.pop("attn_implementation")
|
|
3917
3971
|
|
|
3918
|
-
|
|
3919
|
-
config
|
|
3972
|
+
if "experts_implementation" in kwargs:
|
|
3973
|
+
config._experts_implementation = kwargs.pop("experts_implementation")
|
|
3974
|
+
|
|
3975
|
+
hf_quantizer, config, device_map = get_hf_quantizer(
|
|
3976
|
+
config, quantization_config, device_map, weights_only, user_agent
|
|
3920
3977
|
)
|
|
3921
3978
|
|
|
3922
3979
|
if gguf_file:
|
|
@@ -3963,33 +4020,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3963
4020
|
]
|
|
3964
4021
|
|
|
3965
4022
|
# Find the correct dtype based on current state
|
|
3966
|
-
config, dtype
|
|
3967
|
-
|
|
4023
|
+
config, dtype = _get_dtype(
|
|
4024
|
+
dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only, hf_quantizer
|
|
3968
4025
|
)
|
|
3969
4026
|
|
|
3970
4027
|
config.name_or_path = pretrained_model_name_or_path
|
|
3971
|
-
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
|
|
4028
|
+
model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called)
|
|
3972
4029
|
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
|
3973
4030
|
with ContextManagers(model_init_context):
|
|
3974
4031
|
# Let's make sure we don't run the init function of buffer modules
|
|
3975
4032
|
model = cls(config, *model_args, **model_kwargs)
|
|
3976
4033
|
|
|
4034
|
+
if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
|
|
4035
|
+
hf_quantizer.preprocess_model(
|
|
4036
|
+
model=model,
|
|
4037
|
+
dtype=dtype,
|
|
4038
|
+
device_map=device_map,
|
|
4039
|
+
checkpoint_files=checkpoint_files,
|
|
4040
|
+
use_kernels=use_kernels,
|
|
4041
|
+
)
|
|
4042
|
+
|
|
3977
4043
|
# Obtain the weight conversion mapping for this model if any are registered
|
|
3978
4044
|
weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
|
|
3979
4045
|
|
|
3980
|
-
# make sure we use the model's config since the __init__ call might have copied it
|
|
3981
|
-
config = model.config
|
|
3982
|
-
|
|
3983
|
-
if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
|
|
3984
|
-
hf_quantizer.preprocess_model(
|
|
3985
|
-
model=model,
|
|
3986
|
-
device_map=device_map,
|
|
3987
|
-
keep_in_fp32_modules=model._keep_in_fp32_modules, # TODO prob no longer needed?
|
|
3988
|
-
config=config,
|
|
3989
|
-
checkpoint_files=checkpoint_files,
|
|
3990
|
-
use_kernels=use_kernels,
|
|
3991
|
-
)
|
|
3992
|
-
|
|
3993
4046
|
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
|
|
3994
4047
|
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
|
|
3995
4048
|
|
|
@@ -3997,10 +4050,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3997
4050
|
if device_map is not None:
|
|
3998
4051
|
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
|
|
3999
4052
|
|
|
4000
|
-
# restore default dtype
|
|
4001
|
-
if dtype_orig is not None:
|
|
4002
|
-
torch.set_default_dtype(dtype_orig)
|
|
4003
|
-
|
|
4004
4053
|
# Finalize model weight initialization
|
|
4005
4054
|
model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
|
|
4006
4055
|
model,
|
|
@@ -4011,6 +4060,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4011
4060
|
sharded_metadata=sharded_metadata,
|
|
4012
4061
|
device_map=device_map,
|
|
4013
4062
|
disk_offload_folder=offload_folder,
|
|
4063
|
+
offload_buffers=offload_buffers,
|
|
4014
4064
|
dtype=dtype,
|
|
4015
4065
|
hf_quantizer=hf_quantizer,
|
|
4016
4066
|
device_mesh=device_mesh,
|
|
@@ -4018,7 +4068,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4018
4068
|
weight_mapping=weight_conversions,
|
|
4019
4069
|
)
|
|
4020
4070
|
|
|
4021
|
-
model.eval() # Set model in evaluation mode to deactivate
|
|
4071
|
+
model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
|
|
4022
4072
|
model.set_use_kernels(use_kernels, kernel_config)
|
|
4023
4073
|
|
|
4024
4074
|
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
|
@@ -4034,13 +4084,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4034
4084
|
**kwargs,
|
|
4035
4085
|
)
|
|
4036
4086
|
|
|
4037
|
-
#
|
|
4038
|
-
if device_map is not None and
|
|
4087
|
+
# If the device_map has more than 1 device: dispatch model with hooks on all devices
|
|
4088
|
+
if device_map is not None and len(set(device_map.values())) > 1:
|
|
4039
4089
|
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
|
|
4040
4090
|
|
|
4041
4091
|
if hf_quantizer is not None:
|
|
4042
4092
|
model.hf_quantizer = hf_quantizer
|
|
4043
|
-
hf_quantizer.postprocess_model(
|
|
4093
|
+
hf_quantizer.postprocess_model(
|
|
4094
|
+
model
|
|
4095
|
+
) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
|
|
4044
4096
|
|
|
4045
4097
|
if _adapter_model_path is not None:
|
|
4046
4098
|
adapter_kwargs["key_mapping"] = key_mapping
|
|
@@ -4065,18 +4117,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4065
4117
|
def _load_pretrained_model(
|
|
4066
4118
|
cls,
|
|
4067
4119
|
model: "PreTrainedModel",
|
|
4068
|
-
state_dict:
|
|
4069
|
-
checkpoint_files:
|
|
4070
|
-
pretrained_model_name_or_path:
|
|
4120
|
+
state_dict: dict | None,
|
|
4121
|
+
checkpoint_files: list[str] | None,
|
|
4122
|
+
pretrained_model_name_or_path: str | None,
|
|
4071
4123
|
ignore_mismatched_sizes: bool = False,
|
|
4072
|
-
sharded_metadata:
|
|
4073
|
-
device_map:
|
|
4074
|
-
disk_offload_folder:
|
|
4075
|
-
|
|
4076
|
-
|
|
4124
|
+
sharded_metadata: dict | None = None,
|
|
4125
|
+
device_map: dict | None = None,
|
|
4126
|
+
disk_offload_folder: str | None = None,
|
|
4127
|
+
offload_buffers: bool = False,
|
|
4128
|
+
dtype: torch.dtype | None = None,
|
|
4129
|
+
hf_quantizer: HfQuantizer | None = None,
|
|
4077
4130
|
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
|
4078
4131
|
weights_only: bool = True,
|
|
4079
|
-
weight_mapping:
|
|
4132
|
+
weight_mapping: Sequence[WeightConverter | WeightRenaming] | None = None,
|
|
4080
4133
|
):
|
|
4081
4134
|
is_quantized = hf_quantizer is not None
|
|
4082
4135
|
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
|
@@ -4086,6 +4139,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4086
4139
|
|
|
4087
4140
|
# Model's definition arriving here is final (TP hooks added, quantized layers replaces)
|
|
4088
4141
|
expected_keys = list(model.state_dict().keys())
|
|
4142
|
+
|
|
4089
4143
|
if logger.level >= logging.WARNING:
|
|
4090
4144
|
verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
|
|
4091
4145
|
|
|
@@ -4108,7 +4162,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4108
4162
|
expanded_device_map = expand_device_map(device_map, expected_keys)
|
|
4109
4163
|
caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
|
|
4110
4164
|
|
|
4111
|
-
tp_plan = getattr(model, "_tp_plan", None)
|
|
4112
4165
|
error_msgs = []
|
|
4113
4166
|
|
|
4114
4167
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
|
@@ -4117,9 +4170,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4117
4170
|
for ckpt_file in checkpoint_files:
|
|
4118
4171
|
merged_state_dict.update(load_state_dict(ckpt_file, map_location="cpu", weights_only=weights_only))
|
|
4119
4172
|
state_dict = merged_state_dict
|
|
4120
|
-
error_msgs
|
|
4173
|
+
error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict)
|
|
4121
4174
|
# This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
|
|
4122
|
-
|
|
4175
|
+
unexpected_keys, mismatched_keys, conversion_errors = set(), set(), set()
|
|
4123
4176
|
else:
|
|
4124
4177
|
all_pointer = set()
|
|
4125
4178
|
# Checkpoints are safetensors
|
|
@@ -4143,17 +4196,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4143
4196
|
|
|
4144
4197
|
missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, conversion_errors = (
|
|
4145
4198
|
convert_and_load_state_dict_in_model(
|
|
4146
|
-
model,
|
|
4147
|
-
merged_state_dict,
|
|
4148
|
-
weight_mapping,
|
|
4149
|
-
tp_plan,
|
|
4150
|
-
hf_quantizer,
|
|
4151
|
-
dtype,
|
|
4152
|
-
device_map,
|
|
4153
|
-
model.dtype_plan,
|
|
4154
|
-
device_mesh,
|
|
4155
|
-
disk_offload_index,
|
|
4156
|
-
disk_offload_folder,
|
|
4199
|
+
model=model,
|
|
4200
|
+
state_dict=merged_state_dict,
|
|
4201
|
+
weight_mapping=weight_mapping,
|
|
4202
|
+
tp_plan=model._tp_plan,
|
|
4203
|
+
hf_quantizer=hf_quantizer,
|
|
4204
|
+
dtype=dtype,
|
|
4205
|
+
device_map=device_map,
|
|
4206
|
+
dtype_plan=model.dtype_plan,
|
|
4207
|
+
device_mesh=device_mesh,
|
|
4208
|
+
disk_offload_index=disk_offload_index,
|
|
4209
|
+
disk_offload_folder=disk_offload_folder,
|
|
4210
|
+
offload_buffers=offload_buffers,
|
|
4157
4211
|
)
|
|
4158
4212
|
)
|
|
4159
4213
|
|
|
@@ -4164,12 +4218,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4164
4218
|
# Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
|
|
4165
4219
|
model.mark_tied_weights_as_initialized()
|
|
4166
4220
|
|
|
4167
|
-
# Move missing (and potentially mismatched) keys back to
|
|
4168
|
-
# loading the weights as they
|
|
4169
|
-
|
|
4170
|
-
model.
|
|
4221
|
+
# Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
|
|
4222
|
+
# meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
|
|
4223
|
+
missing_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
|
|
4224
|
+
model._move_missing_keys_from_meta_to_device(missing_and_mismatched, device_map, device_mesh, hf_quantizer)
|
|
4171
4225
|
|
|
4172
|
-
# Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `
|
|
4226
|
+
# Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
|
|
4173
4227
|
model._initialize_missing_keys(is_quantized)
|
|
4174
4228
|
|
|
4175
4229
|
# Tie the weights
|
|
@@ -4178,34 +4232,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4178
4232
|
# Adjust missing and unexpected keys
|
|
4179
4233
|
missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys)
|
|
4180
4234
|
|
|
4181
|
-
# Post-processing for tensor parallelism
|
|
4182
|
-
if device_mesh is not None:
|
|
4183
|
-
# When using TP, the device map is a single device for all parameters
|
|
4184
|
-
tp_device = list(device_map.values())[0]
|
|
4185
|
-
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
|
|
4186
|
-
# not part of the state_dict (persistent=False)
|
|
4187
|
-
for buffer in model.buffers(): # TODO to avoid this buffer could be added to the ckpt
|
|
4188
|
-
if buffer.device != tp_device:
|
|
4189
|
-
buffer.data = buffer.to(tp_device)
|
|
4190
|
-
|
|
4191
|
-
# In this case, the top-most task module weights were not moved to device and parallelized as they
|
|
4192
|
-
# were not part of the loaded weights: do it now
|
|
4193
|
-
if missing_keys:
|
|
4194
|
-
state_dict = model.state_dict()
|
|
4195
|
-
for name in missing_keys:
|
|
4196
|
-
param = state_dict[name]
|
|
4197
|
-
# Shard the param
|
|
4198
|
-
shard_and_distribute_module(
|
|
4199
|
-
model,
|
|
4200
|
-
param.to(tp_device),
|
|
4201
|
-
param,
|
|
4202
|
-
name,
|
|
4203
|
-
None,
|
|
4204
|
-
False,
|
|
4205
|
-
device_mesh.get_local_rank(),
|
|
4206
|
-
device_mesh,
|
|
4207
|
-
)
|
|
4208
|
-
|
|
4209
4235
|
log_state_dict_report(
|
|
4210
4236
|
model=model,
|
|
4211
4237
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
@@ -4381,7 +4407,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4381
4407
|
)
|
|
4382
4408
|
self._use_kernels = False
|
|
4383
4409
|
|
|
4384
|
-
def get_compiled_call(self, compile_config:
|
|
4410
|
+
def get_compiled_call(self, compile_config: CompileConfig | None) -> Callable:
|
|
4385
4411
|
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
|
|
4386
4412
|
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
|
|
4387
4413
|
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
|
|
@@ -4403,33 +4429,54 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4403
4429
|
def is_backend_compatible(cls):
|
|
4404
4430
|
return cls._supports_attention_backend
|
|
4405
4431
|
|
|
4406
|
-
def
|
|
4407
|
-
self,
|
|
4432
|
+
def _move_missing_keys_from_meta_to_device(
|
|
4433
|
+
self,
|
|
4434
|
+
missing_keys: list[str],
|
|
4435
|
+
device_map: dict | None,
|
|
4436
|
+
device_mesh: "torch.distributed.device_mesh.DeviceMesh | None",
|
|
4437
|
+
hf_quantizer: HfQuantizer | None,
|
|
4408
4438
|
) -> None:
|
|
4409
|
-
"""Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts)
|
|
4410
|
-
from meta device to cpu.
|
|
4439
|
+
"""Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts)
|
|
4440
|
+
back from meta device to their device according to the `device_map` if any, else cpu. Takes care of sharding those
|
|
4441
|
+
missing parameters if `device_mesh` is provided, i.e. we are using TP.
|
|
4442
|
+
All non-persistent buffers are also moved back to the correct device (they are not part of the state_dict, but are
|
|
4443
|
+
not missing either).
|
|
4411
4444
|
"""
|
|
4412
4445
|
is_quantized = hf_quantizer is not None
|
|
4446
|
+
# This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
|
|
4447
|
+
if is_deepspeed_zero3_enabled() and not is_quantized:
|
|
4448
|
+
return
|
|
4413
4449
|
|
|
4414
4450
|
# In this case we need to move everything back
|
|
4415
4451
|
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
|
4416
|
-
# We only do it for the parameters, as the buffers are not initialized on the meta device by default
|
|
4417
4452
|
for key, param in self.named_parameters():
|
|
4418
|
-
value = torch.empty_like(param,
|
|
4453
|
+
value = torch.empty_like(param, device="cpu")
|
|
4454
|
+
_load_parameter_into_model(self, key, value)
|
|
4455
|
+
for key, buffer in self.named_buffers():
|
|
4456
|
+
value = torch.empty_like(buffer, device="cpu")
|
|
4419
4457
|
_load_parameter_into_model(self, key, value)
|
|
4420
4458
|
return
|
|
4421
4459
|
|
|
4422
|
-
model_state_dict = self.state_dict()
|
|
4423
4460
|
# The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway)
|
|
4424
4461
|
# This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they
|
|
4425
4462
|
# will be re-initialized for nothing (which can be quite long)
|
|
4426
4463
|
for key in missing_keys - self.all_tied_weights_keys.keys():
|
|
4427
|
-
param =
|
|
4428
|
-
|
|
4429
|
-
|
|
4430
|
-
|
|
4431
|
-
|
|
4432
|
-
|
|
4464
|
+
param = self.get_parameter_or_buffer(key)
|
|
4465
|
+
param_device = get_device(device_map, key, valid_torch_device=True)
|
|
4466
|
+
value = torch.empty_like(param, device=param_device)
|
|
4467
|
+
# For TP, we may need to shard the param
|
|
4468
|
+
if device_mesh is not None:
|
|
4469
|
+
shard_and_distribute_module(
|
|
4470
|
+
self, value, param, key, None, False, device_mesh.get_local_rank(), device_mesh
|
|
4471
|
+
)
|
|
4472
|
+
# Otherwise, just move it to device
|
|
4473
|
+
else:
|
|
4474
|
+
_load_parameter_into_model(self, key, value)
|
|
4475
|
+
# We need to move back non-persistent buffers as well, as they are not part of loaded weights anyway
|
|
4476
|
+
for key, buffer in self.named_non_persistent_buffers():
|
|
4477
|
+
buffer_device = get_device(device_map, key, valid_torch_device=True)
|
|
4478
|
+
value = torch.empty_like(buffer, device=buffer_device)
|
|
4479
|
+
_load_parameter_into_model(self, key, value)
|
|
4433
4480
|
|
|
4434
4481
|
def _initialize_missing_keys(self, is_quantized: bool) -> None:
|
|
4435
4482
|
"""
|
|
@@ -4457,8 +4504,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4457
4504
|
) -> tuple[set[str], set[str]]:
|
|
4458
4505
|
"""Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
|
|
4459
4506
|
raising unneeded warnings/errors.
|
|
4460
|
-
Also, set the `_is_hf_initialized` on tied weight keys, to avoid initializing them as they are going to
|
|
4461
|
-
be tied anyway.
|
|
4462
4507
|
"""
|
|
4463
4508
|
# Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
|
|
4464
4509
|
# (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
|
|
@@ -4517,6 +4562,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4517
4562
|
|
|
4518
4563
|
raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
|
|
4519
4564
|
|
|
4565
|
+
def named_non_persistent_buffers(
|
|
4566
|
+
self, recurse: bool = True, remove_duplicate: bool = True
|
|
4567
|
+
) -> Iterator[tuple[str, torch.Tensor]]:
|
|
4568
|
+
"""Similar to `named_buffers`, but only yield non-persistent ones. It is handy as it's not perfectly straightforward
|
|
4569
|
+
to know if they are persistent or not"""
|
|
4570
|
+
for name, tensor in self.named_buffers(recurse=recurse, remove_duplicate=remove_duplicate):
|
|
4571
|
+
# We have to grab the parent here, as the attribute `_non_persistent_buffers_set` is on the immediate
|
|
4572
|
+
# parent only
|
|
4573
|
+
parent, buf_name = name.rsplit(".", 1) if "." in name else ("", name)
|
|
4574
|
+
parent = self.get_submodule(parent)
|
|
4575
|
+
if buf_name in parent._non_persistent_buffers_set:
|
|
4576
|
+
yield name, tensor
|
|
4577
|
+
|
|
4520
4578
|
def train(self, mode: bool = True):
|
|
4521
4579
|
out = super().train(mode)
|
|
4522
4580
|
if self.use_kernels:
|
|
@@ -4559,7 +4617,7 @@ def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
|
|
|
4559
4617
|
return model
|
|
4560
4618
|
|
|
4561
4619
|
|
|
4562
|
-
def is_accelerator_device(device:
|
|
4620
|
+
def is_accelerator_device(device: str | int | torch.device) -> bool:
|
|
4563
4621
|
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
|
|
4564
4622
|
a proper `torch.device`.
|
|
4565
4623
|
"""
|
|
@@ -4569,7 +4627,41 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
|
|
4569
4627
|
return torch.device(device).type not in ["meta", "cpu"]
|
|
4570
4628
|
|
|
4571
4629
|
|
|
4572
|
-
def
|
|
4630
|
+
def get_total_byte_count(
|
|
4631
|
+
model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: HfQuantizer | None = None
|
|
4632
|
+
):
|
|
4633
|
+
"""
|
|
4634
|
+
This utility function calculates the total bytes count needed to load the model on each device.
|
|
4635
|
+
This is useful for caching_allocator_warmup as we want to know how much cache we need to pre-allocate.
|
|
4636
|
+
"""
|
|
4637
|
+
|
|
4638
|
+
total_byte_count = defaultdict(lambda: 0)
|
|
4639
|
+
tied_param_names = model.all_tied_weights_keys.keys()
|
|
4640
|
+
tp_plan = model._tp_plan if torch.distributed.is_available() and torch.distributed.is_initialized() else []
|
|
4641
|
+
|
|
4642
|
+
for param_name, device in accelerator_device_map.items():
|
|
4643
|
+
# Skip if the parameter has already been accounted for (tied weights)
|
|
4644
|
+
if param_name in tied_param_names:
|
|
4645
|
+
continue
|
|
4646
|
+
|
|
4647
|
+
param = model.get_parameter_or_buffer(param_name)
|
|
4648
|
+
|
|
4649
|
+
if hf_quantizer is not None:
|
|
4650
|
+
dtype_size = hf_quantizer.param_element_size(model, param_name, param)
|
|
4651
|
+
else:
|
|
4652
|
+
dtype_size = param.element_size()
|
|
4653
|
+
|
|
4654
|
+
param_byte_count = param.numel() * dtype_size
|
|
4655
|
+
|
|
4656
|
+
if len(tp_plan) > 0:
|
|
4657
|
+
is_part_of_plan = _get_parameter_tp_plan(param_name, tp_plan, is_weight=True) is not None
|
|
4658
|
+
param_byte_count //= torch.distributed.get_world_size() if is_part_of_plan else 1
|
|
4659
|
+
|
|
4660
|
+
total_byte_count[device] += param_byte_count
|
|
4661
|
+
return total_byte_count
|
|
4662
|
+
|
|
4663
|
+
|
|
4664
|
+
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: HfQuantizer | None):
|
|
4573
4665
|
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
|
4574
4666
|
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
|
4575
4667
|
the model, which is actually the loading speed bottleneck.
|
|
@@ -4588,8 +4680,6 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4588
4680
|
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
|
|
4589
4681
|
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
|
|
4590
4682
|
"""
|
|
4591
|
-
factor = 2 if hf_quantizer is None else hf_quantizer.get_accelerator_warm_up_factor()
|
|
4592
|
-
|
|
4593
4683
|
# Remove disk, cpu and meta devices, and cast to proper torch.device
|
|
4594
4684
|
accelerator_device_map = {
|
|
4595
4685
|
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
|
|
@@ -4597,40 +4687,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4597
4687
|
if not accelerator_device_map:
|
|
4598
4688
|
return
|
|
4599
4689
|
|
|
4600
|
-
|
|
4601
|
-
tp_plan_regex = (
|
|
4602
|
-
re.compile("|".join([re.escape(plan) for plan in tp_plan]))
|
|
4603
|
-
if _torch_distributed_available and torch.distributed.is_initialized()
|
|
4604
|
-
else None
|
|
4605
|
-
)
|
|
4606
|
-
total_byte_count = defaultdict(lambda: 0)
|
|
4607
|
-
tied_param_names = model.all_tied_weights_keys.keys()
|
|
4608
|
-
for param_name, device in accelerator_device_map.items():
|
|
4609
|
-
# Skip if the parameter has already been accounted for (tied weights)
|
|
4610
|
-
if param_name in tied_param_names:
|
|
4611
|
-
continue
|
|
4612
|
-
|
|
4613
|
-
# For example in the case of MXFP4 quantization, we need to update the param name to the original param name
|
|
4614
|
-
# because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
|
|
4615
|
-
if hf_quantizer is not None:
|
|
4616
|
-
param_name = hf_quantizer.get_param_name(param_name)
|
|
4617
|
-
|
|
4618
|
-
try:
|
|
4619
|
-
param = model.get_parameter_or_buffer(param_name)
|
|
4620
|
-
except AttributeError:
|
|
4621
|
-
# TODO: for now let's skip if we can't find the parameters
|
|
4622
|
-
if hf_quantizer is not None:
|
|
4623
|
-
continue
|
|
4624
|
-
raise AttributeError(f"Parameter {param_name} not found in model")
|
|
4625
|
-
|
|
4626
|
-
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
|
4627
|
-
param_byte_count = param.numel() * param.element_size()
|
|
4628
|
-
|
|
4629
|
-
if tp_plan_regex is not None:
|
|
4630
|
-
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
|
|
4631
|
-
param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
|
|
4632
|
-
|
|
4633
|
-
total_byte_count[device] += param_byte_count
|
|
4690
|
+
total_byte_count = get_total_byte_count(model, accelerator_device_map, hf_quantizer)
|
|
4634
4691
|
|
|
4635
4692
|
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
|
4636
4693
|
for device, byte_count in total_byte_count.items():
|
|
@@ -4650,9 +4707,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4650
4707
|
unused_memory = torch_accelerator_module.memory_reserved(
|
|
4651
4708
|
index
|
|
4652
4709
|
) - torch_accelerator_module.memory_allocated(index)
|
|
4653
|
-
byte_count = max(0, byte_count - unused_memory)
|
|
4654
|
-
#
|
|
4655
|
-
_ = torch.empty(byte_count //
|
|
4710
|
+
byte_count = int(max(0, byte_count - unused_memory))
|
|
4711
|
+
# We divide by 2 here as we allocate in fp16
|
|
4712
|
+
_ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
|
|
4656
4713
|
|
|
4657
4714
|
|
|
4658
4715
|
class AttentionInterface(GeneralInterface):
|