transformers 5.0.0__py3-none-any.whl → 5.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +36 -55
- transformers/activations.py +1 -1
- transformers/audio_utils.py +33 -32
- transformers/cache_utils.py +139 -32
- transformers/cli/chat.py +3 -3
- transformers/cli/serve.py +19 -49
- transformers/cli/transformers.py +1 -2
- transformers/configuration_utils.py +155 -129
- transformers/conversion_mapping.py +22 -158
- transformers/convert_slow_tokenizer.py +17 -227
- transformers/core_model_loading.py +185 -528
- transformers/data/data_collator.py +4 -12
- transformers/data/processors/glue.py +1 -0
- transformers/data/processors/utils.py +1 -0
- transformers/data/processors/xnli.py +1 -0
- transformers/dependency_versions_check.py +1 -0
- transformers/dependency_versions_table.py +7 -5
- transformers/distributed/configuration_utils.py +2 -1
- transformers/dynamic_module_utils.py +25 -24
- transformers/feature_extraction_sequence_utils.py +23 -19
- transformers/feature_extraction_utils.py +33 -64
- transformers/file_utils.py +1 -0
- transformers/generation/__init__.py +1 -11
- transformers/generation/candidate_generator.py +33 -80
- transformers/generation/configuration_utils.py +133 -189
- transformers/generation/continuous_batching/__init__.py +1 -4
- transformers/generation/continuous_batching/cache.py +25 -83
- transformers/generation/continuous_batching/cache_manager.py +45 -155
- transformers/generation/continuous_batching/continuous_api.py +147 -270
- transformers/generation/continuous_batching/requests.py +3 -51
- transformers/generation/continuous_batching/scheduler.py +105 -160
- transformers/generation/logits_process.py +128 -0
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/streamers.py +1 -0
- transformers/generation/utils.py +123 -122
- transformers/generation/watermarking.py +6 -8
- transformers/hf_argparser.py +13 -9
- transformers/hyperparameter_search.py +2 -1
- transformers/image_processing_base.py +23 -12
- transformers/image_processing_utils.py +15 -11
- transformers/image_processing_utils_fast.py +75 -85
- transformers/image_transforms.py +42 -73
- transformers/image_utils.py +32 -30
- transformers/initialization.py +0 -37
- transformers/integrations/__init__.py +2 -16
- transformers/integrations/accelerate.py +113 -58
- transformers/integrations/aqlm.py +66 -36
- transformers/integrations/awq.py +516 -45
- transformers/integrations/bitnet.py +105 -47
- transformers/integrations/bitsandbytes.py +202 -91
- transformers/integrations/deepspeed.py +4 -161
- transformers/integrations/eetq.py +82 -84
- transformers/integrations/executorch.py +1 -1
- transformers/integrations/fbgemm_fp8.py +145 -190
- transformers/integrations/finegrained_fp8.py +215 -249
- transformers/integrations/flash_attention.py +3 -3
- transformers/integrations/flex_attention.py +1 -1
- transformers/integrations/fp_quant.py +0 -90
- transformers/integrations/ggml.py +2 -11
- transformers/integrations/higgs.py +62 -37
- transformers/integrations/hub_kernels.py +8 -65
- transformers/integrations/integration_utils.py +3 -47
- transformers/integrations/mistral.py +0 -12
- transformers/integrations/mxfp4.py +80 -33
- transformers/integrations/peft.py +191 -483
- transformers/integrations/quanto.py +56 -77
- transformers/integrations/spqr.py +90 -42
- transformers/integrations/tensor_parallel.py +221 -167
- transformers/integrations/torchao.py +43 -35
- transformers/integrations/vptq.py +59 -40
- transformers/kernels/__init__.py +0 -0
- transformers/{models/pe_audio_video/processing_pe_audio_video.py → kernels/falcon_mamba/__init__.py} +3 -12
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +529 -0
- transformers/loss/loss_utils.py +0 -2
- transformers/masking_utils.py +55 -51
- transformers/model_debugging_utils.py +5 -4
- transformers/modelcard.py +194 -15
- transformers/modeling_attn_mask_utils.py +19 -19
- transformers/modeling_flash_attention_utils.py +27 -27
- transformers/modeling_gguf_pytorch_utils.py +24 -79
- transformers/modeling_layers.py +22 -21
- transformers/modeling_outputs.py +253 -242
- transformers/modeling_rope_utils.py +117 -138
- transformers/modeling_utils.py +739 -850
- transformers/models/__init__.py +0 -27
- transformers/models/afmoe/configuration_afmoe.py +33 -40
- transformers/models/afmoe/modeling_afmoe.py +54 -42
- transformers/models/afmoe/modular_afmoe.py +33 -23
- transformers/models/aimv2/configuration_aimv2.py +10 -2
- transformers/models/aimv2/modeling_aimv2.py +42 -47
- transformers/models/aimv2/modular_aimv2.py +19 -17
- transformers/models/albert/configuration_albert.py +2 -8
- transformers/models/albert/modeling_albert.py +69 -70
- transformers/models/albert/tokenization_albert.py +14 -5
- transformers/models/align/configuration_align.py +6 -8
- transformers/models/align/modeling_align.py +89 -94
- transformers/models/align/processing_align.py +30 -2
- transformers/models/altclip/configuration_altclip.py +7 -4
- transformers/models/altclip/modeling_altclip.py +103 -114
- transformers/models/altclip/processing_altclip.py +15 -2
- transformers/models/apertus/__init__.py +1 -0
- transformers/models/apertus/configuration_apertus.py +28 -23
- transformers/models/apertus/modeling_apertus.py +40 -39
- transformers/models/apertus/modular_apertus.py +38 -37
- transformers/models/arcee/configuration_arcee.py +30 -25
- transformers/models/arcee/modeling_arcee.py +39 -36
- transformers/models/arcee/modular_arcee.py +23 -20
- transformers/models/aria/configuration_aria.py +44 -31
- transformers/models/aria/image_processing_aria.py +27 -25
- transformers/models/aria/modeling_aria.py +106 -110
- transformers/models/aria/modular_aria.py +127 -118
- transformers/models/aria/processing_aria.py +35 -28
- transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +1 -0
- transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +6 -3
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +8 -6
- transformers/models/audioflamingo3/__init__.py +1 -0
- transformers/models/audioflamingo3/configuration_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +49 -58
- transformers/models/audioflamingo3/modular_audioflamingo3.py +43 -53
- transformers/models/audioflamingo3/processing_audioflamingo3.py +30 -33
- transformers/models/auto/auto_factory.py +7 -6
- transformers/models/auto/configuration_auto.py +5 -66
- transformers/models/auto/feature_extraction_auto.py +10 -14
- transformers/models/auto/image_processing_auto.py +41 -32
- transformers/models/auto/modeling_auto.py +188 -46
- transformers/models/auto/processing_auto.py +11 -24
- transformers/models/auto/tokenization_auto.py +588 -171
- transformers/models/auto/video_processing_auto.py +10 -12
- transformers/models/autoformer/configuration_autoformer.py +7 -4
- transformers/models/autoformer/modeling_autoformer.py +101 -104
- transformers/models/aya_vision/configuration_aya_vision.py +1 -4
- transformers/models/aya_vision/modeling_aya_vision.py +102 -71
- transformers/models/aya_vision/modular_aya_vision.py +74 -46
- transformers/models/aya_vision/processing_aya_vision.py +53 -25
- transformers/models/bamba/configuration_bamba.py +39 -34
- transformers/models/bamba/modeling_bamba.py +86 -82
- transformers/models/bamba/modular_bamba.py +72 -70
- transformers/models/bark/configuration_bark.py +8 -6
- transformers/models/bark/generation_configuration_bark.py +5 -3
- transformers/models/bark/modeling_bark.py +57 -54
- transformers/models/bark/processing_bark.py +41 -19
- transformers/models/bart/configuration_bart.py +6 -9
- transformers/models/bart/modeling_bart.py +126 -135
- transformers/models/barthez/tokenization_barthez.py +11 -3
- transformers/models/bartpho/tokenization_bartpho.py +7 -6
- transformers/models/beit/configuration_beit.py +11 -0
- transformers/models/beit/image_processing_beit.py +56 -53
- transformers/models/beit/image_processing_beit_fast.py +12 -10
- transformers/models/beit/modeling_beit.py +60 -69
- transformers/models/bert/configuration_bert.py +2 -12
- transformers/models/bert/modeling_bert.py +122 -114
- transformers/models/bert/tokenization_bert.py +23 -8
- transformers/models/bert/tokenization_bert_legacy.py +5 -3
- transformers/models/bert_generation/configuration_bert_generation.py +2 -17
- transformers/models/bert_generation/modeling_bert_generation.py +49 -49
- transformers/models/bert_generation/tokenization_bert_generation.py +3 -2
- transformers/models/bert_japanese/tokenization_bert_japanese.py +6 -5
- transformers/models/bertweet/tokenization_bertweet.py +3 -1
- transformers/models/big_bird/configuration_big_bird.py +9 -12
- transformers/models/big_bird/modeling_big_bird.py +109 -116
- transformers/models/big_bird/tokenization_big_bird.py +43 -16
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +117 -130
- transformers/models/biogpt/configuration_biogpt.py +2 -8
- transformers/models/biogpt/modeling_biogpt.py +76 -72
- transformers/models/biogpt/modular_biogpt.py +66 -62
- transformers/models/biogpt/tokenization_biogpt.py +5 -3
- transformers/models/bit/configuration_bit.py +1 -0
- transformers/models/bit/image_processing_bit.py +24 -21
- transformers/models/bit/image_processing_bit_fast.py +1 -0
- transformers/models/bit/modeling_bit.py +12 -25
- transformers/models/bitnet/configuration_bitnet.py +28 -23
- transformers/models/bitnet/modeling_bitnet.py +39 -36
- transformers/models/bitnet/modular_bitnet.py +6 -4
- transformers/models/blenderbot/configuration_blenderbot.py +5 -8
- transformers/models/blenderbot/modeling_blenderbot.py +96 -77
- transformers/models/blenderbot/tokenization_blenderbot.py +24 -18
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +5 -8
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +69 -79
- transformers/models/blenderbot_small/tokenization_blenderbot_small.py +3 -1
- transformers/models/blip/configuration_blip.py +10 -9
- transformers/models/blip/image_processing_blip.py +20 -17
- transformers/models/blip/image_processing_blip_fast.py +1 -0
- transformers/models/blip/modeling_blip.py +108 -117
- transformers/models/blip/modeling_blip_text.py +65 -73
- transformers/models/blip/processing_blip.py +36 -5
- transformers/models/blip_2/configuration_blip_2.py +2 -2
- transformers/models/blip_2/modeling_blip_2.py +118 -146
- transformers/models/blip_2/processing_blip_2.py +38 -8
- transformers/models/bloom/configuration_bloom.py +2 -5
- transformers/models/bloom/modeling_bloom.py +104 -77
- transformers/models/blt/configuration_blt.py +86 -94
- transformers/models/blt/modeling_blt.py +81 -238
- transformers/models/blt/modular_blt.py +65 -228
- transformers/models/bridgetower/configuration_bridgetower.py +2 -7
- transformers/models/bridgetower/image_processing_bridgetower.py +35 -34
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +16 -13
- transformers/models/bridgetower/modeling_bridgetower.py +119 -141
- transformers/models/bridgetower/processing_bridgetower.py +16 -2
- transformers/models/bros/configuration_bros.py +18 -24
- transformers/models/bros/modeling_bros.py +80 -90
- transformers/models/bros/processing_bros.py +12 -2
- transformers/models/byt5/tokenization_byt5.py +6 -4
- transformers/models/camembert/configuration_camembert.py +2 -8
- transformers/models/camembert/modeling_camembert.py +195 -196
- transformers/models/camembert/modular_camembert.py +54 -51
- transformers/models/camembert/tokenization_camembert.py +13 -6
- transformers/models/canine/configuration_canine.py +2 -4
- transformers/models/canine/modeling_canine.py +75 -84
- transformers/models/canine/tokenization_canine.py +1 -2
- transformers/models/chameleon/configuration_chameleon.py +34 -29
- transformers/models/chameleon/image_processing_chameleon.py +24 -21
- transformers/models/chameleon/image_processing_chameleon_fast.py +6 -5
- transformers/models/chameleon/modeling_chameleon.py +93 -142
- transformers/models/chameleon/processing_chameleon.py +41 -16
- transformers/models/chinese_clip/configuration_chinese_clip.py +8 -10
- transformers/models/chinese_clip/image_processing_chinese_clip.py +24 -21
- transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +1 -0
- transformers/models/chinese_clip/modeling_chinese_clip.py +92 -96
- transformers/models/chinese_clip/processing_chinese_clip.py +15 -2
- transformers/models/clap/configuration_clap.py +9 -4
- transformers/models/clap/feature_extraction_clap.py +12 -11
- transformers/models/clap/modeling_clap.py +123 -136
- transformers/models/clap/processing_clap.py +15 -2
- transformers/models/clip/configuration_clip.py +2 -4
- transformers/models/clip/image_processing_clip.py +24 -21
- transformers/models/clip/image_processing_clip_fast.py +1 -9
- transformers/models/clip/modeling_clip.py +65 -65
- transformers/models/clip/processing_clip.py +14 -2
- transformers/models/clip/tokenization_clip.py +46 -21
- transformers/models/clipseg/configuration_clipseg.py +2 -4
- transformers/models/clipseg/modeling_clipseg.py +109 -119
- transformers/models/clipseg/processing_clipseg.py +42 -19
- transformers/models/clvp/configuration_clvp.py +5 -15
- transformers/models/clvp/feature_extraction_clvp.py +10 -7
- transformers/models/clvp/modeling_clvp.py +146 -155
- transformers/models/clvp/number_normalizer.py +2 -1
- transformers/models/clvp/processing_clvp.py +20 -3
- transformers/models/clvp/tokenization_clvp.py +64 -1
- transformers/models/code_llama/tokenization_code_llama.py +44 -18
- transformers/models/codegen/configuration_codegen.py +4 -4
- transformers/models/codegen/modeling_codegen.py +53 -63
- transformers/models/codegen/tokenization_codegen.py +47 -17
- transformers/models/cohere/configuration_cohere.py +30 -25
- transformers/models/cohere/modeling_cohere.py +42 -40
- transformers/models/cohere/modular_cohere.py +29 -26
- transformers/models/cohere/tokenization_cohere.py +46 -15
- transformers/models/cohere2/configuration_cohere2.py +32 -31
- transformers/models/cohere2/modeling_cohere2.py +44 -42
- transformers/models/cohere2/modular_cohere2.py +54 -54
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +14 -13
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +58 -59
- transformers/models/cohere2_vision/modular_cohere2_vision.py +46 -45
- transformers/models/cohere2_vision/processing_cohere2_vision.py +36 -6
- transformers/models/colpali/configuration_colpali.py +1 -0
- transformers/models/colpali/modeling_colpali.py +16 -14
- transformers/models/colpali/modular_colpali.py +51 -11
- transformers/models/colpali/processing_colpali.py +52 -14
- transformers/models/colqwen2/modeling_colqwen2.py +28 -28
- transformers/models/colqwen2/modular_colqwen2.py +74 -37
- transformers/models/colqwen2/processing_colqwen2.py +52 -16
- transformers/models/conditional_detr/configuration_conditional_detr.py +2 -1
- transformers/models/conditional_detr/image_processing_conditional_detr.py +70 -67
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +36 -36
- transformers/models/conditional_detr/modeling_conditional_detr.py +87 -99
- transformers/models/conditional_detr/modular_conditional_detr.py +3 -49
- transformers/models/convbert/configuration_convbert.py +8 -11
- transformers/models/convbert/modeling_convbert.py +87 -94
- transformers/models/convbert/tokenization_convbert.py +1 -0
- transformers/models/convnext/configuration_convnext.py +1 -0
- transformers/models/convnext/image_processing_convnext.py +23 -20
- transformers/models/convnext/image_processing_convnext_fast.py +21 -16
- transformers/models/convnext/modeling_convnext.py +12 -9
- transformers/models/convnextv2/configuration_convnextv2.py +1 -0
- transformers/models/convnextv2/modeling_convnextv2.py +12 -9
- transformers/models/cpm/tokenization_cpm.py +7 -6
- transformers/models/cpm/tokenization_cpm_fast.py +5 -3
- transformers/models/cpmant/configuration_cpmant.py +1 -4
- transformers/models/cpmant/modeling_cpmant.py +40 -38
- transformers/models/cpmant/tokenization_cpmant.py +3 -1
- transformers/models/csm/configuration_csm.py +66 -58
- transformers/models/csm/generation_csm.py +35 -31
- transformers/models/csm/modeling_csm.py +85 -85
- transformers/models/csm/modular_csm.py +58 -58
- transformers/models/csm/processing_csm.py +68 -25
- transformers/models/ctrl/configuration_ctrl.py +1 -16
- transformers/models/ctrl/modeling_ctrl.py +44 -54
- transformers/models/ctrl/tokenization_ctrl.py +1 -0
- transformers/models/cvt/configuration_cvt.py +1 -0
- transformers/models/cvt/modeling_cvt.py +16 -20
- transformers/models/cwm/__init__.py +1 -0
- transformers/models/cwm/configuration_cwm.py +12 -8
- transformers/models/cwm/modeling_cwm.py +39 -37
- transformers/models/cwm/modular_cwm.py +12 -10
- transformers/models/d_fine/configuration_d_fine.py +5 -7
- transformers/models/d_fine/modeling_d_fine.py +128 -138
- transformers/models/d_fine/modular_d_fine.py +18 -33
- transformers/models/dab_detr/configuration_dab_detr.py +3 -6
- transformers/models/dab_detr/modeling_dab_detr.py +75 -81
- transformers/models/dac/configuration_dac.py +1 -0
- transformers/models/dac/feature_extraction_dac.py +9 -6
- transformers/models/dac/modeling_dac.py +26 -24
- transformers/models/data2vec/configuration_data2vec_audio.py +2 -4
- transformers/models/data2vec/configuration_data2vec_text.py +3 -11
- transformers/models/data2vec/configuration_data2vec_vision.py +1 -0
- transformers/models/data2vec/modeling_data2vec_audio.py +56 -57
- transformers/models/data2vec/modeling_data2vec_text.py +93 -98
- transformers/models/data2vec/modeling_data2vec_vision.py +45 -49
- transformers/models/data2vec/modular_data2vec_audio.py +1 -6
- transformers/models/data2vec/modular_data2vec_text.py +54 -58
- transformers/models/dbrx/configuration_dbrx.py +22 -36
- transformers/models/dbrx/modeling_dbrx.py +45 -42
- transformers/models/dbrx/modular_dbrx.py +33 -31
- transformers/models/deberta/configuration_deberta.py +1 -6
- transformers/models/deberta/modeling_deberta.py +60 -64
- transformers/models/deberta/tokenization_deberta.py +21 -9
- transformers/models/deberta_v2/configuration_deberta_v2.py +1 -6
- transformers/models/deberta_v2/modeling_deberta_v2.py +65 -71
- transformers/models/deberta_v2/tokenization_deberta_v2.py +29 -11
- transformers/models/decision_transformer/configuration_decision_transformer.py +2 -3
- transformers/models/decision_transformer/modeling_decision_transformer.py +56 -60
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +44 -39
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +43 -43
- transformers/models/deepseek_v2/modular_deepseek_v2.py +49 -48
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +45 -40
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +42 -45
- transformers/models/deepseek_v3/modular_deepseek_v3.py +9 -14
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +3 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +26 -25
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +10 -10
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -57
- transformers/models/deepseek_vl/modular_deepseek_vl.py +43 -14
- transformers/models/deepseek_vl/processing_deepseek_vl.py +41 -10
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +5 -3
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +35 -35
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +24 -20
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +61 -109
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +118 -146
- transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +44 -12
- transformers/models/deformable_detr/configuration_deformable_detr.py +3 -2
- transformers/models/deformable_detr/image_processing_deformable_detr.py +61 -59
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +28 -28
- transformers/models/deformable_detr/modeling_deformable_detr.py +82 -88
- transformers/models/deformable_detr/modular_deformable_detr.py +3 -1
- transformers/models/deit/configuration_deit.py +1 -0
- transformers/models/deit/image_processing_deit.py +21 -18
- transformers/models/deit/image_processing_deit_fast.py +1 -0
- transformers/models/deit/modeling_deit.py +22 -24
- transformers/models/depth_anything/configuration_depth_anything.py +4 -2
- transformers/models/depth_anything/modeling_depth_anything.py +10 -10
- transformers/models/depth_pro/configuration_depth_pro.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro.py +23 -22
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +10 -8
- transformers/models/depth_pro/modeling_depth_pro.py +27 -31
- transformers/models/detr/configuration_detr.py +2 -1
- transformers/models/detr/image_processing_detr.py +66 -64
- transformers/models/detr/image_processing_detr_fast.py +34 -33
- transformers/models/detr/modeling_detr.py +79 -95
- transformers/models/dia/configuration_dia.py +15 -9
- transformers/models/dia/feature_extraction_dia.py +9 -6
- transformers/models/dia/generation_dia.py +50 -48
- transformers/models/dia/modeling_dia.py +69 -78
- transformers/models/dia/modular_dia.py +56 -64
- transformers/models/dia/processing_dia.py +29 -39
- transformers/models/dia/tokenization_dia.py +6 -3
- transformers/models/diffllama/configuration_diffllama.py +30 -25
- transformers/models/diffllama/modeling_diffllama.py +49 -46
- transformers/models/diffllama/modular_diffllama.py +19 -17
- transformers/models/dinat/configuration_dinat.py +1 -0
- transformers/models/dinat/modeling_dinat.py +44 -47
- transformers/models/dinov2/configuration_dinov2.py +1 -0
- transformers/models/dinov2/modeling_dinov2.py +15 -15
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +1 -1
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +15 -16
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +9 -9
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +7 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +6 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +8 -5
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +9 -7
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +18 -19
- transformers/models/dinov3_vit/modular_dinov3_vit.py +15 -16
- transformers/models/distilbert/configuration_distilbert.py +2 -8
- transformers/models/distilbert/modeling_distilbert.py +55 -55
- transformers/models/distilbert/tokenization_distilbert.py +1 -13
- transformers/models/doge/__init__.py +1 -0
- transformers/models/doge/configuration_doge.py +32 -39
- transformers/models/doge/modeling_doge.py +49 -45
- transformers/models/doge/modular_doge.py +63 -71
- transformers/models/donut/configuration_donut_swin.py +1 -0
- transformers/models/donut/image_processing_donut.py +29 -26
- transformers/models/donut/image_processing_donut_fast.py +15 -9
- transformers/models/donut/modeling_donut_swin.py +58 -62
- transformers/models/donut/processing_donut.py +26 -5
- transformers/models/dots1/configuration_dots1.py +33 -41
- transformers/models/dots1/modeling_dots1.py +45 -54
- transformers/models/dots1/modular_dots1.py +4 -5
- transformers/models/dpr/configuration_dpr.py +2 -19
- transformers/models/dpr/modeling_dpr.py +39 -42
- transformers/models/dpr/tokenization_dpr.py +9 -19
- transformers/models/dpr/tokenization_dpr_fast.py +9 -7
- transformers/models/dpt/configuration_dpt.py +2 -1
- transformers/models/dpt/image_processing_dpt.py +66 -65
- transformers/models/dpt/image_processing_dpt_fast.py +20 -18
- transformers/models/dpt/modeling_dpt.py +30 -32
- transformers/models/dpt/modular_dpt.py +17 -15
- transformers/models/edgetam/configuration_edgetam.py +3 -2
- transformers/models/edgetam/modeling_edgetam.py +86 -86
- transformers/models/edgetam/modular_edgetam.py +26 -21
- transformers/models/edgetam_video/__init__.py +1 -0
- transformers/models/edgetam_video/configuration_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modeling_edgetam_video.py +158 -169
- transformers/models/edgetam_video/modular_edgetam_video.py +37 -30
- transformers/models/efficientloftr/configuration_efficientloftr.py +5 -4
- transformers/models/efficientloftr/image_processing_efficientloftr.py +16 -14
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +9 -9
- transformers/models/efficientloftr/modeling_efficientloftr.py +38 -59
- transformers/models/efficientloftr/modular_efficientloftr.py +3 -1
- transformers/models/efficientnet/configuration_efficientnet.py +1 -0
- transformers/models/efficientnet/image_processing_efficientnet.py +32 -28
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +19 -17
- transformers/models/efficientnet/modeling_efficientnet.py +15 -19
- transformers/models/electra/configuration_electra.py +3 -13
- transformers/models/electra/modeling_electra.py +103 -108
- transformers/models/emu3/configuration_emu3.py +17 -13
- transformers/models/emu3/image_processing_emu3.py +39 -44
- transformers/models/emu3/modeling_emu3.py +108 -148
- transformers/models/emu3/modular_emu3.py +73 -115
- transformers/models/emu3/processing_emu3.py +43 -18
- transformers/models/encodec/configuration_encodec.py +4 -2
- transformers/models/encodec/feature_extraction_encodec.py +13 -10
- transformers/models/encodec/modeling_encodec.py +29 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +2 -12
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +43 -37
- transformers/models/eomt/configuration_eomt.py +1 -0
- transformers/models/eomt/image_processing_eomt.py +56 -66
- transformers/models/eomt/image_processing_eomt_fast.py +33 -76
- transformers/models/eomt/modeling_eomt.py +18 -23
- transformers/models/eomt/modular_eomt.py +13 -18
- transformers/models/ernie/configuration_ernie.py +3 -24
- transformers/models/ernie/modeling_ernie.py +132 -127
- transformers/models/ernie/modular_ernie.py +103 -97
- transformers/models/ernie4_5/configuration_ernie4_5.py +27 -23
- transformers/models/ernie4_5/modeling_ernie4_5.py +38 -36
- transformers/models/ernie4_5/modular_ernie4_5.py +4 -3
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +36 -32
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +55 -56
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +46 -18
- transformers/models/esm/configuration_esm.py +15 -11
- transformers/models/esm/modeling_esm.py +34 -38
- transformers/models/esm/modeling_esmfold.py +49 -53
- transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
- transformers/models/esm/openfold_utils/loss.py +2 -1
- transformers/models/esm/openfold_utils/protein.py +16 -15
- transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
- transformers/models/esm/tokenization_esm.py +4 -2
- transformers/models/evolla/configuration_evolla.py +40 -50
- transformers/models/evolla/modeling_evolla.py +66 -71
- transformers/models/evolla/modular_evolla.py +47 -53
- transformers/models/evolla/processing_evolla.py +35 -23
- transformers/models/exaone4/configuration_exaone4.py +25 -23
- transformers/models/exaone4/modeling_exaone4.py +38 -35
- transformers/models/exaone4/modular_exaone4.py +46 -44
- transformers/models/falcon/configuration_falcon.py +26 -31
- transformers/models/falcon/modeling_falcon.py +80 -82
- transformers/models/falcon_h1/configuration_falcon_h1.py +51 -45
- transformers/models/falcon_h1/modeling_falcon_h1.py +82 -85
- transformers/models/falcon_h1/modular_falcon_h1.py +51 -56
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +2 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +82 -75
- transformers/models/falcon_mamba/modular_falcon_mamba.py +45 -28
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +6 -2
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +60 -76
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +3 -2
- transformers/models/flaubert/configuration_flaubert.py +5 -10
- transformers/models/flaubert/modeling_flaubert.py +143 -145
- transformers/models/flaubert/tokenization_flaubert.py +5 -3
- transformers/models/flava/configuration_flava.py +6 -5
- transformers/models/flava/image_processing_flava.py +67 -66
- transformers/models/flava/image_processing_flava_fast.py +49 -46
- transformers/models/flava/modeling_flava.py +136 -153
- transformers/models/flava/processing_flava.py +12 -2
- transformers/models/flex_olmo/__init__.py +1 -0
- transformers/models/flex_olmo/configuration_flex_olmo.py +32 -28
- transformers/models/flex_olmo/modeling_flex_olmo.py +47 -47
- transformers/models/flex_olmo/modular_flex_olmo.py +44 -40
- transformers/models/florence2/configuration_florence2.py +1 -0
- transformers/models/florence2/modeling_florence2.py +69 -111
- transformers/models/florence2/modular_florence2.py +101 -104
- transformers/models/florence2/processing_florence2.py +47 -18
- transformers/models/fnet/configuration_fnet.py +2 -6
- transformers/models/fnet/modeling_fnet.py +80 -83
- transformers/models/fnet/tokenization_fnet.py +1 -0
- transformers/models/focalnet/configuration_focalnet.py +1 -0
- transformers/models/focalnet/modeling_focalnet.py +45 -51
- transformers/models/fsmt/configuration_fsmt.py +17 -12
- transformers/models/fsmt/modeling_fsmt.py +48 -49
- transformers/models/fsmt/tokenization_fsmt.py +5 -3
- transformers/models/funnel/configuration_funnel.py +1 -8
- transformers/models/funnel/modeling_funnel.py +93 -99
- transformers/models/funnel/tokenization_funnel.py +27 -17
- transformers/models/fuyu/configuration_fuyu.py +34 -28
- transformers/models/fuyu/image_processing_fuyu.py +31 -29
- transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
- transformers/models/fuyu/modeling_fuyu.py +53 -53
- transformers/models/fuyu/processing_fuyu.py +34 -23
- transformers/models/gemma/configuration_gemma.py +30 -25
- transformers/models/gemma/modeling_gemma.py +50 -46
- transformers/models/gemma/modular_gemma.py +47 -42
- transformers/models/gemma/tokenization_gemma.py +30 -10
- transformers/models/gemma2/configuration_gemma2.py +35 -30
- transformers/models/gemma2/modeling_gemma2.py +42 -39
- transformers/models/gemma2/modular_gemma2.py +66 -63
- transformers/models/gemma3/configuration_gemma3.py +44 -44
- transformers/models/gemma3/image_processing_gemma3.py +31 -29
- transformers/models/gemma3/image_processing_gemma3_fast.py +13 -11
- transformers/models/gemma3/modeling_gemma3.py +207 -159
- transformers/models/gemma3/modular_gemma3.py +204 -153
- transformers/models/gemma3/processing_gemma3.py +5 -5
- transformers/models/gemma3n/configuration_gemma3n.py +26 -36
- transformers/models/gemma3n/feature_extraction_gemma3n.py +11 -9
- transformers/models/gemma3n/modeling_gemma3n.py +356 -222
- transformers/models/gemma3n/modular_gemma3n.py +207 -230
- transformers/models/gemma3n/processing_gemma3n.py +26 -12
- transformers/models/git/configuration_git.py +8 -5
- transformers/models/git/modeling_git.py +204 -266
- transformers/models/git/processing_git.py +14 -2
- transformers/models/glm/configuration_glm.py +28 -24
- transformers/models/glm/modeling_glm.py +40 -37
- transformers/models/glm/modular_glm.py +7 -4
- transformers/models/glm4/configuration_glm4.py +28 -24
- transformers/models/glm4/modeling_glm4.py +42 -40
- transformers/models/glm4/modular_glm4.py +10 -8
- transformers/models/glm46v/configuration_glm46v.py +1 -0
- transformers/models/glm46v/image_processing_glm46v.py +40 -35
- transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
- transformers/models/glm46v/modeling_glm46v.py +90 -137
- transformers/models/glm46v/modular_glm46v.py +3 -4
- transformers/models/glm46v/processing_glm46v.py +41 -7
- transformers/models/glm46v/video_processing_glm46v.py +11 -9
- transformers/models/glm4_moe/configuration_glm4_moe.py +32 -40
- transformers/models/glm4_moe/modeling_glm4_moe.py +42 -45
- transformers/models/glm4_moe/modular_glm4_moe.py +34 -42
- transformers/models/glm4v/configuration_glm4v.py +20 -18
- transformers/models/glm4v/image_processing_glm4v.py +40 -34
- transformers/models/glm4v/image_processing_glm4v_fast.py +9 -8
- transformers/models/glm4v/modeling_glm4v.py +205 -254
- transformers/models/glm4v/modular_glm4v.py +224 -210
- transformers/models/glm4v/processing_glm4v.py +41 -7
- transformers/models/glm4v/video_processing_glm4v.py +11 -9
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +125 -136
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +368 -377
- transformers/models/glm4v_moe/modular_glm4v_moe.py +169 -83
- transformers/models/glpn/configuration_glpn.py +1 -0
- transformers/models/glpn/image_processing_glpn.py +12 -11
- transformers/models/glpn/image_processing_glpn_fast.py +13 -11
- transformers/models/glpn/modeling_glpn.py +14 -16
- transformers/models/got_ocr2/configuration_got_ocr2.py +12 -4
- transformers/models/got_ocr2/image_processing_got_ocr2.py +24 -22
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +11 -9
- transformers/models/got_ocr2/modeling_got_ocr2.py +80 -77
- transformers/models/got_ocr2/modular_got_ocr2.py +51 -54
- transformers/models/got_ocr2/processing_got_ocr2.py +63 -42
- transformers/models/gpt2/configuration_gpt2.py +2 -13
- transformers/models/gpt2/modeling_gpt2.py +115 -120
- transformers/models/gpt2/tokenization_gpt2.py +46 -15
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +2 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +89 -79
- transformers/models/gpt_neo/configuration_gpt_neo.py +2 -9
- transformers/models/gpt_neo/modeling_gpt_neo.py +67 -83
- transformers/models/gpt_neox/configuration_gpt_neox.py +25 -25
- transformers/models/gpt_neox/modeling_gpt_neox.py +75 -76
- transformers/models/gpt_neox/modular_gpt_neox.py +66 -67
- transformers/models/gpt_neox/tokenization_gpt_neox.py +51 -9
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +19 -24
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +47 -46
- transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +28 -46
- transformers/models/gpt_oss/modeling_gpt_oss.py +121 -83
- transformers/models/gpt_oss/modular_gpt_oss.py +103 -64
- transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +87 -101
- transformers/models/granite/configuration_granite.py +33 -28
- transformers/models/granite/modeling_granite.py +46 -44
- transformers/models/granite/modular_granite.py +31 -29
- transformers/models/granite_speech/configuration_granite_speech.py +1 -0
- transformers/models/granite_speech/feature_extraction_granite_speech.py +3 -1
- transformers/models/granite_speech/modeling_granite_speech.py +52 -82
- transformers/models/granite_speech/processing_granite_speech.py +4 -11
- transformers/models/granitemoe/configuration_granitemoe.py +36 -31
- transformers/models/granitemoe/modeling_granitemoe.py +46 -41
- transformers/models/granitemoe/modular_granitemoe.py +27 -22
- transformers/models/granitemoehybrid/__init__.py +1 -0
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +47 -46
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +93 -97
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +21 -54
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +37 -33
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +61 -54
- transformers/models/granitemoeshared/modular_granitemoeshared.py +21 -19
- transformers/models/grounding_dino/configuration_grounding_dino.py +4 -6
- transformers/models/grounding_dino/image_processing_grounding_dino.py +62 -60
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +29 -28
- transformers/models/grounding_dino/modeling_grounding_dino.py +140 -155
- transformers/models/grounding_dino/modular_grounding_dino.py +3 -2
- transformers/models/grounding_dino/processing_grounding_dino.py +38 -10
- transformers/models/groupvit/configuration_groupvit.py +2 -4
- transformers/models/groupvit/modeling_groupvit.py +93 -107
- transformers/models/helium/configuration_helium.py +29 -25
- transformers/models/helium/modeling_helium.py +40 -38
- transformers/models/helium/modular_helium.py +7 -3
- transformers/models/herbert/tokenization_herbert.py +28 -10
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +1 -0
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -24
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -24
- transformers/models/hiera/configuration_hiera.py +1 -0
- transformers/models/hiera/modeling_hiera.py +66 -72
- transformers/models/hubert/configuration_hubert.py +2 -4
- transformers/models/hubert/modeling_hubert.py +37 -42
- transformers/models/hubert/modular_hubert.py +11 -13
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +31 -26
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +38 -35
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +6 -4
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +36 -31
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +42 -47
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +9 -9
- transformers/models/ibert/configuration_ibert.py +2 -4
- transformers/models/ibert/modeling_ibert.py +62 -82
- transformers/models/ibert/quant_modules.py +1 -0
- transformers/models/idefics/configuration_idefics.py +8 -5
- transformers/models/idefics/image_processing_idefics.py +15 -13
- transformers/models/idefics/modeling_idefics.py +82 -75
- transformers/models/idefics/perceiver.py +3 -1
- transformers/models/idefics/processing_idefics.py +48 -32
- transformers/models/idefics/vision.py +25 -24
- transformers/models/idefics2/configuration_idefics2.py +3 -1
- transformers/models/idefics2/image_processing_idefics2.py +32 -31
- transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
- transformers/models/idefics2/modeling_idefics2.py +101 -127
- transformers/models/idefics2/processing_idefics2.py +68 -10
- transformers/models/idefics3/configuration_idefics3.py +4 -1
- transformers/models/idefics3/image_processing_idefics3.py +43 -42
- transformers/models/idefics3/image_processing_idefics3_fast.py +15 -40
- transformers/models/idefics3/modeling_idefics3.py +90 -115
- transformers/models/idefics3/processing_idefics3.py +69 -15
- transformers/models/ijepa/configuration_ijepa.py +1 -0
- transformers/models/ijepa/modeling_ijepa.py +11 -10
- transformers/models/ijepa/modular_ijepa.py +7 -5
- transformers/models/imagegpt/configuration_imagegpt.py +2 -9
- transformers/models/imagegpt/image_processing_imagegpt.py +18 -17
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +16 -11
- transformers/models/imagegpt/modeling_imagegpt.py +65 -76
- transformers/models/informer/configuration_informer.py +9 -6
- transformers/models/informer/modeling_informer.py +86 -88
- transformers/models/informer/modular_informer.py +16 -14
- transformers/models/instructblip/configuration_instructblip.py +2 -2
- transformers/models/instructblip/modeling_instructblip.py +63 -103
- transformers/models/instructblip/processing_instructblip.py +36 -10
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +139 -157
- transformers/models/instructblipvideo/modular_instructblipvideo.py +64 -73
- transformers/models/instructblipvideo/processing_instructblipvideo.py +33 -14
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +8 -6
- transformers/models/internvl/configuration_internvl.py +1 -0
- transformers/models/internvl/modeling_internvl.py +106 -85
- transformers/models/internvl/modular_internvl.py +67 -47
- transformers/models/internvl/processing_internvl.py +45 -12
- transformers/models/internvl/video_processing_internvl.py +12 -10
- transformers/models/jamba/configuration_jamba.py +8 -5
- transformers/models/jamba/modeling_jamba.py +66 -68
- transformers/models/jamba/modular_jamba.py +55 -54
- transformers/models/janus/configuration_janus.py +1 -0
- transformers/models/janus/image_processing_janus.py +37 -35
- transformers/models/janus/image_processing_janus_fast.py +20 -18
- transformers/models/janus/modeling_janus.py +191 -115
- transformers/models/janus/modular_janus.py +84 -133
- transformers/models/janus/processing_janus.py +43 -17
- transformers/models/jetmoe/configuration_jetmoe.py +26 -24
- transformers/models/jetmoe/modeling_jetmoe.py +46 -43
- transformers/models/jetmoe/modular_jetmoe.py +33 -31
- transformers/models/kosmos2/configuration_kosmos2.py +9 -10
- transformers/models/kosmos2/modeling_kosmos2.py +173 -208
- transformers/models/kosmos2/processing_kosmos2.py +55 -40
- transformers/models/kosmos2_5/__init__.py +1 -0
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +9 -8
- transformers/models/kosmos2_5/image_processing_kosmos2_5.py +12 -10
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +13 -4
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +118 -132
- transformers/models/kosmos2_5/processing_kosmos2_5.py +29 -8
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +28 -31
- transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +14 -12
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +100 -110
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +22 -28
- transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +8 -2
- transformers/models/layoutlm/configuration_layoutlm.py +2 -14
- transformers/models/layoutlm/modeling_layoutlm.py +72 -77
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +17 -14
- transformers/models/layoutlmv2/image_processing_layoutlmv2.py +21 -18
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +9 -7
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +50 -64
- transformers/models/layoutlmv2/processing_layoutlmv2.py +44 -14
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +126 -73
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +19 -16
- transformers/models/layoutlmv3/image_processing_layoutlmv3.py +26 -24
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +11 -9
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +56 -82
- transformers/models/layoutlmv3/processing_layoutlmv3.py +46 -14
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +134 -74
- transformers/models/layoutxlm/configuration_layoutxlm.py +17 -14
- transformers/models/layoutxlm/modular_layoutxlm.py +1 -0
- transformers/models/layoutxlm/processing_layoutxlm.py +44 -14
- transformers/models/layoutxlm/tokenization_layoutxlm.py +113 -77
- transformers/models/led/configuration_led.py +12 -8
- transformers/models/led/modeling_led.py +266 -124
- transformers/models/levit/configuration_levit.py +1 -0
- transformers/models/levit/image_processing_levit.py +21 -19
- transformers/models/levit/image_processing_levit_fast.py +5 -4
- transformers/models/levit/modeling_levit.py +19 -38
- transformers/models/lfm2/configuration_lfm2.py +30 -27
- transformers/models/lfm2/modeling_lfm2.py +50 -47
- transformers/models/lfm2/modular_lfm2.py +30 -29
- transformers/models/lfm2_moe/__init__.py +1 -0
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +9 -6
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +53 -61
- transformers/models/lfm2_moe/modular_lfm2_moe.py +37 -13
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +1 -4
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +12 -41
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +66 -84
- transformers/models/lfm2_vl/modular_lfm2_vl.py +56 -70
- transformers/models/lfm2_vl/processing_lfm2_vl.py +76 -96
- transformers/models/lightglue/image_processing_lightglue.py +15 -16
- transformers/models/lightglue/image_processing_lightglue_fast.py +9 -9
- transformers/models/lightglue/modeling_lightglue.py +31 -31
- transformers/models/lightglue/modular_lightglue.py +28 -29
- transformers/models/lilt/configuration_lilt.py +2 -6
- transformers/models/lilt/modeling_lilt.py +70 -76
- transformers/models/llama/configuration_llama.py +31 -26
- transformers/models/llama/modeling_llama.py +39 -36
- transformers/models/llama/tokenization_llama.py +44 -14
- transformers/models/llama4/configuration_llama4.py +30 -27
- transformers/models/llama4/image_processing_llama4_fast.py +14 -12
- transformers/models/llama4/modeling_llama4.py +113 -120
- transformers/models/llama4/processing_llama4.py +57 -33
- transformers/models/llava/configuration_llava.py +1 -10
- transformers/models/llava/image_processing_llava.py +28 -25
- transformers/models/llava/image_processing_llava_fast.py +11 -9
- transformers/models/llava/modeling_llava.py +109 -85
- transformers/models/llava/processing_llava.py +51 -18
- transformers/models/llava_next/configuration_llava_next.py +2 -2
- transformers/models/llava_next/image_processing_llava_next.py +45 -43
- transformers/models/llava_next/image_processing_llava_next_fast.py +13 -11
- transformers/models/llava_next/modeling_llava_next.py +107 -110
- transformers/models/llava_next/processing_llava_next.py +47 -18
- transformers/models/llava_next_video/configuration_llava_next_video.py +7 -4
- transformers/models/llava_next_video/modeling_llava_next_video.py +158 -175
- transformers/models/llava_next_video/modular_llava_next_video.py +150 -155
- transformers/models/llava_next_video/processing_llava_next_video.py +63 -21
- transformers/models/llava_next_video/video_processing_llava_next_video.py +1 -0
- transformers/models/llava_onevision/configuration_llava_onevision.py +7 -4
- transformers/models/llava_onevision/image_processing_llava_onevision.py +42 -40
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +15 -14
- transformers/models/llava_onevision/modeling_llava_onevision.py +169 -177
- transformers/models/llava_onevision/modular_llava_onevision.py +156 -163
- transformers/models/llava_onevision/processing_llava_onevision.py +53 -21
- transformers/models/llava_onevision/video_processing_llava_onevision.py +1 -0
- transformers/models/longcat_flash/__init__.py +1 -0
- transformers/models/longcat_flash/configuration_longcat_flash.py +42 -37
- transformers/models/longcat_flash/modeling_longcat_flash.py +36 -36
- transformers/models/longcat_flash/modular_longcat_flash.py +21 -21
- transformers/models/longformer/configuration_longformer.py +5 -5
- transformers/models/longformer/modeling_longformer.py +101 -105
- transformers/models/longt5/configuration_longt5.py +7 -9
- transformers/models/longt5/modeling_longt5.py +49 -49
- transformers/models/luke/configuration_luke.py +2 -8
- transformers/models/luke/modeling_luke.py +181 -188
- transformers/models/luke/tokenization_luke.py +140 -107
- transformers/models/lxmert/configuration_lxmert.py +1 -16
- transformers/models/lxmert/modeling_lxmert.py +74 -65
- transformers/models/m2m_100/configuration_m2m_100.py +9 -7
- transformers/models/m2m_100/modeling_m2m_100.py +71 -83
- transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
- transformers/models/mamba/configuration_mamba.py +2 -1
- transformers/models/mamba/modeling_mamba.py +66 -58
- transformers/models/mamba2/configuration_mamba2.py +8 -5
- transformers/models/mamba2/modeling_mamba2.py +69 -68
- transformers/models/marian/configuration_marian.py +5 -10
- transformers/models/marian/modeling_marian.py +87 -93
- transformers/models/marian/tokenization_marian.py +6 -6
- transformers/models/markuplm/configuration_markuplm.py +7 -4
- transformers/models/markuplm/feature_extraction_markuplm.py +2 -1
- transformers/models/markuplm/modeling_markuplm.py +70 -69
- transformers/models/markuplm/processing_markuplm.py +38 -31
- transformers/models/markuplm/tokenization_markuplm.py +136 -93
- transformers/models/mask2former/configuration_mask2former.py +8 -5
- transformers/models/mask2former/image_processing_mask2former.py +85 -84
- transformers/models/mask2former/image_processing_mask2former_fast.py +40 -37
- transformers/models/mask2former/modeling_mask2former.py +103 -118
- transformers/models/mask2former/modular_mask2former.py +8 -6
- transformers/models/maskformer/configuration_maskformer.py +9 -6
- transformers/models/maskformer/configuration_maskformer_swin.py +1 -0
- transformers/models/maskformer/image_processing_maskformer.py +85 -84
- transformers/models/maskformer/image_processing_maskformer_fast.py +40 -36
- transformers/models/maskformer/modeling_maskformer.py +65 -79
- transformers/models/maskformer/modeling_maskformer_swin.py +32 -36
- transformers/models/mbart/configuration_mbart.py +4 -9
- transformers/models/mbart/modeling_mbart.py +116 -131
- transformers/models/mbart/tokenization_mbart.py +54 -11
- transformers/models/mbart50/tokenization_mbart50.py +13 -8
- transformers/models/megatron_bert/configuration_megatron_bert.py +3 -13
- transformers/models/megatron_bert/modeling_megatron_bert.py +150 -148
- transformers/models/metaclip_2/configuration_metaclip_2.py +1 -4
- transformers/models/metaclip_2/modeling_metaclip_2.py +84 -91
- transformers/models/metaclip_2/modular_metaclip_2.py +45 -61
- transformers/models/mgp_str/configuration_mgp_str.py +1 -0
- transformers/models/mgp_str/modeling_mgp_str.py +18 -20
- transformers/models/mgp_str/processing_mgp_str.py +20 -3
- transformers/models/mgp_str/tokenization_mgp_str.py +3 -1
- transformers/models/mimi/configuration_mimi.py +40 -42
- transformers/models/mimi/modeling_mimi.py +113 -142
- transformers/models/minimax/__init__.py +1 -0
- transformers/models/minimax/configuration_minimax.py +43 -37
- transformers/models/minimax/modeling_minimax.py +51 -61
- transformers/models/minimax/modular_minimax.py +62 -68
- transformers/models/ministral/configuration_ministral.py +29 -25
- transformers/models/ministral/modeling_ministral.py +38 -36
- transformers/models/ministral/modular_ministral.py +37 -32
- transformers/models/ministral3/configuration_ministral3.py +27 -24
- transformers/models/ministral3/modeling_ministral3.py +37 -36
- transformers/models/ministral3/modular_ministral3.py +5 -4
- transformers/models/mistral/configuration_mistral.py +29 -24
- transformers/models/mistral/modeling_mistral.py +37 -36
- transformers/models/mistral/modular_mistral.py +12 -11
- transformers/models/mistral3/configuration_mistral3.py +1 -4
- transformers/models/mistral3/modeling_mistral3.py +86 -89
- transformers/models/mistral3/modular_mistral3.py +68 -69
- transformers/models/mixtral/configuration_mixtral.py +34 -29
- transformers/models/mixtral/modeling_mixtral.py +45 -50
- transformers/models/mixtral/modular_mixtral.py +31 -32
- transformers/models/mlcd/configuration_mlcd.py +1 -0
- transformers/models/mlcd/modeling_mlcd.py +14 -20
- transformers/models/mlcd/modular_mlcd.py +13 -17
- transformers/models/mllama/configuration_mllama.py +15 -10
- transformers/models/mllama/image_processing_mllama.py +25 -23
- transformers/models/mllama/image_processing_mllama_fast.py +11 -11
- transformers/models/mllama/modeling_mllama.py +94 -105
- transformers/models/mllama/processing_mllama.py +55 -6
- transformers/models/mluke/tokenization_mluke.py +107 -101
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +3 -5
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +140 -155
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +3 -5
- transformers/models/mobilebert/configuration_mobilebert.py +2 -4
- transformers/models/mobilebert/modeling_mobilebert.py +85 -77
- transformers/models/mobilebert/tokenization_mobilebert.py +1 -0
- transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +1 -0
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +23 -20
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +1 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +16 -15
- transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +1 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +51 -48
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +15 -13
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +22 -24
- transformers/models/mobilevit/configuration_mobilevit.py +1 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +49 -46
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +14 -12
- transformers/models/mobilevit/modeling_mobilevit.py +21 -28
- transformers/models/mobilevitv2/configuration_mobilevitv2.py +1 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +22 -28
- transformers/models/modernbert/configuration_modernbert.py +42 -44
- transformers/models/modernbert/modeling_modernbert.py +133 -145
- transformers/models/modernbert/modular_modernbert.py +170 -186
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +40 -40
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +57 -62
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +86 -94
- transformers/models/moonshine/configuration_moonshine.py +31 -34
- transformers/models/moonshine/modeling_moonshine.py +71 -71
- transformers/models/moonshine/modular_moonshine.py +83 -88
- transformers/models/moshi/configuration_moshi.py +23 -46
- transformers/models/moshi/modeling_moshi.py +187 -157
- transformers/models/mpnet/configuration_mpnet.py +2 -6
- transformers/models/mpnet/modeling_mpnet.py +57 -62
- transformers/models/mpnet/tokenization_mpnet.py +15 -4
- transformers/models/mpt/configuration_mpt.py +9 -5
- transformers/models/mpt/modeling_mpt.py +60 -60
- transformers/models/mra/configuration_mra.py +2 -8
- transformers/models/mra/modeling_mra.py +57 -64
- transformers/models/mt5/configuration_mt5.py +8 -10
- transformers/models/mt5/modeling_mt5.py +95 -87
- transformers/models/musicgen/configuration_musicgen.py +8 -12
- transformers/models/musicgen/modeling_musicgen.py +122 -118
- transformers/models/musicgen/processing_musicgen.py +21 -3
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +8 -15
- transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +9 -8
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +123 -117
- transformers/models/musicgen_melody/processing_musicgen_melody.py +22 -3
- transformers/models/mvp/configuration_mvp.py +5 -8
- transformers/models/mvp/modeling_mvp.py +123 -135
- transformers/models/myt5/tokenization_myt5.py +10 -8
- transformers/models/nanochat/configuration_nanochat.py +8 -5
- transformers/models/nanochat/modeling_nanochat.py +40 -37
- transformers/models/nanochat/modular_nanochat.py +14 -12
- transformers/models/nemotron/configuration_nemotron.py +30 -25
- transformers/models/nemotron/modeling_nemotron.py +57 -56
- transformers/models/nllb/tokenization_nllb.py +28 -12
- transformers/models/nllb_moe/configuration_nllb_moe.py +9 -7
- transformers/models/nllb_moe/modeling_nllb_moe.py +69 -77
- transformers/models/nougat/image_processing_nougat.py +32 -29
- transformers/models/nougat/image_processing_nougat_fast.py +14 -12
- transformers/models/nougat/processing_nougat.py +39 -37
- transformers/models/nougat/tokenization_nougat.py +73 -18
- transformers/models/nystromformer/configuration_nystromformer.py +2 -8
- transformers/models/nystromformer/modeling_nystromformer.py +63 -74
- transformers/models/olmo/configuration_olmo.py +28 -23
- transformers/models/olmo/modeling_olmo.py +39 -36
- transformers/models/olmo/modular_olmo.py +11 -7
- transformers/models/olmo2/configuration_olmo2.py +28 -23
- transformers/models/olmo2/modeling_olmo2.py +41 -37
- transformers/models/olmo2/modular_olmo2.py +32 -29
- transformers/models/olmo3/__init__.py +1 -0
- transformers/models/olmo3/configuration_olmo3.py +30 -26
- transformers/models/olmo3/modeling_olmo3.py +39 -36
- transformers/models/olmo3/modular_olmo3.py +40 -37
- transformers/models/olmoe/configuration_olmoe.py +33 -29
- transformers/models/olmoe/modeling_olmoe.py +46 -52
- transformers/models/olmoe/modular_olmoe.py +15 -16
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +4 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +47 -53
- transformers/models/omdet_turbo/processing_omdet_turbo.py +67 -19
- transformers/models/oneformer/configuration_oneformer.py +8 -5
- transformers/models/oneformer/image_processing_oneformer.py +84 -83
- transformers/models/oneformer/image_processing_oneformer_fast.py +42 -41
- transformers/models/oneformer/modeling_oneformer.py +171 -147
- transformers/models/oneformer/processing_oneformer.py +43 -28
- transformers/models/openai/configuration_openai.py +1 -16
- transformers/models/openai/modeling_openai.py +51 -65
- transformers/models/openai/tokenization_openai.py +47 -8
- transformers/models/opt/configuration_opt.py +7 -6
- transformers/models/opt/modeling_opt.py +76 -78
- transformers/models/ovis2/__init__.py +1 -0
- transformers/models/ovis2/configuration_ovis2.py +1 -0
- transformers/models/ovis2/image_processing_ovis2.py +24 -22
- transformers/models/ovis2/image_processing_ovis2_fast.py +11 -9
- transformers/models/ovis2/modeling_ovis2.py +142 -111
- transformers/models/ovis2/modular_ovis2.py +45 -90
- transformers/models/ovis2/processing_ovis2.py +40 -12
- transformers/models/owlv2/configuration_owlv2.py +2 -4
- transformers/models/owlv2/image_processing_owlv2.py +21 -20
- transformers/models/owlv2/image_processing_owlv2_fast.py +15 -12
- transformers/models/owlv2/modeling_owlv2.py +117 -133
- transformers/models/owlv2/modular_owlv2.py +14 -11
- transformers/models/owlv2/processing_owlv2.py +49 -20
- transformers/models/owlvit/configuration_owlvit.py +2 -4
- transformers/models/owlvit/image_processing_owlvit.py +22 -21
- transformers/models/owlvit/image_processing_owlvit_fast.py +3 -2
- transformers/models/owlvit/modeling_owlvit.py +116 -132
- transformers/models/owlvit/processing_owlvit.py +48 -20
- transformers/models/paligemma/configuration_paligemma.py +1 -4
- transformers/models/paligemma/modeling_paligemma.py +93 -103
- transformers/models/paligemma/processing_paligemma.py +66 -13
- transformers/models/parakeet/configuration_parakeet.py +14 -7
- transformers/models/parakeet/feature_extraction_parakeet.py +12 -10
- transformers/models/parakeet/modeling_parakeet.py +28 -32
- transformers/models/parakeet/modular_parakeet.py +20 -23
- transformers/models/parakeet/processing_parakeet.py +5 -13
- transformers/models/parakeet/{tokenization_parakeet.py → tokenization_parakeet_fast.py} +7 -5
- transformers/models/patchtsmixer/configuration_patchtsmixer.py +8 -5
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +62 -70
- transformers/models/patchtst/configuration_patchtst.py +9 -6
- transformers/models/patchtst/modeling_patchtst.py +80 -97
- transformers/models/pegasus/configuration_pegasus.py +5 -8
- transformers/models/pegasus/modeling_pegasus.py +66 -72
- transformers/models/pegasus/tokenization_pegasus.py +45 -15
- transformers/models/pegasus_x/configuration_pegasus_x.py +4 -5
- transformers/models/pegasus_x/modeling_pegasus_x.py +52 -55
- transformers/models/perceiver/configuration_perceiver.py +1 -0
- transformers/models/perceiver/image_processing_perceiver.py +25 -22
- transformers/models/perceiver/image_processing_perceiver_fast.py +9 -7
- transformers/models/perceiver/modeling_perceiver.py +146 -165
- transformers/models/perceiver/tokenization_perceiver.py +6 -3
- transformers/models/perception_lm/configuration_perception_lm.py +1 -0
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +10 -8
- transformers/models/perception_lm/modeling_perception_lm.py +70 -71
- transformers/models/perception_lm/modular_perception_lm.py +61 -65
- transformers/models/perception_lm/processing_perception_lm.py +47 -13
- transformers/models/perception_lm/video_processing_perception_lm.py +1 -0
- transformers/models/persimmon/configuration_persimmon.py +28 -23
- transformers/models/persimmon/modeling_persimmon.py +45 -43
- transformers/models/phi/configuration_phi.py +28 -23
- transformers/models/phi/modeling_phi.py +43 -40
- transformers/models/phi/modular_phi.py +24 -23
- transformers/models/phi3/configuration_phi3.py +33 -28
- transformers/models/phi3/modeling_phi3.py +38 -36
- transformers/models/phi3/modular_phi3.py +17 -13
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +33 -30
- transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +9 -7
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +78 -95
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +80 -98
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +44 -7
- transformers/models/phimoe/configuration_phimoe.py +36 -31
- transformers/models/phimoe/modeling_phimoe.py +45 -50
- transformers/models/phimoe/modular_phimoe.py +4 -3
- transformers/models/phobert/tokenization_phobert.py +6 -4
- transformers/models/pix2struct/configuration_pix2struct.py +10 -12
- transformers/models/pix2struct/image_processing_pix2struct.py +19 -15
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +15 -12
- transformers/models/pix2struct/modeling_pix2struct.py +52 -58
- transformers/models/pix2struct/processing_pix2struct.py +30 -5
- transformers/models/pixtral/configuration_pixtral.py +14 -11
- transformers/models/pixtral/image_processing_pixtral.py +28 -26
- transformers/models/pixtral/image_processing_pixtral_fast.py +11 -10
- transformers/models/pixtral/modeling_pixtral.py +34 -28
- transformers/models/pixtral/processing_pixtral.py +53 -21
- transformers/models/plbart/configuration_plbart.py +5 -8
- transformers/models/plbart/modeling_plbart.py +106 -119
- transformers/models/plbart/modular_plbart.py +33 -39
- transformers/models/plbart/tokenization_plbart.py +7 -4
- transformers/models/poolformer/configuration_poolformer.py +1 -0
- transformers/models/poolformer/image_processing_poolformer.py +24 -21
- transformers/models/poolformer/image_processing_poolformer_fast.py +15 -13
- transformers/models/poolformer/modeling_poolformer.py +13 -23
- transformers/models/pop2piano/configuration_pop2piano.py +8 -7
- transformers/models/pop2piano/feature_extraction_pop2piano.py +9 -6
- transformers/models/pop2piano/modeling_pop2piano.py +24 -26
- transformers/models/pop2piano/processing_pop2piano.py +33 -25
- transformers/models/pop2piano/tokenization_pop2piano.py +23 -15
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +3 -3
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +21 -20
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +13 -16
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +13 -16
- transformers/models/prophetnet/configuration_prophetnet.py +38 -37
- transformers/models/prophetnet/modeling_prophetnet.py +131 -114
- transformers/models/prophetnet/tokenization_prophetnet.py +16 -14
- transformers/models/pvt/configuration_pvt.py +1 -0
- transformers/models/pvt/image_processing_pvt.py +27 -24
- transformers/models/pvt/image_processing_pvt_fast.py +2 -1
- transformers/models/pvt/modeling_pvt.py +21 -21
- transformers/models/pvt_v2/configuration_pvt_v2.py +4 -2
- transformers/models/pvt_v2/modeling_pvt_v2.py +25 -28
- transformers/models/qwen2/configuration_qwen2.py +25 -32
- transformers/models/qwen2/modeling_qwen2.py +38 -36
- transformers/models/qwen2/modular_qwen2.py +12 -11
- transformers/models/qwen2/tokenization_qwen2.py +23 -12
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +26 -32
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +277 -340
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +211 -278
- transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +49 -41
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +35 -29
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +148 -203
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +118 -93
- transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +43 -7
- transformers/models/qwen2_audio/configuration_qwen2_audio.py +1 -0
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +40 -40
- transformers/models/qwen2_audio/processing_qwen2_audio.py +42 -13
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +35 -42
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +46 -51
- transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +34 -29
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +42 -41
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +15 -12
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +153 -199
- transformers/models/qwen2_vl/processing_qwen2_vl.py +44 -7
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +18 -38
- transformers/models/qwen3/configuration_qwen3.py +27 -34
- transformers/models/qwen3/modeling_qwen3.py +39 -36
- transformers/models/qwen3/modular_qwen3.py +6 -4
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +32 -39
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +46 -51
- transformers/models/qwen3_moe/modular_qwen3_moe.py +13 -10
- transformers/models/qwen3_next/configuration_qwen3_next.py +35 -45
- transformers/models/qwen3_next/modeling_qwen3_next.py +51 -47
- transformers/models/qwen3_next/modular_qwen3_next.py +35 -34
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +101 -135
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +252 -355
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +196 -250
- transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +48 -40
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +29 -27
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +155 -233
- transformers/models/qwen3_vl/modular_qwen3_vl.py +179 -206
- transformers/models/qwen3_vl/processing_qwen3_vl.py +42 -6
- transformers/models/qwen3_vl/video_processing_qwen3_vl.py +12 -10
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +30 -23
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +303 -358
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +124 -87
- transformers/models/rag/configuration_rag.py +15 -6
- transformers/models/rag/modeling_rag.py +130 -127
- transformers/models/rag/retrieval_rag.py +5 -3
- transformers/models/rag/tokenization_rag.py +50 -0
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +30 -29
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +42 -53
- transformers/models/reformer/configuration_reformer.py +8 -7
- transformers/models/reformer/modeling_reformer.py +69 -80
- transformers/models/reformer/tokenization_reformer.py +31 -11
- transformers/models/regnet/configuration_regnet.py +1 -0
- transformers/models/regnet/modeling_regnet.py +8 -15
- transformers/models/rembert/configuration_rembert.py +2 -8
- transformers/models/rembert/modeling_rembert.py +111 -121
- transformers/models/rembert/tokenization_rembert.py +12 -2
- transformers/models/resnet/configuration_resnet.py +1 -0
- transformers/models/resnet/modeling_resnet.py +13 -27
- transformers/models/roberta/configuration_roberta.py +3 -11
- transformers/models/roberta/modeling_roberta.py +93 -94
- transformers/models/roberta/modular_roberta.py +58 -58
- transformers/models/roberta/tokenization_roberta.py +29 -17
- transformers/models/roberta/tokenization_roberta_old.py +4 -2
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +3 -11
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +93 -94
- transformers/models/roc_bert/configuration_roc_bert.py +2 -8
- transformers/models/roc_bert/modeling_roc_bert.py +121 -122
- transformers/models/roc_bert/tokenization_roc_bert.py +94 -88
- transformers/models/roformer/configuration_roformer.py +3 -13
- transformers/models/roformer/modeling_roformer.py +81 -85
- transformers/models/roformer/tokenization_roformer.py +412 -74
- transformers/models/roformer/tokenization_roformer_fast.py +160 -0
- transformers/models/roformer/tokenization_utils.py +1 -0
- transformers/models/rt_detr/configuration_rt_detr.py +2 -1
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +1 -0
- transformers/models/rt_detr/image_processing_rt_detr.py +55 -54
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +26 -26
- transformers/models/rt_detr/modeling_rt_detr.py +90 -99
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +6 -13
- transformers/models/rt_detr/modular_rt_detr.py +16 -16
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +4 -6
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +90 -101
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +12 -19
- transformers/models/rwkv/configuration_rwkv.py +4 -2
- transformers/models/rwkv/modeling_rwkv.py +32 -31
- transformers/models/sam/configuration_sam.py +1 -3
- transformers/models/sam/image_processing_sam.py +60 -59
- transformers/models/sam/image_processing_sam_fast.py +27 -25
- transformers/models/sam/modeling_sam.py +41 -47
- transformers/models/sam/processing_sam.py +27 -39
- transformers/models/sam2/configuration_sam2.py +3 -2
- transformers/models/sam2/image_processing_sam2_fast.py +15 -14
- transformers/models/sam2/modeling_sam2.py +90 -96
- transformers/models/sam2/modular_sam2.py +91 -86
- transformers/models/sam2/processing_sam2.py +47 -31
- transformers/models/sam2_video/configuration_sam2_video.py +1 -0
- transformers/models/sam2_video/modeling_sam2_video.py +144 -151
- transformers/models/sam2_video/modular_sam2_video.py +104 -101
- transformers/models/sam2_video/processing_sam2_video.py +66 -49
- transformers/models/sam2_video/video_processing_sam2_video.py +4 -1
- transformers/models/sam3/configuration_sam3.py +2 -21
- transformers/models/sam3/image_processing_sam3_fast.py +20 -17
- transformers/models/sam3/modeling_sam3.py +170 -184
- transformers/models/sam3/modular_sam3.py +8 -3
- transformers/models/sam3/processing_sam3.py +52 -37
- transformers/models/sam3_tracker/__init__.py +1 -0
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +3 -1
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +77 -82
- transformers/models/sam3_tracker/modular_sam3_tracker.py +3 -8
- transformers/models/sam3_tracker/processing_sam3_tracker.py +48 -31
- transformers/models/sam3_tracker_video/__init__.py +1 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +1 -25
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +122 -135
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +26 -35
- transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +66 -50
- transformers/models/sam3_video/configuration_sam3_video.py +1 -14
- transformers/models/sam3_video/modeling_sam3_video.py +34 -33
- transformers/models/sam3_video/processing_sam3_video.py +46 -26
- transformers/models/sam_hq/__init__.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -3
- transformers/models/sam_hq/modeling_sam_hq.py +69 -74
- transformers/models/sam_hq/modular_sam_hq.py +25 -23
- transformers/models/sam_hq/{processing_sam_hq.py → processing_samhq.py} +29 -41
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +10 -8
- transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +11 -8
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +194 -212
- transformers/models/seamless_m4t/processing_seamless_m4t.py +39 -18
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +77 -40
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +10 -8
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +196 -204
- transformers/models/seed_oss/configuration_seed_oss.py +32 -28
- transformers/models/seed_oss/modeling_seed_oss.py +35 -33
- transformers/models/seed_oss/modular_seed_oss.py +4 -3
- transformers/models/segformer/configuration_segformer.py +10 -0
- transformers/models/segformer/image_processing_segformer.py +42 -39
- transformers/models/segformer/image_processing_segformer_fast.py +12 -10
- transformers/models/segformer/modeling_segformer.py +31 -34
- transformers/models/segformer/modular_segformer.py +10 -8
- transformers/models/seggpt/configuration_seggpt.py +1 -0
- transformers/models/seggpt/image_processing_seggpt.py +41 -38
- transformers/models/seggpt/modeling_seggpt.py +38 -50
- transformers/models/sew/configuration_sew.py +2 -4
- transformers/models/sew/modeling_sew.py +36 -38
- transformers/models/sew/modular_sew.py +13 -13
- transformers/models/sew_d/configuration_sew_d.py +2 -4
- transformers/models/sew_d/modeling_sew_d.py +30 -31
- transformers/models/shieldgemma2/configuration_shieldgemma2.py +1 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +17 -16
- transformers/models/shieldgemma2/processing_shieldgemma2.py +5 -3
- transformers/models/siglip/configuration_siglip.py +2 -4
- transformers/models/siglip/image_processing_siglip.py +20 -17
- transformers/models/siglip/image_processing_siglip_fast.py +1 -0
- transformers/models/siglip/modeling_siglip.py +75 -84
- transformers/models/siglip/processing_siglip.py +14 -2
- transformers/models/siglip/tokenization_siglip.py +7 -6
- transformers/models/siglip2/configuration_siglip2.py +2 -5
- transformers/models/siglip2/image_processing_siglip2.py +16 -15
- transformers/models/siglip2/image_processing_siglip2_fast.py +7 -6
- transformers/models/siglip2/modeling_siglip2.py +129 -143
- transformers/models/siglip2/modular_siglip2.py +46 -47
- transformers/models/siglip2/processing_siglip2.py +14 -2
- transformers/models/smollm3/configuration_smollm3.py +32 -29
- transformers/models/smollm3/modeling_smollm3.py +39 -36
- transformers/models/smollm3/modular_smollm3.py +35 -33
- transformers/models/smolvlm/configuration_smolvlm.py +4 -2
- transformers/models/smolvlm/image_processing_smolvlm.py +43 -42
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +15 -41
- transformers/models/smolvlm/modeling_smolvlm.py +94 -126
- transformers/models/smolvlm/modular_smolvlm.py +39 -50
- transformers/models/smolvlm/processing_smolvlm.py +83 -15
- transformers/models/smolvlm/video_processing_smolvlm.py +18 -16
- transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +1 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +27 -26
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
- transformers/models/speech_to_text/feature_extraction_speech_to_text.py +13 -10
- transformers/models/speech_to_text/modeling_speech_to_text.py +54 -66
- transformers/models/speech_to_text/processing_speech_to_text.py +30 -4
- transformers/models/speech_to_text/tokenization_speech_to_text.py +6 -5
- transformers/models/speecht5/configuration_speecht5.py +9 -7
- transformers/models/speecht5/feature_extraction_speecht5.py +37 -16
- transformers/models/speecht5/modeling_speecht5.py +175 -213
- transformers/models/speecht5/number_normalizer.py +1 -0
- transformers/models/speecht5/processing_speecht5.py +37 -3
- transformers/models/speecht5/tokenization_speecht5.py +5 -4
- transformers/models/splinter/configuration_splinter.py +7 -6
- transformers/models/splinter/modeling_splinter.py +59 -71
- transformers/models/splinter/tokenization_splinter.py +30 -9
- transformers/models/squeezebert/configuration_squeezebert.py +2 -14
- transformers/models/squeezebert/modeling_squeezebert.py +62 -68
- transformers/models/squeezebert/tokenization_squeezebert.py +1 -0
- transformers/models/stablelm/configuration_stablelm.py +29 -24
- transformers/models/stablelm/modeling_stablelm.py +45 -44
- transformers/models/starcoder2/configuration_starcoder2.py +27 -30
- transformers/models/starcoder2/modeling_starcoder2.py +41 -39
- transformers/models/starcoder2/modular_starcoder2.py +16 -14
- transformers/models/superglue/configuration_superglue.py +3 -7
- transformers/models/superglue/image_processing_superglue.py +15 -15
- transformers/models/superglue/image_processing_superglue_fast.py +10 -9
- transformers/models/superglue/modeling_superglue.py +37 -42
- transformers/models/superpoint/image_processing_superpoint.py +15 -15
- transformers/models/superpoint/image_processing_superpoint_fast.py +11 -8
- transformers/models/superpoint/modeling_superpoint.py +16 -18
- transformers/models/swiftformer/configuration_swiftformer.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +14 -18
- transformers/models/swin/configuration_swin.py +1 -0
- transformers/models/swin/modeling_swin.py +86 -86
- transformers/models/swin2sr/configuration_swin2sr.py +1 -0
- transformers/models/swin2sr/image_processing_swin2sr.py +13 -10
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +8 -4
- transformers/models/swin2sr/modeling_swin2sr.py +63 -81
- transformers/models/swinv2/configuration_swinv2.py +1 -0
- transformers/models/swinv2/modeling_swinv2.py +104 -108
- transformers/models/switch_transformers/configuration_switch_transformers.py +7 -11
- transformers/models/switch_transformers/modeling_switch_transformers.py +44 -37
- transformers/models/switch_transformers/modular_switch_transformers.py +41 -34
- transformers/models/t5/configuration_t5.py +8 -14
- transformers/models/t5/modeling_t5.py +92 -88
- transformers/models/t5/tokenization_t5.py +9 -3
- transformers/models/t5gemma/configuration_t5gemma.py +41 -43
- transformers/models/t5gemma/modeling_t5gemma.py +107 -104
- transformers/models/t5gemma/modular_t5gemma.py +120 -124
- transformers/models/t5gemma2/configuration_t5gemma2.py +120 -80
- transformers/models/t5gemma2/modeling_t5gemma2.py +125 -141
- transformers/models/t5gemma2/modular_t5gemma2.py +104 -393
- transformers/models/table_transformer/configuration_table_transformer.py +2 -1
- transformers/models/table_transformer/modeling_table_transformer.py +49 -51
- transformers/models/tapas/configuration_tapas.py +2 -12
- transformers/models/tapas/modeling_tapas.py +67 -68
- transformers/models/tapas/tokenization_tapas.py +153 -115
- transformers/models/textnet/configuration_textnet.py +1 -0
- transformers/models/textnet/image_processing_textnet.py +25 -22
- transformers/models/textnet/image_processing_textnet_fast.py +10 -8
- transformers/models/textnet/modeling_textnet.py +16 -28
- transformers/models/time_series_transformer/configuration_time_series_transformer.py +8 -5
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +81 -83
- transformers/models/timesfm/configuration_timesfm.py +1 -0
- transformers/models/timesfm/modeling_timesfm.py +22 -33
- transformers/models/timesfm/modular_timesfm.py +21 -32
- transformers/models/timesformer/configuration_timesformer.py +1 -0
- transformers/models/timesformer/modeling_timesformer.py +16 -15
- transformers/models/timm_backbone/configuration_timm_backbone.py +1 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +15 -17
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -5
- transformers/models/timm_wrapper/image_processing_timm_wrapper.py +5 -4
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +29 -34
- transformers/models/trocr/configuration_trocr.py +8 -11
- transformers/models/trocr/modeling_trocr.py +44 -45
- transformers/models/trocr/processing_trocr.py +25 -5
- transformers/models/tvp/configuration_tvp.py +2 -5
- transformers/models/tvp/image_processing_tvp.py +52 -50
- transformers/models/tvp/image_processing_tvp_fast.py +15 -15
- transformers/models/tvp/modeling_tvp.py +27 -27
- transformers/models/tvp/processing_tvp.py +14 -2
- transformers/models/udop/configuration_udop.py +7 -16
- transformers/models/udop/modeling_udop.py +73 -71
- transformers/models/udop/processing_udop.py +26 -7
- transformers/models/udop/tokenization_udop.py +105 -84
- transformers/models/umt5/configuration_umt5.py +7 -8
- transformers/models/umt5/modeling_umt5.py +90 -94
- transformers/models/unispeech/configuration_unispeech.py +2 -4
- transformers/models/unispeech/modeling_unispeech.py +49 -51
- transformers/models/unispeech/modular_unispeech.py +22 -22
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +2 -4
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +65 -69
- transformers/models/unispeech_sat/modular_unispeech_sat.py +23 -23
- transformers/models/univnet/feature_extraction_univnet.py +14 -14
- transformers/models/univnet/modeling_univnet.py +8 -8
- transformers/models/upernet/configuration_upernet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +13 -11
- transformers/models/vaultgemma/__init__.py +1 -0
- transformers/models/vaultgemma/configuration_vaultgemma.py +33 -29
- transformers/models/vaultgemma/modeling_vaultgemma.py +41 -39
- transformers/models/vaultgemma/modular_vaultgemma.py +31 -29
- transformers/models/video_llama_3/configuration_video_llama_3.py +0 -4
- transformers/models/video_llama_3/image_processing_video_llama_3.py +42 -43
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +14 -12
- transformers/models/video_llama_3/modeling_video_llama_3.py +109 -157
- transformers/models/video_llama_3/modular_video_llama_3.py +146 -155
- transformers/models/video_llama_3/processing_video_llama_3.py +39 -5
- transformers/models/video_llama_3/video_processing_video_llama_3.py +23 -42
- transformers/models/video_llava/configuration_video_llava.py +1 -4
- transformers/models/video_llava/image_processing_video_llava.py +38 -35
- transformers/models/video_llava/modeling_video_llava.py +146 -146
- transformers/models/video_llava/processing_video_llava.py +78 -38
- transformers/models/video_llava/video_processing_video_llava.py +1 -0
- transformers/models/videomae/configuration_videomae.py +1 -0
- transformers/models/videomae/image_processing_videomae.py +34 -31
- transformers/models/videomae/modeling_videomae.py +17 -14
- transformers/models/videomae/video_processing_videomae.py +1 -0
- transformers/models/vilt/configuration_vilt.py +4 -6
- transformers/models/vilt/image_processing_vilt.py +30 -29
- transformers/models/vilt/image_processing_vilt_fast.py +16 -15
- transformers/models/vilt/modeling_vilt.py +90 -116
- transformers/models/vilt/processing_vilt.py +14 -2
- transformers/models/vipllava/configuration_vipllava.py +1 -4
- transformers/models/vipllava/modeling_vipllava.py +70 -99
- transformers/models/vipllava/modular_vipllava.py +54 -78
- transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +1 -0
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +27 -28
- transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +1 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +41 -46
- transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +16 -2
- transformers/models/visual_bert/configuration_visual_bert.py +2 -6
- transformers/models/visual_bert/modeling_visual_bert.py +92 -98
- transformers/models/vit/configuration_vit.py +1 -0
- transformers/models/vit/image_processing_vit.py +22 -19
- transformers/models/vit/image_processing_vit_fast.py +1 -0
- transformers/models/vit/modeling_vit.py +17 -17
- transformers/models/vit_mae/configuration_vit_mae.py +1 -0
- transformers/models/vit_mae/modeling_vit_mae.py +27 -29
- transformers/models/vit_msn/configuration_vit_msn.py +1 -0
- transformers/models/vit_msn/modeling_vit_msn.py +16 -18
- transformers/models/vitdet/configuration_vitdet.py +1 -0
- transformers/models/vitdet/modeling_vitdet.py +14 -14
- transformers/models/vitmatte/configuration_vitmatte.py +5 -2
- transformers/models/vitmatte/image_processing_vitmatte.py +18 -15
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +18 -16
- transformers/models/vitmatte/modeling_vitmatte.py +11 -14
- transformers/models/vitpose/configuration_vitpose.py +7 -4
- transformers/models/vitpose/image_processing_vitpose.py +25 -24
- transformers/models/vitpose/image_processing_vitpose_fast.py +11 -9
- transformers/models/vitpose/modeling_vitpose.py +14 -14
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +1 -0
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +10 -8
- transformers/models/vits/configuration_vits.py +1 -4
- transformers/models/vits/modeling_vits.py +42 -44
- transformers/models/vits/tokenization_vits.py +4 -3
- transformers/models/vivit/configuration_vivit.py +1 -0
- transformers/models/vivit/image_processing_vivit.py +39 -36
- transformers/models/vivit/modeling_vivit.py +8 -6
- transformers/models/vjepa2/__init__.py +1 -0
- transformers/models/vjepa2/configuration_vjepa2.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +32 -31
- transformers/models/vjepa2/video_processing_vjepa2.py +1 -0
- transformers/models/voxtral/__init__.py +1 -0
- transformers/models/voxtral/configuration_voxtral.py +2 -0
- transformers/models/voxtral/modeling_voxtral.py +47 -40
- transformers/models/voxtral/modular_voxtral.py +40 -37
- transformers/models/voxtral/processing_voxtral.py +48 -25
- transformers/models/wav2vec2/configuration_wav2vec2.py +2 -4
- transformers/models/wav2vec2/feature_extraction_wav2vec2.py +10 -7
- transformers/models/wav2vec2/modeling_wav2vec2.py +121 -73
- transformers/models/wav2vec2/processing_wav2vec2.py +35 -6
- transformers/models/wav2vec2/tokenization_wav2vec2.py +332 -20
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +2 -4
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +62 -70
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +48 -57
- transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +35 -6
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +2 -4
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +77 -90
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +30 -37
- transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +17 -16
- transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +55 -36
- transformers/models/wavlm/configuration_wavlm.py +2 -4
- transformers/models/wavlm/modeling_wavlm.py +48 -50
- transformers/models/wavlm/modular_wavlm.py +5 -4
- transformers/models/whisper/configuration_whisper.py +5 -6
- transformers/models/whisper/english_normalizer.py +4 -3
- transformers/models/whisper/feature_extraction_whisper.py +24 -9
- transformers/models/whisper/generation_whisper.py +48 -26
- transformers/models/whisper/modeling_whisper.py +73 -79
- transformers/models/whisper/processing_whisper.py +20 -3
- transformers/models/whisper/tokenization_whisper.py +43 -11
- transformers/models/x_clip/configuration_x_clip.py +2 -4
- transformers/models/x_clip/modeling_x_clip.py +93 -96
- transformers/models/x_clip/processing_x_clip.py +14 -2
- transformers/models/xcodec/configuration_xcodec.py +6 -4
- transformers/models/xcodec/modeling_xcodec.py +17 -20
- transformers/models/xglm/configuration_xglm.py +8 -9
- transformers/models/xglm/modeling_xglm.py +55 -60
- transformers/models/xglm/tokenization_xglm.py +11 -3
- transformers/models/xlm/configuration_xlm.py +8 -10
- transformers/models/xlm/modeling_xlm.py +144 -144
- transformers/models/xlm/tokenization_xlm.py +5 -3
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +3 -11
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +194 -195
- transformers/models/xlm_roberta/modular_xlm_roberta.py +53 -50
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +18 -8
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +2 -10
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +93 -94
- transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +70 -67
- transformers/models/xlnet/configuration_xlnet.py +12 -3
- transformers/models/xlnet/modeling_xlnet.py +163 -152
- transformers/models/xlnet/tokenization_xlnet.py +9 -2
- transformers/models/xlstm/configuration_xlstm.py +12 -8
- transformers/models/xlstm/modeling_xlstm.py +65 -62
- transformers/models/xmod/configuration_xmod.py +3 -11
- transformers/models/xmod/modeling_xmod.py +110 -108
- transformers/models/yolos/configuration_yolos.py +1 -0
- transformers/models/yolos/image_processing_yolos.py +62 -60
- transformers/models/yolos/image_processing_yolos_fast.py +45 -42
- transformers/models/yolos/modeling_yolos.py +16 -16
- transformers/models/yolos/modular_yolos.py +19 -17
- transformers/models/yoso/configuration_yoso.py +2 -8
- transformers/models/yoso/modeling_yoso.py +63 -70
- transformers/models/zamba/configuration_zamba.py +8 -5
- transformers/models/zamba/modeling_zamba.py +78 -81
- transformers/models/zamba2/configuration_zamba2.py +50 -44
- transformers/models/zamba2/modeling_zamba2.py +97 -97
- transformers/models/zamba2/modular_zamba2.py +48 -46
- transformers/models/zoedepth/configuration_zoedepth.py +2 -1
- transformers/models/zoedepth/image_processing_zoedepth.py +29 -28
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +24 -21
- transformers/models/zoedepth/modeling_zoedepth.py +18 -26
- transformers/pipelines/__init__.py +114 -57
- transformers/pipelines/any_to_any.py +22 -14
- transformers/pipelines/audio_utils.py +2 -1
- transformers/pipelines/automatic_speech_recognition.py +12 -20
- transformers/pipelines/base.py +27 -15
- transformers/{models/pe_audio/processing_pe_audio.py → pipelines/deprecated/__init__.py} +3 -10
- transformers/pipelines/deprecated/text2text_generation.py +408 -0
- transformers/pipelines/document_question_answering.py +2 -4
- transformers/pipelines/image_text_to_text.py +1 -0
- transformers/pipelines/image_to_text.py +229 -0
- transformers/pipelines/question_answering.py +44 -5
- transformers/pipelines/text_classification.py +14 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/pipelines/token_classification.py +22 -1
- transformers/pipelines/video_classification.py +9 -1
- transformers/pipelines/zero_shot_audio_classification.py +1 -0
- transformers/pipelines/zero_shot_classification.py +6 -0
- transformers/pipelines/zero_shot_image_classification.py +7 -0
- transformers/processing_utils.py +145 -230
- transformers/quantizers/auto.py +4 -2
- transformers/quantizers/base.py +173 -53
- transformers/quantizers/quantizer_aqlm.py +23 -2
- transformers/quantizers/quantizer_auto_round.py +12 -2
- transformers/quantizers/quantizer_awq.py +89 -20
- transformers/quantizers/quantizer_bitnet.py +14 -4
- transformers/quantizers/quantizer_bnb_4bit.py +155 -18
- transformers/quantizers/quantizer_bnb_8bit.py +110 -24
- transformers/quantizers/quantizer_compressed_tensors.py +9 -2
- transformers/quantizers/quantizer_eetq.py +74 -16
- transformers/quantizers/quantizer_fbgemm_fp8.py +138 -38
- transformers/quantizers/quantizer_finegrained_fp8.py +113 -26
- transformers/quantizers/quantizer_fp_quant.py +82 -52
- transformers/quantizers/quantizer_gptq.py +28 -8
- transformers/quantizers/quantizer_higgs.py +60 -42
- transformers/quantizers/quantizer_hqq.py +153 -144
- transformers/quantizers/quantizer_mxfp4.py +194 -14
- transformers/quantizers/quantizer_quanto.py +79 -35
- transformers/quantizers/quantizer_quark.py +18 -36
- transformers/quantizers/quantizer_spqr.py +12 -4
- transformers/quantizers/quantizer_torchao.py +325 -50
- transformers/quantizers/quantizer_vptq.py +27 -4
- transformers/quantizers/quantizers_utils.py +0 -20
- transformers/safetensors_conversion.py +3 -9
- transformers/testing_utils.py +82 -326
- transformers/tokenization_mistral_common.py +903 -568
- transformers/tokenization_utils_base.py +340 -220
- transformers/tokenization_utils_sentencepiece.py +6 -5
- transformers/tokenization_utils_tokenizers.py +113 -226
- transformers/trainer.py +53 -60
- transformers/trainer_callback.py +0 -8
- transformers/trainer_seq2seq.py +1 -5
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +41 -77
- transformers/utils/__init__.py +4 -8
- transformers/utils/attention_visualizer.py +5 -5
- transformers/utils/auto_docstring.py +37 -599
- transformers/utils/doc.py +36 -4
- transformers/utils/dummy_pt_objects.py +42 -0
- transformers/utils/generic.py +28 -111
- transformers/utils/hub.py +15 -5
- transformers/utils/import_utils.py +32 -165
- transformers/utils/kernel_config.py +19 -74
- transformers/utils/loading_report.py +15 -25
- transformers/utils/quantization_config.py +241 -72
- transformers/video_processing_utils.py +39 -41
- transformers/video_utils.py +22 -18
- {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/METADATA +236 -284
- transformers-5.0.0rc0.dist-info/RECORD +1987 -0
- {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/WHEEL +1 -1
- transformers/integrations/moe.py +0 -360
- transformers/integrations/quark.py +0 -53
- transformers/loss/loss_lw_detr.py +0 -356
- transformers/models/ernie4_5_vl_moe/__init__.py +0 -31
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +0 -340
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +0 -455
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +0 -231
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +0 -1936
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +0 -1925
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +0 -249
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +0 -593
- transformers/models/fast_vlm/__init__.py +0 -27
- transformers/models/fast_vlm/configuration_fast_vlm.py +0 -137
- transformers/models/fast_vlm/modeling_fast_vlm.py +0 -432
- transformers/models/fast_vlm/modular_fast_vlm.py +0 -373
- transformers/models/glm4_moe_lite/__init__.py +0 -28
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +0 -233
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +0 -740
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +0 -302
- transformers/models/glm_image/__init__.py +0 -31
- transformers/models/glm_image/configuration_glm_image.py +0 -351
- transformers/models/glm_image/image_processing_glm_image.py +0 -503
- transformers/models/glm_image/image_processing_glm_image_fast.py +0 -294
- transformers/models/glm_image/modeling_glm_image.py +0 -1642
- transformers/models/glm_image/modular_glm_image.py +0 -1531
- transformers/models/glm_image/processing_glm_image.py +0 -217
- transformers/models/glmasr/__init__.py +0 -29
- transformers/models/glmasr/configuration_glmasr.py +0 -196
- transformers/models/glmasr/modeling_glmasr.py +0 -517
- transformers/models/glmasr/modular_glmasr.py +0 -443
- transformers/models/glmasr/processing_glmasr.py +0 -331
- transformers/models/jais2/__init__.py +0 -27
- transformers/models/jais2/configuration_jais2.py +0 -148
- transformers/models/jais2/modeling_jais2.py +0 -484
- transformers/models/jais2/modular_jais2.py +0 -194
- transformers/models/lasr/__init__.py +0 -29
- transformers/models/lasr/configuration_lasr.py +0 -244
- transformers/models/lasr/feature_extraction_lasr.py +0 -275
- transformers/models/lasr/modeling_lasr.py +0 -727
- transformers/models/lasr/modular_lasr.py +0 -574
- transformers/models/lasr/processing_lasr.py +0 -100
- transformers/models/lasr/tokenization_lasr.py +0 -184
- transformers/models/lighton_ocr/__init__.py +0 -28
- transformers/models/lighton_ocr/configuration_lighton_ocr.py +0 -128
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +0 -463
- transformers/models/lighton_ocr/modular_lighton_ocr.py +0 -404
- transformers/models/lighton_ocr/processing_lighton_ocr.py +0 -229
- transformers/models/lw_detr/__init__.py +0 -27
- transformers/models/lw_detr/configuration_lw_detr.py +0 -374
- transformers/models/lw_detr/modeling_lw_detr.py +0 -1702
- transformers/models/lw_detr/modular_lw_detr.py +0 -1615
- transformers/models/minimax_m2/__init__.py +0 -28
- transformers/models/minimax_m2/configuration_minimax_m2.py +0 -188
- transformers/models/minimax_m2/modeling_minimax_m2.py +0 -704
- transformers/models/minimax_m2/modular_minimax_m2.py +0 -346
- transformers/models/paddleocr_vl/__init__.py +0 -31
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +0 -335
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +0 -503
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +0 -209
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +0 -1683
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +0 -1380
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +0 -133
- transformers/models/pe_audio/__init__.py +0 -29
- transformers/models/pe_audio/configuration_pe_audio.py +0 -204
- transformers/models/pe_audio/feature_extraction_pe_audio.py +0 -160
- transformers/models/pe_audio/modeling_pe_audio.py +0 -819
- transformers/models/pe_audio/modular_pe_audio.py +0 -298
- transformers/models/pe_audio_video/__init__.py +0 -28
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +0 -223
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +0 -971
- transformers/models/pe_audio_video/modular_pe_audio_video.py +0 -763
- transformers/models/pe_video/__init__.py +0 -29
- transformers/models/pe_video/configuration_pe_video.py +0 -209
- transformers/models/pe_video/modeling_pe_video.py +0 -647
- transformers/models/pe_video/modular_pe_video.py +0 -231
- transformers/models/pe_video/processing_pe_video.py +0 -10
- transformers/models/pe_video/video_processing_pe_video.py +0 -64
- transformers/models/pixio/__init__.py +0 -29
- transformers/models/pixio/configuration_pixio.py +0 -150
- transformers/models/pixio/modeling_pixio.py +0 -507
- transformers/models/pixio/modular_pixio.py +0 -403
- transformers/models/solar_open/__init__.py +0 -27
- transformers/models/solar_open/configuration_solar_open.py +0 -184
- transformers/models/solar_open/modeling_solar_open.py +0 -642
- transformers/models/solar_open/modular_solar_open.py +0 -224
- transformers/trainer_jit_checkpoint.py +0 -125
- transformers-5.0.0.dist-info/RECORD +0 -2068
- {transformers-5.0.0.dist-info/licenses → transformers-5.0.0rc0.dist-info}/LICENSE +0 -0
- {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/top_level.txt +0 -0
transformers/modeling_utils.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# coding=utf-8
|
|
1
2
|
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
|
2
3
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
3
4
|
#
|
|
@@ -15,6 +16,7 @@
|
|
|
15
16
|
import collections
|
|
16
17
|
import copy
|
|
17
18
|
import functools
|
|
19
|
+
import gc
|
|
18
20
|
import importlib.metadata
|
|
19
21
|
import inspect
|
|
20
22
|
import json
|
|
@@ -24,18 +26,17 @@ import sys
|
|
|
24
26
|
import warnings
|
|
25
27
|
from abc import abstractmethod
|
|
26
28
|
from collections import defaultdict
|
|
27
|
-
from collections.abc import Callable,
|
|
29
|
+
from collections.abc import Callable, Sequence
|
|
28
30
|
from contextlib import contextmanager
|
|
29
|
-
from dataclasses import dataclass, field, replace
|
|
30
31
|
from enum import Enum
|
|
31
32
|
from functools import partial, wraps
|
|
32
33
|
from itertools import cycle
|
|
33
34
|
from threading import Thread
|
|
34
|
-
from typing import Optional, TypeVar, get_type_hints
|
|
35
|
+
from typing import Optional, TypeVar, Union, get_type_hints
|
|
35
36
|
from zipfile import is_zipfile
|
|
36
37
|
|
|
37
38
|
import torch
|
|
38
|
-
from huggingface_hub import create_repo,
|
|
39
|
+
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
|
|
39
40
|
from packaging import version
|
|
40
41
|
from safetensors import safe_open
|
|
41
42
|
from safetensors.torch import save_file as safe_save_file
|
|
@@ -62,8 +63,7 @@ from .integrations.accelerate import (
|
|
|
62
63
|
accelerate_dispatch,
|
|
63
64
|
check_and_set_device_map,
|
|
64
65
|
expand_device_map,
|
|
65
|
-
|
|
66
|
-
load_offloaded_parameter,
|
|
66
|
+
init_empty_weights,
|
|
67
67
|
)
|
|
68
68
|
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
|
69
69
|
from .integrations.eager_paged import eager_paged_attention_forward
|
|
@@ -85,8 +85,7 @@ from .integrations.tensor_parallel import (
|
|
|
85
85
|
verify_tp_plan,
|
|
86
86
|
)
|
|
87
87
|
from .loss.loss_utils import LOSS_MAPPING
|
|
88
|
-
from .modeling_flash_attention_utils import lazy_import_flash_attention
|
|
89
|
-
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
88
|
+
from .modeling_flash_attention_utils import lazy_import_flash_attention
|
|
90
89
|
from .pytorch_utils import id_tensor_storage
|
|
91
90
|
from .quantizers import HfQuantizer
|
|
92
91
|
from .quantizers.auto import get_hf_quantizer
|
|
@@ -94,6 +93,7 @@ from .quantizers.quantizers_utils import get_module_from_name
|
|
|
94
93
|
from .safetensors_conversion import auto_conversion
|
|
95
94
|
from .utils import (
|
|
96
95
|
ADAPTER_SAFE_WEIGHTS_NAME,
|
|
96
|
+
ADAPTER_WEIGHTS_NAME,
|
|
97
97
|
DUMMY_INPUTS,
|
|
98
98
|
SAFE_WEIGHTS_INDEX_NAME,
|
|
99
99
|
SAFE_WEIGHTS_NAME,
|
|
@@ -107,12 +107,10 @@ from .utils import (
|
|
|
107
107
|
copy_func,
|
|
108
108
|
has_file,
|
|
109
109
|
is_accelerate_available,
|
|
110
|
-
is_bitsandbytes_available,
|
|
111
|
-
is_env_variable_true,
|
|
112
110
|
is_flash_attn_2_available,
|
|
113
111
|
is_flash_attn_3_available,
|
|
114
|
-
is_grouped_mm_available,
|
|
115
112
|
is_kernels_available,
|
|
113
|
+
is_offline_mode,
|
|
116
114
|
is_torch_flex_attn_available,
|
|
117
115
|
is_torch_greater_or_equal,
|
|
118
116
|
is_torch_mlu_available,
|
|
@@ -120,7 +118,7 @@ from .utils import (
|
|
|
120
118
|
is_torch_xpu_available,
|
|
121
119
|
logging,
|
|
122
120
|
)
|
|
123
|
-
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
|
|
121
|
+
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
|
|
124
122
|
from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
|
|
125
123
|
from .utils.import_utils import (
|
|
126
124
|
is_huggingface_hub_greater_or_equal,
|
|
@@ -134,6 +132,7 @@ from .utils.quantization_config import QuantizationMethod
|
|
|
134
132
|
if is_accelerate_available():
|
|
135
133
|
from accelerate.hooks import add_hook_to_module
|
|
136
134
|
from accelerate.utils import extract_model_from_parallel
|
|
135
|
+
from accelerate.utils.modeling import get_state_dict_from_offload
|
|
137
136
|
|
|
138
137
|
|
|
139
138
|
_torch_distributed_available = torch.distributed.is_available()
|
|
@@ -155,63 +154,62 @@ logger = logging.get_logger(__name__)
|
|
|
155
154
|
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
|
|
156
155
|
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
|
157
156
|
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
|
|
157
|
+
_init_weights = True
|
|
158
158
|
_is_quantized = False
|
|
159
159
|
_is_ds_init_called = False
|
|
160
160
|
|
|
161
|
-
# Mapping from flash attention implementations to their kernel fallback repositories
|
|
162
|
-
FLASH_ATTN_KERNEL_FALLBACK = {
|
|
163
|
-
"flash_attention_2": "kernels-community/flash-attn2",
|
|
164
|
-
"flash_attention_3": "kernels-community/vllm-flash-attn3",
|
|
165
|
-
}
|
|
166
|
-
|
|
167
161
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
162
|
+
def is_local_dist_rank_0():
|
|
163
|
+
return (
|
|
164
|
+
torch.distributed.is_available()
|
|
165
|
+
and torch.distributed.is_initialized()
|
|
166
|
+
and int(os.environ.get("LOCAL_RANK", "-1")) == 0
|
|
167
|
+
)
|
|
174
168
|
|
|
175
|
-
pretrained_model_name_or_path: str | None = None
|
|
176
|
-
download_kwargs: DownloadKwargs | None = field(default_factory=DownloadKwargs)
|
|
177
|
-
use_safetensors: bool = True
|
|
178
|
-
ignore_mismatched_sizes: bool = False
|
|
179
|
-
sharded_metadata: dict | None = None
|
|
180
|
-
device_map: dict | None = None
|
|
181
|
-
disk_offload_folder: str | None = None
|
|
182
|
-
offload_buffers: bool = False
|
|
183
|
-
dtype: torch.dtype | None = None
|
|
184
|
-
hf_quantizer: HfQuantizer | None = None
|
|
185
|
-
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None
|
|
186
|
-
weights_only: bool = True
|
|
187
|
-
weight_mapping: list[WeightConverter | WeightRenaming] | None = None
|
|
188
169
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
170
|
+
TORCH_INIT_FUNCTIONS = {
|
|
171
|
+
"uniform_": nn.init.uniform_,
|
|
172
|
+
"normal_": nn.init.normal_,
|
|
173
|
+
"trunc_normal_": nn.init.trunc_normal_,
|
|
174
|
+
"constant_": nn.init.constant_,
|
|
175
|
+
"xavier_uniform_": nn.init.xavier_uniform_,
|
|
176
|
+
"xavier_normal_": nn.init.xavier_normal_,
|
|
177
|
+
"kaiming_uniform_": nn.init.kaiming_uniform_,
|
|
178
|
+
"kaiming_normal_": nn.init.kaiming_normal_,
|
|
179
|
+
"uniform": nn.init.uniform,
|
|
180
|
+
"normal": nn.init.normal,
|
|
181
|
+
"xavier_uniform": nn.init.xavier_uniform,
|
|
182
|
+
"xavier_normal": nn.init.xavier_normal,
|
|
183
|
+
"kaiming_uniform": nn.init.kaiming_uniform,
|
|
184
|
+
"kaiming_normal": nn.init.kaiming_normal,
|
|
185
|
+
"orthogonal_": nn.init.orthogonal_,
|
|
186
|
+
}
|
|
192
187
|
|
|
193
188
|
|
|
194
|
-
@
|
|
195
|
-
|
|
189
|
+
@contextmanager
|
|
190
|
+
def no_init_weights():
|
|
196
191
|
"""
|
|
197
|
-
|
|
198
|
-
This simplifies the code a bit.
|
|
192
|
+
Context manager to globally disable weight initialization to speed up loading large models.
|
|
199
193
|
"""
|
|
194
|
+
global _init_weights
|
|
195
|
+
old_init_weights = _init_weights
|
|
200
196
|
|
|
201
|
-
|
|
202
|
-
unexpected_keys: set[str]
|
|
203
|
-
mismatched_keys: set[tuple[str, torch.Size]]
|
|
204
|
-
disk_offload_index: dict[str, str] | None
|
|
205
|
-
error_msgs: list[str]
|
|
206
|
-
conversion_errors: set[str]
|
|
197
|
+
_init_weights = False
|
|
207
198
|
|
|
199
|
+
def _skip_init(*args, **kwargs):
|
|
200
|
+
pass
|
|
208
201
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
torch.
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
202
|
+
# Save the original initialization functions
|
|
203
|
+
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
|
204
|
+
setattr(torch.nn.init, name, _skip_init)
|
|
205
|
+
|
|
206
|
+
try:
|
|
207
|
+
yield
|
|
208
|
+
finally:
|
|
209
|
+
_init_weights = old_init_weights
|
|
210
|
+
# Restore the original initialization functions
|
|
211
|
+
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
|
212
|
+
setattr(torch.nn.init, name, init_func)
|
|
215
213
|
|
|
216
214
|
|
|
217
215
|
@contextmanager
|
|
@@ -237,28 +235,23 @@ def set_zero3_state():
|
|
|
237
235
|
_is_ds_init_called = False
|
|
238
236
|
|
|
239
237
|
|
|
240
|
-
|
|
241
|
-
def local_torch_dtype(dtype: torch.dtype, model_class_name: str | None = None):
|
|
238
|
+
def restore_default_dtype(func):
|
|
242
239
|
"""
|
|
243
|
-
|
|
244
|
-
|
|
240
|
+
Decorator to restore the default torch dtype
|
|
241
|
+
at the end of the function. Serves
|
|
242
|
+
as a backup in case calling the function raises
|
|
243
|
+
an error after the function has changed the default dtype but before it could restore it.
|
|
245
244
|
"""
|
|
246
|
-
# Just a more helping error before we set `torch.set_default_dtype` later on which would crash in this case
|
|
247
|
-
if not dtype.is_floating_point:
|
|
248
|
-
if model_class_name is not None:
|
|
249
|
-
error_message = (
|
|
250
|
-
f"{model_class_name} cannot be instantiated under `dtype={dtype}` as it's not a floating-point dtype"
|
|
251
|
-
)
|
|
252
|
-
else:
|
|
253
|
-
error_message = f"Cannot set `{dtype}` as torch's default as it's not a floating-point dtype"
|
|
254
|
-
raise ValueError(error_message)
|
|
255
245
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
torch.
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
246
|
+
@wraps(func)
|
|
247
|
+
def _wrapper(*args, **kwargs):
|
|
248
|
+
old_dtype = torch.get_default_dtype()
|
|
249
|
+
try:
|
|
250
|
+
return func(*args, **kwargs)
|
|
251
|
+
finally:
|
|
252
|
+
torch.set_default_dtype(old_dtype)
|
|
253
|
+
|
|
254
|
+
return _wrapper
|
|
262
255
|
|
|
263
256
|
|
|
264
257
|
def get_torch_context_manager_or_global_device():
|
|
@@ -286,9 +279,7 @@ def get_state_dict_dtype(state_dict):
|
|
|
286
279
|
return t.dtype
|
|
287
280
|
|
|
288
281
|
# if no floating dtype was found return whatever the first dtype is
|
|
289
|
-
|
|
290
|
-
return torch.float32
|
|
291
|
-
return next(iter(state_dict.values())).dtype
|
|
282
|
+
return next(state_dict.values()).dtype
|
|
292
283
|
|
|
293
284
|
|
|
294
285
|
str_to_torch_dtype = {
|
|
@@ -314,7 +305,7 @@ if is_torch_greater_or_equal("2.3.0"):
|
|
|
314
305
|
|
|
315
306
|
|
|
316
307
|
def load_state_dict(
|
|
317
|
-
checkpoint_file: str
|
|
308
|
+
checkpoint_file: Union[str, os.PathLike], map_location: Union[str, torch.device] = "cpu", weights_only: bool = True
|
|
318
309
|
) -> dict[str, torch.Tensor]:
|
|
319
310
|
"""
|
|
320
311
|
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
|
|
@@ -414,97 +405,14 @@ def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]
|
|
|
414
405
|
return shared_tensors, identical
|
|
415
406
|
|
|
416
407
|
|
|
417
|
-
def remove_tied_weights_from_state_dict(
|
|
418
|
-
state_dict: dict[str, torch.Tensor], model: "PreTrainedModel"
|
|
419
|
-
) -> dict[str, torch.Tensor]:
|
|
420
|
-
"""
|
|
421
|
-
Remove all tied weights from the given `state_dict`, making sure to keep only the main weight that `model`
|
|
422
|
-
will expect when reloading (even if we know tie weights symmetrically, it's better to keep the intended one).
|
|
423
|
-
This is because `safetensors` does not allow tensor aliasing - so we're going to remove aliases before saving.
|
|
424
|
-
"""
|
|
425
|
-
# To avoid any potential mistakes and mismatches between config and actual tied weights, here we check the pointers
|
|
426
|
-
# of the Tensors themselves -> we are guaranteed to find all the actual tied weights
|
|
427
|
-
ptrs = collections.defaultdict(list)
|
|
428
|
-
for name, tensor in state_dict.items():
|
|
429
|
-
if not isinstance(tensor, torch.Tensor):
|
|
430
|
-
# Sometimes in the state_dict we have non-tensor objects.
|
|
431
|
-
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
|
432
|
-
# In the non-tensor case, fall back to the pointer of the object itself
|
|
433
|
-
ptrs[id(tensor)].append(name)
|
|
434
|
-
|
|
435
|
-
elif tensor.device.type == "meta":
|
|
436
|
-
# In offloaded cases, there may be meta tensors in the state_dict.
|
|
437
|
-
# For these cases, key by the pointer of the original tensor object
|
|
438
|
-
# (state_dict tensors are detached and therefore no longer shared)
|
|
439
|
-
tensor = model.get_parameter(name)
|
|
440
|
-
ptrs[id(tensor)].append(name)
|
|
441
|
-
|
|
442
|
-
else:
|
|
443
|
-
ptrs[id_tensor_storage(tensor)].append(name)
|
|
444
|
-
|
|
445
|
-
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
|
446
|
-
|
|
447
|
-
# Recursively descend to find tied weight keys
|
|
448
|
-
all_potential_tied_weights_keys = set(_get_tied_weight_keys(model))
|
|
449
|
-
error_names = []
|
|
450
|
-
to_delete_names = set()
|
|
451
|
-
# Removing the keys which are declared as known duplicates on load. This allows to make sure the name which is
|
|
452
|
-
# kept is consistent
|
|
453
|
-
if all_potential_tied_weights_keys is not None:
|
|
454
|
-
for names in shared_ptrs.values():
|
|
455
|
-
found = 0
|
|
456
|
-
for name in sorted(names):
|
|
457
|
-
matches_pattern = any(re.search(pat, name) for pat in all_potential_tied_weights_keys)
|
|
458
|
-
if matches_pattern and name in state_dict:
|
|
459
|
-
found += 1
|
|
460
|
-
if found < len(names):
|
|
461
|
-
to_delete_names.add(name)
|
|
462
|
-
# We are entering a place where the weights and the transformers configuration do NOT match.
|
|
463
|
-
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
|
|
464
|
-
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
|
|
465
|
-
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
|
|
466
|
-
for name in disjoint_names:
|
|
467
|
-
state_dict[name] = state_dict[name].clone()
|
|
468
|
-
|
|
469
|
-
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
|
470
|
-
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
|
471
|
-
# the key back leading to random tensor. A proper warning will be shown
|
|
472
|
-
# during reload (if applicable), but since the file is not necessarily compatible with
|
|
473
|
-
# the config, better show a proper warning.
|
|
474
|
-
shared_names, identical_names = _find_identical(shared_names, state_dict)
|
|
475
|
-
# delete tensors that have identical storage
|
|
476
|
-
for inames in identical_names:
|
|
477
|
-
known = inames.intersection(to_delete_names)
|
|
478
|
-
for name in known:
|
|
479
|
-
del state_dict[name]
|
|
480
|
-
unknown = inames.difference(to_delete_names)
|
|
481
|
-
if len(unknown) > 1:
|
|
482
|
-
error_names.append(unknown)
|
|
483
|
-
|
|
484
|
-
if shared_names:
|
|
485
|
-
error_names.extend(shared_names)
|
|
486
|
-
|
|
487
|
-
if len(error_names) > 0:
|
|
488
|
-
raise RuntimeError(
|
|
489
|
-
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. "
|
|
490
|
-
f"We found all the potential target tied weights keys to be: {all_potential_tied_weights_keys}.\n"
|
|
491
|
-
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
|
|
492
|
-
)
|
|
493
|
-
|
|
494
|
-
return state_dict
|
|
495
|
-
|
|
496
|
-
|
|
497
408
|
def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
|
|
498
|
-
"""Cast a single parameter
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
# We need to use setattr here, as we set non-persistent buffers as well with this function (`load_state_dict`
|
|
503
|
-
# does not allow to do it)
|
|
504
|
-
setattr(parent, param_type, tensor)
|
|
409
|
+
"""Cast a single parameter `param_name` into the `model`, with value `tensor`."""
|
|
410
|
+
module, param_type = get_module_from_name(model, param_name)
|
|
411
|
+
# This will check potential shape mismatch if skipped before
|
|
412
|
+
module.load_state_dict({param_type: tensor}, strict=False, assign=True)
|
|
505
413
|
|
|
506
414
|
|
|
507
|
-
def _add_variant(weights_name: str, variant: str
|
|
415
|
+
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|
508
416
|
if variant is not None:
|
|
509
417
|
path, name = weights_name.rsplit(".", 1)
|
|
510
418
|
weights_name = f"{path}.{variant}.{name}"
|
|
@@ -512,20 +420,19 @@ def _add_variant(weights_name: str, variant: str | None = None) -> str:
|
|
|
512
420
|
|
|
513
421
|
|
|
514
422
|
def _get_resolved_checkpoint_files(
|
|
515
|
-
pretrained_model_name_or_path: str
|
|
516
|
-
variant: str
|
|
517
|
-
gguf_file: str
|
|
518
|
-
use_safetensors: bool
|
|
519
|
-
|
|
423
|
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
|
424
|
+
variant: Optional[str],
|
|
425
|
+
gguf_file: Optional[str],
|
|
426
|
+
use_safetensors: Optional[bool],
|
|
427
|
+
download_kwargs: DownloadKwargs,
|
|
428
|
+
user_agent: dict,
|
|
520
429
|
is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
|
|
521
|
-
transformers_explicit_filename: str
|
|
522
|
-
|
|
523
|
-
) -> tuple[list[str] | None, dict | None]:
|
|
430
|
+
transformers_explicit_filename: Optional[str] = None,
|
|
431
|
+
) -> tuple[Optional[list[str]], Optional[dict]]:
|
|
524
432
|
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
|
525
433
|
checkpoints are sharded.
|
|
526
434
|
This function will download the data if necessary.
|
|
527
435
|
"""
|
|
528
|
-
download_kwargs = download_kwargs or DownloadKwargs()
|
|
529
436
|
cache_dir = download_kwargs.get("cache_dir")
|
|
530
437
|
force_download = download_kwargs.get("force_download", False)
|
|
531
438
|
proxies = download_kwargs.get("proxies")
|
|
@@ -538,19 +445,17 @@ def _get_resolved_checkpoint_files(
|
|
|
538
445
|
if not transformers_explicit_filename.endswith(".safetensors") and not transformers_explicit_filename.endswith(
|
|
539
446
|
".safetensors.index.json"
|
|
540
447
|
):
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
)
|
|
448
|
+
raise ValueError(
|
|
449
|
+
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
|
|
450
|
+
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
|
|
451
|
+
f"{transformers_explicit_filename}"
|
|
452
|
+
)
|
|
547
453
|
|
|
548
454
|
is_sharded = False
|
|
549
455
|
|
|
550
456
|
if pretrained_model_name_or_path is not None and gguf_file is None:
|
|
551
457
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
|
552
458
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
|
553
|
-
# If the file is a local folder (but not in the HF_HOME cache, even if it's technically local)
|
|
554
459
|
if is_local:
|
|
555
460
|
if transformers_explicit_filename is not None:
|
|
556
461
|
# If the filename is explicitly defined, load this by default.
|
|
@@ -609,38 +514,25 @@ def _get_resolved_checkpoint_files(
|
|
|
609
514
|
else:
|
|
610
515
|
filename = _add_variant(WEIGHTS_NAME, variant)
|
|
611
516
|
|
|
612
|
-
# Prepare set of kwargs for hub functions
|
|
613
|
-
has_file_kwargs = {
|
|
614
|
-
"revision": revision,
|
|
615
|
-
"proxies": proxies,
|
|
616
|
-
"token": token,
|
|
617
|
-
"cache_dir": cache_dir,
|
|
618
|
-
"local_files_only": local_files_only,
|
|
619
|
-
}
|
|
620
|
-
cached_file_kwargs = {
|
|
621
|
-
"force_download": force_download,
|
|
622
|
-
"user_agent": user_agent,
|
|
623
|
-
"subfolder": subfolder,
|
|
624
|
-
"_raise_exceptions_for_gated_repo": False,
|
|
625
|
-
"_raise_exceptions_for_missing_entries": False,
|
|
626
|
-
"_commit_hash": commit_hash,
|
|
627
|
-
**has_file_kwargs,
|
|
628
|
-
}
|
|
629
|
-
can_auto_convert = (
|
|
630
|
-
not is_offline_mode() # for obvious reasons
|
|
631
|
-
# If we are in a CI environment or in a pytest run, we prevent the conversion
|
|
632
|
-
and not is_env_variable_true("DISABLE_SAFETENSORS_CONVERSION")
|
|
633
|
-
and not is_remote_code # converter bot does not work on remote code
|
|
634
|
-
and subfolder == "" # converter bot does not work on subfolders
|
|
635
|
-
)
|
|
636
|
-
|
|
637
517
|
try:
|
|
638
518
|
# Load from URL or cache if already cached
|
|
639
|
-
|
|
640
|
-
|
|
519
|
+
cached_file_kwargs = {
|
|
520
|
+
"cache_dir": cache_dir,
|
|
521
|
+
"force_download": force_download,
|
|
522
|
+
"proxies": proxies,
|
|
523
|
+
"local_files_only": local_files_only,
|
|
524
|
+
"token": token,
|
|
525
|
+
"user_agent": user_agent,
|
|
526
|
+
"revision": revision,
|
|
527
|
+
"subfolder": subfolder,
|
|
528
|
+
"_raise_exceptions_for_gated_repo": False,
|
|
529
|
+
"_raise_exceptions_for_missing_entries": False,
|
|
530
|
+
"_commit_hash": commit_hash,
|
|
531
|
+
}
|
|
641
532
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
|
642
533
|
|
|
643
|
-
#
|
|
534
|
+
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
|
535
|
+
# result when internet is up, the repo and revision exist, but the file does not.
|
|
644
536
|
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
|
|
645
537
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
|
646
538
|
resolved_archive_file = cached_file(
|
|
@@ -651,7 +543,7 @@ def _get_resolved_checkpoint_files(
|
|
|
651
543
|
if resolved_archive_file is not None:
|
|
652
544
|
is_sharded = True
|
|
653
545
|
elif use_safetensors:
|
|
654
|
-
if revision == "main" and
|
|
546
|
+
if revision == "main" and not is_offline_mode():
|
|
655
547
|
resolved_archive_file, revision, is_sharded = auto_conversion(
|
|
656
548
|
pretrained_model_name_or_path, **cached_file_kwargs
|
|
657
549
|
)
|
|
@@ -660,7 +552,8 @@ def _get_resolved_checkpoint_files(
|
|
|
660
552
|
raise OSError(
|
|
661
553
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
662
554
|
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
|
|
663
|
-
"and thus cannot be loaded with `safetensors`. Please
|
|
555
|
+
"and thus cannot be loaded with `safetensors`. Please make sure that the model has "
|
|
556
|
+
"been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
|
|
664
557
|
)
|
|
665
558
|
else:
|
|
666
559
|
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
|
@@ -668,8 +561,6 @@ def _get_resolved_checkpoint_files(
|
|
|
668
561
|
resolved_archive_file = cached_file(
|
|
669
562
|
pretrained_model_name_or_path, filename, **cached_file_kwargs
|
|
670
563
|
)
|
|
671
|
-
|
|
672
|
-
# Then try `.bin` files
|
|
673
564
|
if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
|
|
674
565
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
|
675
566
|
resolved_archive_file = cached_file(
|
|
@@ -679,38 +570,67 @@ def _get_resolved_checkpoint_files(
|
|
|
679
570
|
)
|
|
680
571
|
if resolved_archive_file is not None:
|
|
681
572
|
is_sharded = True
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
573
|
+
if not local_files_only and not is_offline_mode():
|
|
574
|
+
if resolved_archive_file is not None:
|
|
575
|
+
# In a CI environment (CircleCI / Github Actions workflow runs) or in a pytest run,
|
|
576
|
+
# we set `DISABLE_SAFETENSORS_CONVERSION=true` to prevent the conversion.
|
|
577
|
+
if (
|
|
578
|
+
filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
|
|
579
|
+
and os.getenv("DISABLE_SAFETENSORS_CONVERSION", None) != "true"
|
|
580
|
+
):
|
|
581
|
+
# If the PyTorch file was found, check if there is a safetensors file on the repository
|
|
582
|
+
# If there is no safetensors file on the repositories, start an auto conversion
|
|
583
|
+
safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
|
|
584
|
+
has_file_kwargs = {
|
|
585
|
+
"revision": revision,
|
|
586
|
+
"proxies": proxies,
|
|
587
|
+
"token": token,
|
|
588
|
+
"cache_dir": cache_dir,
|
|
589
|
+
"local_files_only": local_files_only,
|
|
590
|
+
}
|
|
591
|
+
cached_file_kwargs = {
|
|
592
|
+
"cache_dir": cache_dir,
|
|
593
|
+
"force_download": force_download,
|
|
594
|
+
"local_files_only": local_files_only,
|
|
595
|
+
"user_agent": user_agent,
|
|
596
|
+
"subfolder": subfolder,
|
|
597
|
+
"_raise_exceptions_for_gated_repo": False,
|
|
598
|
+
"_raise_exceptions_for_missing_entries": False,
|
|
599
|
+
"_commit_hash": commit_hash,
|
|
600
|
+
**has_file_kwargs,
|
|
601
|
+
}
|
|
602
|
+
if (
|
|
603
|
+
not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
|
|
604
|
+
and not is_remote_code
|
|
605
|
+
):
|
|
606
|
+
Thread(
|
|
607
|
+
target=auto_conversion,
|
|
608
|
+
args=(pretrained_model_name_or_path,),
|
|
609
|
+
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
|
|
610
|
+
name="Thread-auto_conversion",
|
|
611
|
+
).start()
|
|
709
612
|
else:
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
613
|
+
# Otherwise, no PyTorch file was found
|
|
614
|
+
has_file_kwargs = {
|
|
615
|
+
"revision": revision,
|
|
616
|
+
"proxies": proxies,
|
|
617
|
+
"token": token,
|
|
618
|
+
"cache_dir": cache_dir,
|
|
619
|
+
"local_files_only": local_files_only,
|
|
620
|
+
}
|
|
621
|
+
if variant is not None and has_file(
|
|
622
|
+
pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
|
|
623
|
+
):
|
|
624
|
+
raise OSError(
|
|
625
|
+
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
626
|
+
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
|
|
627
|
+
f" {variant}. Use `variant=None` to load this model from those weights."
|
|
628
|
+
)
|
|
629
|
+
else:
|
|
630
|
+
raise OSError(
|
|
631
|
+
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
632
|
+
f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
|
|
633
|
+
)
|
|
714
634
|
|
|
715
635
|
except OSError:
|
|
716
636
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
|
@@ -777,20 +697,22 @@ def _get_resolved_checkpoint_files(
|
|
|
777
697
|
|
|
778
698
|
|
|
779
699
|
def _get_dtype(
|
|
780
|
-
|
|
781
|
-
|
|
700
|
+
cls,
|
|
701
|
+
dtype: Optional[Union[str, torch.dtype, dict]],
|
|
702
|
+
checkpoint_files: Optional[list[str]],
|
|
782
703
|
config: PreTrainedConfig,
|
|
783
|
-
sharded_metadata: dict
|
|
784
|
-
state_dict: dict
|
|
704
|
+
sharded_metadata: Optional[dict],
|
|
705
|
+
state_dict: Optional[dict],
|
|
785
706
|
weights_only: bool,
|
|
786
|
-
|
|
787
|
-
) -> tuple[PreTrainedConfig, torch.dtype]:
|
|
707
|
+
) -> tuple[PreTrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
|
788
708
|
"""Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
|
|
789
709
|
inferred dtype. We do the following:
|
|
790
|
-
1. If dtype is
|
|
791
|
-
|
|
792
|
-
|
|
710
|
+
1. If dtype is not None, we use that dtype
|
|
711
|
+
2. If dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
|
712
|
+
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
|
713
|
+
we also may have config.dtype available, but we won't rely on it till v5
|
|
793
714
|
"""
|
|
715
|
+
dtype_orig = None
|
|
794
716
|
is_sharded = sharded_metadata is not None
|
|
795
717
|
|
|
796
718
|
if dtype is not None:
|
|
@@ -815,46 +737,43 @@ def _get_dtype(
|
|
|
815
737
|
)
|
|
816
738
|
elif hasattr(torch, dtype):
|
|
817
739
|
dtype = getattr(torch, dtype)
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
740
|
+
config.dtype = dtype
|
|
741
|
+
for sub_config_key in config.sub_configs:
|
|
742
|
+
if (sub_config := getattr(config, sub_config_key)) is not None:
|
|
743
|
+
sub_config.dtype = dtype
|
|
744
|
+
elif isinstance(dtype, torch.dtype):
|
|
745
|
+
config.dtype = dtype
|
|
746
|
+
for sub_config_key in config.sub_configs:
|
|
747
|
+
if (sub_config := getattr(config, sub_config_key)) is not None:
|
|
748
|
+
sub_config.dtype = dtype
|
|
749
|
+
elif isinstance(dtype, dict):
|
|
750
|
+
for key, curr_dtype in dtype.items():
|
|
751
|
+
if hasattr(config, key):
|
|
752
|
+
value = getattr(config, key)
|
|
753
|
+
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
|
|
754
|
+
value.dtype = curr_dtype
|
|
755
|
+
# main torch dtype for modules that aren't part of any sub-config
|
|
756
|
+
dtype = dtype.get("")
|
|
757
|
+
dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
|
|
758
|
+
config.dtype = dtype
|
|
759
|
+
if dtype is None:
|
|
760
|
+
dtype = torch.float32
|
|
761
|
+
else:
|
|
826
762
|
raise ValueError(
|
|
827
763
|
f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
|
|
828
764
|
f"for each sub-config in composite configs, but received {dtype}"
|
|
829
765
|
)
|
|
830
|
-
else:
|
|
831
|
-
# set torch.get_default_dtype() (usually fp32) as the default dtype if `None` is provided
|
|
832
|
-
dtype = torch.get_default_dtype()
|
|
833
|
-
|
|
834
|
-
if hf_quantizer is not None:
|
|
835
|
-
hf_quantizer.update_dtype(dtype)
|
|
836
|
-
|
|
837
|
-
# Get the main dtype
|
|
838
|
-
if isinstance(dtype, dict):
|
|
839
|
-
main_dtype = dtype.get("", torch.get_default_dtype())
|
|
840
|
-
main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
|
|
841
|
-
|
|
842
|
-
logger.warning_once(
|
|
843
|
-
"Using different dtypes per module is deprecated and will be removed in future versions "
|
|
844
|
-
"Setting different dtypes per backbone model might cause device errors downstream, therefore "
|
|
845
|
-
f"setting the dtype={main_dtype} for all modules."
|
|
846
|
-
)
|
|
847
766
|
|
|
767
|
+
dtype_orig = cls._set_default_dtype(dtype)
|
|
848
768
|
else:
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
sub_config.dtype = main_dtype
|
|
769
|
+
# set fp32 as the default dtype for BC
|
|
770
|
+
default_dtype = torch.get_default_dtype()
|
|
771
|
+
config.dtype = default_dtype
|
|
772
|
+
for key in config.sub_configs:
|
|
773
|
+
if (sub_config := getattr(config, key)) is not None:
|
|
774
|
+
sub_config.dtype = default_dtype
|
|
856
775
|
|
|
857
|
-
return config,
|
|
776
|
+
return config, dtype, dtype_orig
|
|
858
777
|
|
|
859
778
|
|
|
860
779
|
class PipelineParallel(Enum):
|
|
@@ -905,8 +824,13 @@ class ModuleUtilsMixin:
|
|
|
905
824
|
return encoder_extended_attention_mask
|
|
906
825
|
|
|
907
826
|
@staticmethod
|
|
908
|
-
def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
|
|
909
|
-
device
|
|
827
|
+
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
|
|
828
|
+
if device is not None:
|
|
829
|
+
warnings.warn(
|
|
830
|
+
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
|
831
|
+
)
|
|
832
|
+
else:
|
|
833
|
+
device = attention_mask.device
|
|
910
834
|
batch_size, seq_length = input_shape
|
|
911
835
|
seq_ids = torch.arange(seq_length, device=device)
|
|
912
836
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
@@ -930,7 +854,8 @@ class ModuleUtilsMixin:
|
|
|
930
854
|
self,
|
|
931
855
|
attention_mask: Tensor,
|
|
932
856
|
input_shape: tuple[int, ...],
|
|
933
|
-
|
|
857
|
+
device: Optional[torch.device] = None,
|
|
858
|
+
dtype: Optional[torch.dtype] = None,
|
|
934
859
|
) -> Tensor:
|
|
935
860
|
"""
|
|
936
861
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
|
@@ -947,6 +872,12 @@ class ModuleUtilsMixin:
|
|
|
947
872
|
if dtype is None:
|
|
948
873
|
dtype = self.dtype
|
|
949
874
|
|
|
875
|
+
if not (attention_mask.dim() == 2 and self.config.is_decoder):
|
|
876
|
+
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
|
|
877
|
+
if device is not None:
|
|
878
|
+
warnings.warn(
|
|
879
|
+
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
|
880
|
+
)
|
|
950
881
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
951
882
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
952
883
|
if attention_mask.dim() == 3:
|
|
@@ -955,9 +886,9 @@ class ModuleUtilsMixin:
|
|
|
955
886
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
956
887
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
957
888
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
958
|
-
if
|
|
889
|
+
if self.config.is_decoder:
|
|
959
890
|
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
|
|
960
|
-
input_shape, attention_mask
|
|
891
|
+
input_shape, attention_mask, device
|
|
961
892
|
)
|
|
962
893
|
else:
|
|
963
894
|
extended_attention_mask = attention_mask[:, None, None, :]
|
|
@@ -1038,52 +969,54 @@ class EmbeddingAccessMixin:
|
|
|
1038
969
|
`nn.Module`: A torch module mapping vocabulary to hidden states.
|
|
1039
970
|
"""
|
|
1040
971
|
|
|
972
|
+
# 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
|
|
973
|
+
# for most NLP models), and if so, return it.
|
|
974
|
+
|
|
1041
975
|
name = getattr(self, "_input_embed_layer", "embed_tokens")
|
|
1042
976
|
|
|
1043
|
-
# 1) Direct attribute (most NLP models).
|
|
1044
977
|
if (default_embedding := getattr(self, name, None)) is not None:
|
|
1045
978
|
return default_embedding
|
|
1046
|
-
# 2)
|
|
1047
|
-
if hasattr(self, "embeddings") and hasattr(self.embeddings, name):
|
|
1048
|
-
return getattr(self.embeddings, name)
|
|
1049
|
-
# 3) Encoder/decoder wrappers (e.g., `self.model.embed_tokens` or similar overrides).
|
|
1050
|
-
if hasattr(self, "model") and hasattr(self.model, name):
|
|
1051
|
-
return getattr(self.model, name)
|
|
979
|
+
# 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
|
|
1052
980
|
|
|
1053
|
-
if hasattr(self, "
|
|
1054
|
-
|
|
1055
|
-
if base_model is not None and base_model is not self:
|
|
1056
|
-
return base_model.get_input_embeddings()
|
|
981
|
+
if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
|
|
982
|
+
return self.model.embed_tokens
|
|
1057
983
|
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
984
|
+
# 3) vanilla decoder‑only architectures
|
|
985
|
+
elif hasattr(self, "embed_tokens"):
|
|
986
|
+
return self.embed_tokens
|
|
987
|
+
else:
|
|
988
|
+
base_model = getattr(self, "base_model_prefix", None)
|
|
989
|
+
if base_model is not None:
|
|
990
|
+
base_model = getattr(self, base_model, None)
|
|
991
|
+
if base_model is not None and base_model is not self:
|
|
992
|
+
return base_model.get_input_embeddings()
|
|
993
|
+
raise NotImplementedError(
|
|
994
|
+
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
|
|
995
|
+
"please override in the subclass."
|
|
996
|
+
)
|
|
1061
997
|
|
|
1062
998
|
def set_input_embeddings(self, value: nn.Module):
|
|
1063
999
|
"""Fallback setter that handles **~70%** of models in the code-base.
|
|
1064
1000
|
|
|
1065
1001
|
Order of attempts:
|
|
1066
|
-
1. `self
|
|
1067
|
-
2. `self.
|
|
1068
|
-
3.
|
|
1069
|
-
4.
|
|
1070
|
-
5. otherwise raise `NotImplementedError` so subclasses still can (and
|
|
1002
|
+
1. `self.model.embed_tokens`
|
|
1003
|
+
2. `self.embed_tokens`
|
|
1004
|
+
3. delegate to the *base model* if one exists
|
|
1005
|
+
4. otherwise raise `NotImplementedError` so subclasses still can (and
|
|
1071
1006
|
should) override for exotic layouts.
|
|
1072
1007
|
"""
|
|
1073
1008
|
|
|
1009
|
+
# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
|
|
1074
1010
|
name = getattr(self, "_input_embed_layer", "embed_tokens")
|
|
1075
|
-
|
|
1076
|
-
if hasattr(self, name):
|
|
1077
|
-
setattr(self, name, value)
|
|
1078
|
-
# 2) Nested embeddings (e.g., self.embeddings.patch_embedding for vision models)
|
|
1079
|
-
elif hasattr(self, "embeddings") and hasattr(self.embeddings, name):
|
|
1080
|
-
setattr(self.embeddings, name, value)
|
|
1081
|
-
# 3) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
|
|
1082
|
-
elif hasattr(self, "model") and hasattr(self.model, name):
|
|
1011
|
+
if hasattr(self, "model") and hasattr(self.model, name):
|
|
1083
1012
|
setattr(self.model, name, value)
|
|
1084
|
-
#
|
|
1085
|
-
elif hasattr(self,
|
|
1086
|
-
self
|
|
1013
|
+
# 2) as well as vanilla decoder‑only architectures
|
|
1014
|
+
elif hasattr(self, name):
|
|
1015
|
+
setattr(self, name, value)
|
|
1016
|
+
# 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
|
|
1017
|
+
elif getattr(self, self.base_model_prefix, self) is not self:
|
|
1018
|
+
base_model = getattr(self, self.base_model_prefix, self)
|
|
1019
|
+
base_model.set_input_embeddings(value)
|
|
1087
1020
|
else:
|
|
1088
1021
|
raise NotImplementedError(
|
|
1089
1022
|
f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
|
|
@@ -1144,7 +1077,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1144
1077
|
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
|
|
1145
1078
|
_keep_in_fp32_modules_strict = None
|
|
1146
1079
|
|
|
1147
|
-
dtype_plan: dict[str, torch.dtype]
|
|
1080
|
+
dtype_plan: Optional[dict[str, torch.dtype]] = None
|
|
1148
1081
|
|
|
1149
1082
|
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
|
1150
1083
|
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
|
@@ -1204,7 +1137,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1204
1137
|
|
|
1205
1138
|
# Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
|
|
1206
1139
|
# Possible values are: text, image, video, audio and time
|
|
1207
|
-
input_modalities: str
|
|
1140
|
+
input_modalities: Union[str, list[str]] = "text" # most models are text
|
|
1208
1141
|
|
|
1209
1142
|
@property
|
|
1210
1143
|
@torch._dynamo.allow_in_graph
|
|
@@ -1295,11 +1228,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1295
1228
|
self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
|
1296
1229
|
self.config._attn_implementation, is_init_check=True
|
|
1297
1230
|
)
|
|
1298
|
-
# Check the experts implementation is supported, or set it if not yet set (on the internal attr, to avoid
|
|
1299
|
-
# setting it recursively)
|
|
1300
|
-
self.config._experts_implementation_internal = self._check_and_adjust_experts_implementation(
|
|
1301
|
-
self.config._experts_implementation
|
|
1302
|
-
)
|
|
1303
1231
|
if self.can_generate():
|
|
1304
1232
|
self.generation_config = GenerationConfig.from_model_config(config)
|
|
1305
1233
|
|
|
@@ -1415,7 +1343,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1415
1343
|
def pp_plan(self, plan: dict[str, tuple[str, str]]):
|
|
1416
1344
|
self._pp_plan = plan
|
|
1417
1345
|
|
|
1418
|
-
def dequantize(self
|
|
1346
|
+
def dequantize(self):
|
|
1419
1347
|
"""
|
|
1420
1348
|
Potentially dequantize the model in case it has been quantized by a quantization method that support
|
|
1421
1349
|
dequantization.
|
|
@@ -1425,7 +1353,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1425
1353
|
if hf_quantizer is None:
|
|
1426
1354
|
raise ValueError("You need to first quantize your model in order to dequantize it")
|
|
1427
1355
|
|
|
1428
|
-
return hf_quantizer.dequantize(self
|
|
1356
|
+
return hf_quantizer.dequantize(self)
|
|
1429
1357
|
|
|
1430
1358
|
def _backward_compatibility_gradient_checkpointing(self):
|
|
1431
1359
|
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
|
|
@@ -1433,7 +1361,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1433
1361
|
# Remove the attribute now that is has been consumed, so it's no saved in the config.
|
|
1434
1362
|
delattr(self.config, "gradient_checkpointing")
|
|
1435
1363
|
|
|
1436
|
-
def add_model_tags(self, tags: list[str]
|
|
1364
|
+
def add_model_tags(self, tags: Union[list[str], str]) -> None:
|
|
1437
1365
|
r"""
|
|
1438
1366
|
Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
|
|
1439
1367
|
not overwrite existing tags in the model.
|
|
@@ -1466,6 +1394,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1466
1394
|
self.model_tags.append(tag)
|
|
1467
1395
|
|
|
1468
1396
|
@classmethod
|
|
1397
|
+
@restore_default_dtype
|
|
1469
1398
|
def _from_config(cls, config, **kwargs):
|
|
1470
1399
|
"""
|
|
1471
1400
|
All context managers that the model should be initialized under go here.
|
|
@@ -1474,6 +1403,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1474
1403
|
dtype (`torch.dtype`, *optional*):
|
|
1475
1404
|
Override the default `dtype` and load the model under this dtype.
|
|
1476
1405
|
"""
|
|
1406
|
+
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
|
|
1407
|
+
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
|
|
1408
|
+
# modeling code, we can try to infer it here same way as done in `from_pretrained`
|
|
1477
1409
|
# For BC on the old `torch_dtype`
|
|
1478
1410
|
dtype = kwargs.pop("dtype", config.dtype)
|
|
1479
1411
|
if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
|
|
@@ -1483,32 +1415,61 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1483
1415
|
if isinstance(dtype, str):
|
|
1484
1416
|
dtype = getattr(torch, dtype)
|
|
1485
1417
|
|
|
1418
|
+
# override default dtype if needed
|
|
1419
|
+
dtype_orig = None
|
|
1420
|
+
if dtype is not None:
|
|
1421
|
+
dtype_orig = cls._set_default_dtype(dtype)
|
|
1422
|
+
|
|
1486
1423
|
# If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
|
|
1487
1424
|
if "attn_implementation" in kwargs:
|
|
1488
1425
|
config._attn_implementation = kwargs.pop("attn_implementation")
|
|
1489
1426
|
|
|
1490
|
-
# If passing `experts_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
|
|
1491
|
-
if "experts_implementation" in kwargs:
|
|
1492
|
-
config._experts_implementation = kwargs.pop("experts_implementation")
|
|
1493
|
-
|
|
1494
|
-
init_contexts = []
|
|
1495
|
-
if dtype is not None:
|
|
1496
|
-
init_contexts.append(local_torch_dtype(dtype, cls.__name__))
|
|
1497
|
-
|
|
1498
1427
|
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
|
|
1499
1428
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
|
1500
1429
|
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
|
1501
1430
|
# and memory copying it on CPU or each GPU first
|
|
1502
1431
|
import deepspeed
|
|
1503
1432
|
|
|
1504
|
-
init_contexts
|
|
1433
|
+
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
|
|
1434
|
+
with ContextManagers(init_contexts):
|
|
1435
|
+
model = cls(config, **kwargs)
|
|
1505
1436
|
|
|
1506
|
-
|
|
1507
|
-
with ContextManagers(init_contexts):
|
|
1437
|
+
else:
|
|
1508
1438
|
model = cls(config, **kwargs)
|
|
1509
1439
|
|
|
1440
|
+
# restore default dtype if it was modified
|
|
1441
|
+
if dtype_orig is not None:
|
|
1442
|
+
torch.set_default_dtype(dtype_orig)
|
|
1443
|
+
|
|
1510
1444
|
return model
|
|
1511
1445
|
|
|
1446
|
+
@classmethod
|
|
1447
|
+
def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
|
1448
|
+
"""
|
|
1449
|
+
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
|
|
1450
|
+
under specific dtype.
|
|
1451
|
+
|
|
1452
|
+
Args:
|
|
1453
|
+
dtype (`torch.dtype`):
|
|
1454
|
+
a floating dtype to set to.
|
|
1455
|
+
|
|
1456
|
+
Returns:
|
|
1457
|
+
`torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
|
|
1458
|
+
modified. If it wasn't, returns `None`.
|
|
1459
|
+
|
|
1460
|
+
Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
|
|
1461
|
+
`torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
|
|
1462
|
+
"""
|
|
1463
|
+
if not dtype.is_floating_point:
|
|
1464
|
+
raise ValueError(
|
|
1465
|
+
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
|
|
1466
|
+
)
|
|
1467
|
+
|
|
1468
|
+
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
|
|
1469
|
+
dtype_orig = torch.get_default_dtype()
|
|
1470
|
+
torch.set_default_dtype(dtype)
|
|
1471
|
+
return dtype_orig
|
|
1472
|
+
|
|
1512
1473
|
@property
|
|
1513
1474
|
def base_model(self) -> nn.Module:
|
|
1514
1475
|
"""
|
|
@@ -1585,9 +1546,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1585
1546
|
return True
|
|
1586
1547
|
|
|
1587
1548
|
if is_torch_xpu_available():
|
|
1588
|
-
logger.info(
|
|
1589
|
-
f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU."
|
|
1590
|
-
)
|
|
1549
|
+
logger.info("Detect using FlashAttention2 (via kernel `kernels-community/flash-attn2`) on XPU.")
|
|
1591
1550
|
return True
|
|
1592
1551
|
|
|
1593
1552
|
if importlib.util.find_spec("flash_attn") is None:
|
|
@@ -1756,22 +1715,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1756
1715
|
|
|
1757
1716
|
return True
|
|
1758
1717
|
|
|
1759
|
-
def _grouped_mm_can_dispatch(self) -> bool:
|
|
1760
|
-
"""
|
|
1761
|
-
Check the availability of Grouped MM for a given model.
|
|
1762
|
-
"""
|
|
1763
|
-
|
|
1764
|
-
if not self._can_set_experts_implementation():
|
|
1765
|
-
raise ValueError(f"{self.__class__.__name__} does not support setting experts implementation.")
|
|
1766
|
-
|
|
1767
|
-
if not is_grouped_mm_available():
|
|
1768
|
-
raise ImportError(
|
|
1769
|
-
"PyTorch Grouped MM requirements in Transformers are not met. Please install torch>=2.9.0."
|
|
1770
|
-
)
|
|
1771
|
-
|
|
1772
|
-
# If no error raised by this point, we can return `True`
|
|
1773
|
-
return True
|
|
1774
|
-
|
|
1775
1718
|
def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
|
|
1776
1719
|
"""
|
|
1777
1720
|
Check the availability of Flex Attention for a given model.
|
|
@@ -1800,7 +1743,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1800
1743
|
return True
|
|
1801
1744
|
|
|
1802
1745
|
def _check_and_adjust_attn_implementation(
|
|
1803
|
-
self, attn_implementation: str
|
|
1746
|
+
self, attn_implementation: Optional[str], is_init_check: bool = False
|
|
1804
1747
|
) -> str:
|
|
1805
1748
|
"""
|
|
1806
1749
|
Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
|
|
@@ -1821,12 +1764,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1821
1764
|
"""
|
|
1822
1765
|
applicable_attn_implementation = attn_implementation
|
|
1823
1766
|
|
|
1824
|
-
is_paged = attn_implementation is not None and attn_implementation.startswith("paged|")
|
|
1825
|
-
|
|
1826
1767
|
# If FA not installed, do not fail but use kernels instead
|
|
1827
1768
|
requested_original_flash_attn = attn_implementation is not None and (
|
|
1828
|
-
attn_implementation
|
|
1829
|
-
or attn_implementation.removeprefix("paged|") == "flash_attention_3"
|
|
1769
|
+
attn_implementation == "flash_attention_2" or attn_implementation == "flash_attention_3"
|
|
1830
1770
|
)
|
|
1831
1771
|
if (
|
|
1832
1772
|
requested_original_flash_attn
|
|
@@ -1835,23 +1775,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1835
1775
|
and is_kernels_available()
|
|
1836
1776
|
and not is_torch_npu_available()
|
|
1837
1777
|
):
|
|
1838
|
-
|
|
1839
|
-
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
applicable_attn_implementation = f"paged|{applicable_attn_implementation}"
|
|
1778
|
+
if attn_implementation.endswith("2"):
|
|
1779
|
+
applicable_attn_implementation = "kernels-community/flash-attn2"
|
|
1780
|
+
if is_torch_xpu_available():
|
|
1781
|
+
# On XPU, kernels library is the native implementation
|
|
1782
|
+
# Disabling this flag to avoid giving wrong fallbacks on errors and warnings
|
|
1783
|
+
requested_original_flash_attn = False
|
|
1784
|
+
else:
|
|
1785
|
+
applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
|
|
1847
1786
|
|
|
1848
1787
|
if is_kernel(applicable_attn_implementation):
|
|
1849
1788
|
try:
|
|
1850
1789
|
# preload flash attention here to allow compile with fullgraph
|
|
1851
|
-
|
|
1852
|
-
lazy_import_paged_flash_attention(applicable_attn_implementation)
|
|
1853
|
-
else:
|
|
1854
|
-
lazy_import_flash_attention(applicable_attn_implementation)
|
|
1790
|
+
lazy_import_flash_attention(applicable_attn_implementation)
|
|
1855
1791
|
|
|
1856
1792
|
# log that we used kernel fallback if successful
|
|
1857
1793
|
if requested_original_flash_attn:
|
|
@@ -1875,25 +1811,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1875
1811
|
)
|
|
1876
1812
|
|
|
1877
1813
|
# preload flash attention here to allow compile with fullgraph
|
|
1878
|
-
if
|
|
1814
|
+
if "flash" in applicable_attn_implementation:
|
|
1879
1815
|
lazy_import_flash_attention(applicable_attn_implementation)
|
|
1880
1816
|
|
|
1881
1817
|
return applicable_attn_implementation
|
|
1882
1818
|
|
|
1883
|
-
def
|
|
1884
|
-
"""
|
|
1885
|
-
Check that the `experts_implementation` exists and is supported by the models.
|
|
1886
|
-
|
|
1887
|
-
Args:
|
|
1888
|
-
experts_implementation (`str` or `None`):
|
|
1889
|
-
The experts implementation to check for existence/validity.
|
|
1890
|
-
Returns:
|
|
1891
|
-
`str`: The final experts implementation to use.
|
|
1892
|
-
"""
|
|
1893
|
-
applicable_experts_implementation = self.get_correct_experts_implementation(experts_implementation)
|
|
1894
|
-
return applicable_experts_implementation
|
|
1895
|
-
|
|
1896
|
-
def get_correct_attn_implementation(self, requested_attention: str | None, is_init_check: bool = False) -> str:
|
|
1819
|
+
def get_correct_attn_implementation(self, requested_attention: Optional[str], is_init_check: bool = False) -> str:
|
|
1897
1820
|
applicable_attention = "sdpa" if requested_attention is None else requested_attention
|
|
1898
1821
|
if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
|
1899
1822
|
message = (
|
|
@@ -1927,33 +1850,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1927
1850
|
|
|
1928
1851
|
return applicable_attention
|
|
1929
1852
|
|
|
1930
|
-
def get_correct_experts_implementation(self, requested_experts: str | None) -> str:
|
|
1931
|
-
applicable_experts = "grouped_mm" if requested_experts is None else requested_experts
|
|
1932
|
-
if applicable_experts not in ["eager", "grouped_mm", "batched_mm"]:
|
|
1933
|
-
message = (
|
|
1934
|
-
f'Specified `experts_implementation="{applicable_experts}"` is not supported. The only possible arguments are '
|
|
1935
|
-
'`experts_implementation="eager"`, `"experts_implementation=grouped_mm"` and `"experts_implementation=batched_mm"`.'
|
|
1936
|
-
)
|
|
1937
|
-
raise ValueError(message)
|
|
1938
|
-
|
|
1939
|
-
# Perform relevant checks
|
|
1940
|
-
if applicable_experts == "grouped_mm":
|
|
1941
|
-
try:
|
|
1942
|
-
self._grouped_mm_can_dispatch()
|
|
1943
|
-
except (ValueError, ImportError) as e:
|
|
1944
|
-
if requested_experts == "grouped_mm":
|
|
1945
|
-
raise e
|
|
1946
|
-
applicable_experts = "eager"
|
|
1947
|
-
|
|
1948
|
-
return applicable_experts
|
|
1949
|
-
|
|
1950
1853
|
@classmethod
|
|
1951
1854
|
def _can_set_attn_implementation(cls) -> bool:
|
|
1952
1855
|
"""Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
|
|
1953
1856
|
opening the file, but avoids maintaining yet another property flag.
|
|
1954
1857
|
"""
|
|
1955
1858
|
class_file = sys.modules[cls.__module__].__file__
|
|
1956
|
-
with open(class_file, "r"
|
|
1859
|
+
with open(class_file, "r") as f:
|
|
1957
1860
|
code = f.read()
|
|
1958
1861
|
# heuristic -> if we find those patterns, the model uses the correct interface
|
|
1959
1862
|
if re.search(r"class \w+Attention\(nn.Module\)", code):
|
|
@@ -1965,18 +1868,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1965
1868
|
# If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
|
|
1966
1869
|
return True
|
|
1967
1870
|
|
|
1968
|
-
|
|
1969
|
-
def _can_set_experts_implementation(cls) -> bool:
|
|
1970
|
-
"""Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
|
|
1971
|
-
opening the file, but avoids maintaining yet another property flag.
|
|
1972
|
-
"""
|
|
1973
|
-
class_file = sys.modules[cls.__module__].__file__
|
|
1974
|
-
with open(class_file, "r", encoding="utf-8") as f:
|
|
1975
|
-
code = f.read()
|
|
1976
|
-
# heuristic -> if we the use_experts_implementation decorator is used, then we can set it
|
|
1977
|
-
return "@use_experts_implementation" in code
|
|
1978
|
-
|
|
1979
|
-
def set_attn_implementation(self, attn_implementation: str | dict):
|
|
1871
|
+
def set_attn_implementation(self, attn_implementation: Union[str, dict]):
|
|
1980
1872
|
"""
|
|
1981
1873
|
Set the requested `attn_implementation` for this model.
|
|
1982
1874
|
|
|
@@ -2075,50 +1967,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2075
1967
|
if hasattr(subconfig, "_attn_was_changed"):
|
|
2076
1968
|
del subconfig._attn_was_changed
|
|
2077
1969
|
|
|
2078
|
-
def set_experts_implementation(self, experts_implementation: str | dict):
|
|
2079
|
-
"""
|
|
2080
|
-
Set the requested `experts_implementation` for this model.
|
|
2081
|
-
|
|
2082
|
-
Args:
|
|
2083
|
-
experts_implementation (`str` or `dict`):
|
|
2084
|
-
The experts implementation to set for this model. It can be either a `str`, in which case it will be
|
|
2085
|
-
dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
|
|
2086
|
-
submodel will dispatch the corresponding value.
|
|
2087
|
-
"""
|
|
2088
|
-
requested_implementation = (
|
|
2089
|
-
experts_implementation
|
|
2090
|
-
if not isinstance(experts_implementation, dict)
|
|
2091
|
-
else experts_implementation.get("", self.config._experts_implementation)
|
|
2092
|
-
)
|
|
2093
|
-
|
|
2094
|
-
if requested_implementation != self.config._experts_implementation:
|
|
2095
|
-
requested_implementation = self._check_and_adjust_experts_implementation(requested_implementation)
|
|
2096
|
-
# Apply the change (on the internal attr, to avoid setting it recursively)
|
|
2097
|
-
self.config._experts_implementation_internal = requested_implementation
|
|
2098
|
-
|
|
2099
|
-
# Apply it to all submodels as well
|
|
2100
|
-
for submodule in self.modules():
|
|
2101
|
-
# We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
|
|
2102
|
-
# e.g. ForCausalLM has a Model inside, but no need to check it again)
|
|
2103
|
-
if (
|
|
2104
|
-
submodule is not self
|
|
2105
|
-
and isinstance(submodule, PreTrainedModel)
|
|
2106
|
-
and submodule.config.__class__ != self.config.__class__
|
|
2107
|
-
):
|
|
2108
|
-
# Set the experts on the submodule
|
|
2109
|
-
sub_implementation = requested_implementation
|
|
2110
|
-
if isinstance(experts_implementation, dict):
|
|
2111
|
-
for subconfig_key in self.config.sub_configs:
|
|
2112
|
-
# We need to check for exact object match here, with `is`
|
|
2113
|
-
if getattr(self.config, subconfig_key) is submodule.config:
|
|
2114
|
-
sub_implementation = experts_implementation.get(
|
|
2115
|
-
subconfig_key, submodule.config._experts_implementation
|
|
2116
|
-
)
|
|
2117
|
-
break
|
|
2118
|
-
# Check the module can use correctly, otherwise we raise an error if requested experts can't be set for submodule
|
|
2119
|
-
sub_implementation = submodule.get_correct_experts_implementation(sub_implementation)
|
|
2120
|
-
submodule.config._experts_implementation_internal = sub_implementation
|
|
2121
|
-
|
|
2122
1970
|
def enable_input_require_grads(self):
|
|
2123
1971
|
"""
|
|
2124
1972
|
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
|
|
@@ -2130,18 +1978,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2130
1978
|
|
|
2131
1979
|
hooks = []
|
|
2132
1980
|
seen_modules = set()
|
|
2133
|
-
found_embeddings = False
|
|
2134
1981
|
|
|
2135
1982
|
for module in self.modules():
|
|
2136
1983
|
if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
|
|
2137
1984
|
continue
|
|
2138
1985
|
|
|
2139
|
-
|
|
2140
|
-
input_embeddings = module.get_input_embeddings()
|
|
2141
|
-
except NotImplementedError:
|
|
2142
|
-
continue
|
|
1986
|
+
input_embeddings = module.get_input_embeddings()
|
|
2143
1987
|
|
|
2144
|
-
if input_embeddings is None
|
|
1988
|
+
if input_embeddings is None:
|
|
2145
1989
|
continue
|
|
2146
1990
|
|
|
2147
1991
|
embedding_id = id(input_embeddings)
|
|
@@ -2150,18 +1994,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2150
1994
|
|
|
2151
1995
|
seen_modules.add(embedding_id)
|
|
2152
1996
|
hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))
|
|
2153
|
-
found_embeddings = True
|
|
2154
1997
|
|
|
2155
1998
|
self._require_grads_hooks = hooks
|
|
2156
1999
|
if hooks:
|
|
2157
2000
|
# for BC
|
|
2158
2001
|
self._require_grads_hook = hooks[0]
|
|
2159
|
-
if not found_embeddings:
|
|
2160
|
-
logger.warning_once(
|
|
2161
|
-
f"{self.__class__.__name__} does not expose input embeddings. Gradients cannot flow back to the token "
|
|
2162
|
-
"embeddings when using adapters or gradient checkpointing. Override `get_input_embeddings` to fully "
|
|
2163
|
-
"support those features, or set `_input_embed_layer` to the attribute name that holds the embeddings."
|
|
2164
|
-
)
|
|
2165
2002
|
|
|
2166
2003
|
def disable_input_require_grads(self):
|
|
2167
2004
|
"""
|
|
@@ -2178,7 +2015,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2178
2015
|
if hasattr(self, "_require_grads_hook"):
|
|
2179
2016
|
del self._require_grads_hook
|
|
2180
2017
|
|
|
2181
|
-
def get_encoder(self, modality: str
|
|
2018
|
+
def get_encoder(self, modality: Optional[str] = None):
|
|
2182
2019
|
"""
|
|
2183
2020
|
Best-effort lookup of the *encoder* module. If provided with `modality` argument,
|
|
2184
2021
|
it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder")
|
|
@@ -2210,7 +2047,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2210
2047
|
# If this is a base transformer model (no encoder/model attributes), return self
|
|
2211
2048
|
return self
|
|
2212
2049
|
|
|
2213
|
-
def set_encoder(self, encoder, modality: str
|
|
2050
|
+
def set_encoder(self, encoder, modality: Optional[str] = None):
|
|
2214
2051
|
"""
|
|
2215
2052
|
Symmetric setter. Mirrors the lookup logic used in `get_encoder`.
|
|
2216
2053
|
"""
|
|
@@ -2267,6 +2104,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2267
2104
|
possible_module_names = ["language_model", "text_model", "decoder"]
|
|
2268
2105
|
for name in possible_module_names:
|
|
2269
2106
|
if hasattr(self, name):
|
|
2107
|
+
print(name)
|
|
2270
2108
|
setattr(self, name, decoder)
|
|
2271
2109
|
return
|
|
2272
2110
|
|
|
@@ -2296,13 +2134,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2296
2134
|
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
|
|
2297
2135
|
if getattr(module, "weight", None) is not None:
|
|
2298
2136
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
2299
|
-
if module
|
|
2137
|
+
if getattr(module, "bias", None) is not None:
|
|
2300
2138
|
init.zeros_(module.bias)
|
|
2301
2139
|
elif isinstance(module, nn.Embedding):
|
|
2302
|
-
|
|
2303
|
-
|
|
2304
|
-
|
|
2305
|
-
|
|
2140
|
+
if getattr(module, "weight", None) is not None:
|
|
2141
|
+
init.normal_(module.weight, mean=0.0, std=std)
|
|
2142
|
+
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
2143
|
+
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
|
|
2144
|
+
init.zeros_(module.weight[module.padding_idx])
|
|
2306
2145
|
elif isinstance(module, nn.MultiheadAttention):
|
|
2307
2146
|
# This uses torch's original init
|
|
2308
2147
|
module._reset_parameters()
|
|
@@ -2314,25 +2153,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2314
2153
|
or "RMSNorm" in module.__class__.__name__
|
|
2315
2154
|
):
|
|
2316
2155
|
# Norms can exist without weights (in which case they are None from torch primitives)
|
|
2317
|
-
if
|
|
2156
|
+
if hasattr(module, "weight") and module.weight is not None:
|
|
2318
2157
|
init.ones_(module.weight)
|
|
2319
|
-
if
|
|
2158
|
+
if hasattr(module, "bias") and module.bias is not None:
|
|
2320
2159
|
init.zeros_(module.bias)
|
|
2321
|
-
# And the potential buffers for the BatchNorms
|
|
2322
|
-
if getattr(module, "running_mean", None) is not None:
|
|
2323
|
-
init.zeros_(module.running_mean)
|
|
2324
|
-
init.ones_(module.running_var)
|
|
2325
|
-
init.zeros_(module.num_batches_tracked)
|
|
2326
|
-
# This matches all the usual RotaryEmbeddings modules
|
|
2327
|
-
elif "RotaryEmbedding" in module.__class__.__name__ and hasattr(module, "original_inv_freq"):
|
|
2328
|
-
rope_fn = (
|
|
2329
|
-
ROPE_INIT_FUNCTIONS[module.rope_type]
|
|
2330
|
-
if module.rope_type != "default"
|
|
2331
|
-
else module.compute_default_rope_parameters
|
|
2332
|
-
)
|
|
2333
|
-
buffer_value, _ = rope_fn(module.config)
|
|
2334
|
-
init.copy_(module.inv_freq, buffer_value)
|
|
2335
|
-
init.copy_(module.original_inv_freq, buffer_value)
|
|
2336
2160
|
|
|
2337
2161
|
def _initialize_weights(self, module):
|
|
2338
2162
|
"""
|
|
@@ -2437,10 +2261,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2437
2261
|
|
|
2438
2262
|
tied_mapping = self._tied_weights_keys
|
|
2439
2263
|
# If the config does not specify any tying, return empty dict
|
|
2440
|
-
|
|
2441
|
-
# modules do not have any word embeddings!
|
|
2442
|
-
tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False)
|
|
2443
|
-
if not tie_word_embeddings:
|
|
2264
|
+
if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder:
|
|
2444
2265
|
return {}
|
|
2445
2266
|
# If None, return empty dict
|
|
2446
2267
|
elif tied_mapping is None:
|
|
@@ -2486,7 +2307,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2486
2307
|
|
|
2487
2308
|
return expanded_tied_weights
|
|
2488
2309
|
|
|
2489
|
-
def tie_weights(self, missing_keys: set[str]
|
|
2310
|
+
def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True):
|
|
2490
2311
|
"""
|
|
2491
2312
|
Tie the model weights. If `recompute_mapping=False` (default when called internally), it will rely on the
|
|
2492
2313
|
`model.all_tied_weights_keys` attribute, containing the `{target: source}` mapping for the tied params.
|
|
@@ -2506,26 +2327,30 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2506
2327
|
|
|
2507
2328
|
tied_keys = list(tied_keys.items())
|
|
2508
2329
|
for i, (target_param_name, source_param_name) in enumerate(tied_keys):
|
|
2330
|
+
# Usually we tie a single target to a single source, but when both are missing we may later tie
|
|
2331
|
+
# both the source and target to a third "backup" parameter that is present in the checkpoint, so we use
|
|
2332
|
+
# a list here
|
|
2333
|
+
target_param_names = [target_param_name]
|
|
2334
|
+
|
|
2509
2335
|
# This is `from_pretrained` -> let's check symmetrically in case the source key is not present
|
|
2510
2336
|
if missing_keys is not None:
|
|
2511
2337
|
remove_from_missing = True
|
|
2512
2338
|
source_is_there = source_param_name not in missing_keys
|
|
2513
2339
|
target_is_there = target_param_name not in missing_keys
|
|
2514
2340
|
# Both are already present -> it means the config is wrong and do not reflect the actual
|
|
2515
|
-
# checkpoint -> let's raise a warning and
|
|
2341
|
+
# checkpoint -> let's raise a warning and do nothing
|
|
2516
2342
|
if source_is_there and target_is_there:
|
|
2517
2343
|
logger.warning(
|
|
2518
2344
|
f"The tied weights mapping and config for this model specifies to tie {source_param_name} to "
|
|
2519
2345
|
f"{target_param_name}, but both are present in the checkpoints, so we will NOT tie them. "
|
|
2520
2346
|
"You should update the config with `tie_word_embeddings=False` to silence this warning"
|
|
2521
2347
|
)
|
|
2522
|
-
# Remove from internal attribute to correctly reflect actual tied weights
|
|
2523
|
-
self.all_tied_weights_keys.pop(target_param_name)
|
|
2524
2348
|
# Skip to next iteration
|
|
2525
2349
|
continue
|
|
2526
2350
|
# We're missing the source but we have the target -> we swap them, tying the parameter that exists
|
|
2527
2351
|
elif not source_is_there and target_is_there:
|
|
2528
2352
|
target_param_name, source_param_name = source_param_name, target_param_name
|
|
2353
|
+
target_param_names = [target_param_name]
|
|
2529
2354
|
# Both are missing -> check other keys in case more than 2 keys are tied to the same weight
|
|
2530
2355
|
elif not source_is_there and not target_is_there:
|
|
2531
2356
|
for target_backup, source_backup in tied_keys[i + 1 :]:
|
|
@@ -2534,10 +2359,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2534
2359
|
if source_backup == source_param_name:
|
|
2535
2360
|
target_backup_is_there = target_backup not in missing_keys
|
|
2536
2361
|
# If the target is present, we found the correct weight to tie into (we know the source is missing)
|
|
2537
|
-
# Note here that we do not tie the missing source right now as well, as it will be done anyway when
|
|
2538
|
-
# the pair (target_backup, source_backup) becomes the main pair (target_param_name, source_param_name)
|
|
2539
2362
|
if target_backup_is_there:
|
|
2540
2363
|
source_param_name = target_backup
|
|
2364
|
+
# Append the source as well, since both are missing we'll tie both
|
|
2365
|
+
target_param_names.append(source_param_name)
|
|
2541
2366
|
break
|
|
2542
2367
|
# If we did not break from the loop, it was impossible to find a source key -> let's raise
|
|
2543
2368
|
else:
|
|
@@ -2553,18 +2378,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2553
2378
|
|
|
2554
2379
|
# Perform the actual tying
|
|
2555
2380
|
source_param = self.get_parameter_or_buffer(source_param_name)
|
|
2556
|
-
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
missing_keys
|
|
2381
|
+
for target_param_name in target_param_names:
|
|
2382
|
+
if "." in target_param_name:
|
|
2383
|
+
parent_name, name = target_param_name.rsplit(".", 1)
|
|
2384
|
+
parent = self.get_submodule(parent_name)
|
|
2385
|
+
else:
|
|
2386
|
+
name = target_param_name
|
|
2387
|
+
parent = self
|
|
2388
|
+
# Tie the weights
|
|
2389
|
+
setattr(parent, name, source_param)
|
|
2390
|
+
self._adjust_bias(parent, source_param)
|
|
2391
|
+
# Remove from missing if necesary
|
|
2392
|
+
if missing_keys is not None and remove_from_missing:
|
|
2393
|
+
missing_keys.discard(target_param_name)
|
|
2568
2394
|
|
|
2569
2395
|
def _adjust_bias(self, output_embeddings, input_embeddings):
|
|
2570
2396
|
if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"):
|
|
@@ -2609,8 +2435,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2609
2435
|
|
|
2610
2436
|
def resize_token_embeddings(
|
|
2611
2437
|
self,
|
|
2612
|
-
new_num_tokens: int
|
|
2613
|
-
pad_to_multiple_of: int
|
|
2438
|
+
new_num_tokens: Optional[int] = None,
|
|
2439
|
+
pad_to_multiple_of: Optional[int] = None,
|
|
2614
2440
|
mean_resizing: bool = True,
|
|
2615
2441
|
) -> nn.Embedding:
|
|
2616
2442
|
"""
|
|
@@ -2690,7 +2516,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2690
2516
|
new_num_tokens = new_embeddings.weight.shape[0]
|
|
2691
2517
|
|
|
2692
2518
|
# if word embeddings are not tied, make sure that lm head is resized as well
|
|
2693
|
-
if
|
|
2519
|
+
if (
|
|
2520
|
+
self.get_output_embeddings() is not None
|
|
2521
|
+
and not self.config.get_text_config(decoder=True).tie_word_embeddings
|
|
2522
|
+
):
|
|
2694
2523
|
old_lm_head = self.get_output_embeddings()
|
|
2695
2524
|
if isinstance(old_lm_head, torch.nn.Embedding):
|
|
2696
2525
|
new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
|
|
@@ -2708,8 +2537,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2708
2537
|
def _get_resized_embeddings(
|
|
2709
2538
|
self,
|
|
2710
2539
|
old_embeddings: nn.Embedding,
|
|
2711
|
-
new_num_tokens: int
|
|
2712
|
-
pad_to_multiple_of: int
|
|
2540
|
+
new_num_tokens: Optional[int] = None,
|
|
2541
|
+
pad_to_multiple_of: Optional[int] = None,
|
|
2713
2542
|
mean_resizing: bool = True,
|
|
2714
2543
|
) -> nn.Embedding:
|
|
2715
2544
|
"""
|
|
@@ -2866,7 +2695,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2866
2695
|
def _get_resized_lm_head(
|
|
2867
2696
|
self,
|
|
2868
2697
|
old_lm_head: nn.Linear,
|
|
2869
|
-
new_num_tokens: int
|
|
2698
|
+
new_num_tokens: Optional[int] = None,
|
|
2870
2699
|
transposed: bool = False,
|
|
2871
2700
|
mean_resizing: bool = True,
|
|
2872
2701
|
) -> nn.Linear:
|
|
@@ -3063,7 +2892,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3063
2892
|
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
|
|
3064
2893
|
)
|
|
3065
2894
|
|
|
3066
|
-
def get_position_embeddings(self) -> nn.Embedding
|
|
2895
|
+
def get_position_embeddings(self) -> Union[nn.Embedding, tuple[nn.Embedding]]:
|
|
3067
2896
|
raise NotImplementedError(
|
|
3068
2897
|
f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
|
|
3069
2898
|
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
|
|
@@ -3074,8 +2903,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3074
2903
|
Maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
|
|
3075
2904
|
initialization logic in `_init_weights`.
|
|
3076
2905
|
"""
|
|
3077
|
-
|
|
3078
|
-
if get_torch_context_manager_or_global_device() != torch.device("meta"):
|
|
2906
|
+
if _init_weights:
|
|
3079
2907
|
# Initialize weights
|
|
3080
2908
|
self.initialize_weights()
|
|
3081
2909
|
# Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
|
|
@@ -3096,7 +2924,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3096
2924
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
|
3097
2925
|
|
|
3098
2926
|
if gradient_checkpointing_kwargs is None:
|
|
3099
|
-
gradient_checkpointing_kwargs = {"use_reentrant":
|
|
2927
|
+
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
|
3100
2928
|
|
|
3101
2929
|
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
|
3102
2930
|
|
|
@@ -3113,10 +2941,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3113
2941
|
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
|
3114
2942
|
)
|
|
3115
2943
|
|
|
3116
|
-
|
|
3117
|
-
# we use that also to detect whether or not we have to raise if embeddings are missing (the submodel might not have embeddings at all)
|
|
3118
|
-
enable_input_grads = needs_embedding_grads or getattr(self, "_hf_peft_config_loaded", False)
|
|
3119
|
-
if enable_input_grads:
|
|
2944
|
+
if getattr(self, "_hf_peft_config_loaded", False):
|
|
3120
2945
|
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
|
3121
2946
|
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
|
3122
2947
|
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
|
@@ -3174,13 +2999,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3174
2999
|
|
|
3175
3000
|
def save_pretrained(
|
|
3176
3001
|
self,
|
|
3177
|
-
save_directory: str
|
|
3002
|
+
save_directory: Union[str, os.PathLike],
|
|
3178
3003
|
is_main_process: bool = True,
|
|
3179
|
-
state_dict: dict
|
|
3004
|
+
state_dict: Optional[dict] = None,
|
|
3005
|
+
save_function: Callable = torch.save,
|
|
3180
3006
|
push_to_hub: bool = False,
|
|
3181
|
-
max_shard_size: int
|
|
3182
|
-
|
|
3183
|
-
|
|
3007
|
+
max_shard_size: Union[int, str] = "5GB",
|
|
3008
|
+
safe_serialization: bool = True,
|
|
3009
|
+
variant: Optional[str] = None,
|
|
3010
|
+
token: Optional[Union[str, bool]] = None,
|
|
3184
3011
|
save_peft_format: bool = True,
|
|
3185
3012
|
save_original_format: bool = True,
|
|
3186
3013
|
**kwargs,
|
|
@@ -3200,13 +3027,18 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3200
3027
|
The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
|
|
3201
3028
|
save parts of the model or if special precautions need to be taken when recovering the state dictionary
|
|
3202
3029
|
of a model (like when using model parallelism).
|
|
3030
|
+
save_function (`Callable`):
|
|
3031
|
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
|
3032
|
+
need to replace `torch.save` by another method.
|
|
3203
3033
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
3204
3034
|
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
|
3205
3035
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
|
3206
3036
|
namespace).
|
|
3207
|
-
max_shard_size (`int` or `str`, *optional*, defaults to `"
|
|
3037
|
+
max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
|
|
3208
3038
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
|
3209
3039
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
|
3040
|
+
We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
|
|
3041
|
+
without CPU OOM issues.
|
|
3210
3042
|
|
|
3211
3043
|
<Tip warning={true}>
|
|
3212
3044
|
|
|
@@ -3215,8 +3047,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3215
3047
|
|
|
3216
3048
|
</Tip>
|
|
3217
3049
|
|
|
3050
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
|
3051
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
|
3218
3052
|
variant (`str`, *optional*):
|
|
3219
|
-
If specified, weights are saved in the format
|
|
3053
|
+
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
|
3220
3054
|
token (`str` or `bool`, *optional*):
|
|
3221
3055
|
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
|
3222
3056
|
the token generated when running `hf auth login` (stored in `~/.huggingface`).
|
|
@@ -3238,7 +3072,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3238
3072
|
|
|
3239
3073
|
hf_quantizer = getattr(self, "hf_quantizer", None)
|
|
3240
3074
|
quantization_serializable = (
|
|
3241
|
-
hf_quantizer is not None
|
|
3075
|
+
hf_quantizer is not None
|
|
3076
|
+
and isinstance(hf_quantizer, HfQuantizer)
|
|
3077
|
+
and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
|
|
3242
3078
|
)
|
|
3243
3079
|
|
|
3244
3080
|
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
|
|
@@ -3247,6 +3083,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3247
3083
|
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
|
3248
3084
|
)
|
|
3249
3085
|
|
|
3086
|
+
if "save_config" in kwargs:
|
|
3087
|
+
warnings.warn(
|
|
3088
|
+
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
|
|
3089
|
+
)
|
|
3090
|
+
is_main_process = kwargs.pop("save_config")
|
|
3091
|
+
|
|
3250
3092
|
# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
|
|
3251
3093
|
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
|
|
3252
3094
|
raise ImportError(
|
|
@@ -3268,7 +3110,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3268
3110
|
|
|
3269
3111
|
metadata = {}
|
|
3270
3112
|
if hf_quantizer is not None:
|
|
3271
|
-
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self)
|
|
3113
|
+
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization)
|
|
3272
3114
|
metadata["format"] = "pt"
|
|
3273
3115
|
|
|
3274
3116
|
# Only save the model itself if we are using distributed training
|
|
@@ -3321,22 +3163,28 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3321
3163
|
current_peft_config = self.peft_config[active_adapter]
|
|
3322
3164
|
current_peft_config.save_pretrained(save_directory)
|
|
3323
3165
|
|
|
3324
|
-
#
|
|
3166
|
+
# for offloaded modules
|
|
3167
|
+
module_map = {}
|
|
3168
|
+
|
|
3169
|
+
# Save the model
|
|
3325
3170
|
if state_dict is None:
|
|
3326
|
-
|
|
3171
|
+
# if any model parameters are offloaded, make module map
|
|
3172
|
+
if (
|
|
3173
|
+
hasattr(self, "hf_device_map")
|
|
3174
|
+
and len(set(self.hf_device_map.values())) > 1
|
|
3175
|
+
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
|
|
3176
|
+
):
|
|
3177
|
+
warnings.warn(
|
|
3178
|
+
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
|
|
3179
|
+
)
|
|
3180
|
+
for name, module in model_to_save.named_modules():
|
|
3181
|
+
if name == "":
|
|
3182
|
+
continue
|
|
3183
|
+
module_state_dict = module.state_dict()
|
|
3327
3184
|
|
|
3328
|
-
|
|
3329
|
-
|
|
3330
|
-
|
|
3331
|
-
hasattr(self, "hf_device_map")
|
|
3332
|
-
and len(set(self.hf_device_map.values())) > 1
|
|
3333
|
-
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
|
|
3334
|
-
):
|
|
3335
|
-
is_offloaded = True
|
|
3336
|
-
warnings.warn(
|
|
3337
|
-
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory "
|
|
3338
|
-
"exceeds the `shard_size` (50GB default)"
|
|
3339
|
-
)
|
|
3185
|
+
for key in module_state_dict:
|
|
3186
|
+
module_map[name + f".{key}"] = module
|
|
3187
|
+
state_dict = model_to_save.state_dict()
|
|
3340
3188
|
|
|
3341
3189
|
# Translate state_dict from smp to hf if saving with smp >= 1.10
|
|
3342
3190
|
if IS_SAGEMAKER_MP_POST_1_10:
|
|
@@ -3354,19 +3202,86 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3354
3202
|
if self._tp_size is not None:
|
|
3355
3203
|
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
|
3356
3204
|
|
|
3357
|
-
|
|
3358
|
-
|
|
3205
|
+
if safe_serialization:
|
|
3206
|
+
# TODO: fix safe_serialization for tied weights
|
|
3207
|
+
# Safetensors does not allow tensor aliasing.
|
|
3208
|
+
# We're going to remove aliases before saving
|
|
3209
|
+
ptrs = collections.defaultdict(list)
|
|
3210
|
+
for name, tensor in state_dict.items():
|
|
3211
|
+
if not isinstance(tensor, torch.Tensor):
|
|
3212
|
+
# Sometimes in the state_dict we have non-tensor objects.
|
|
3213
|
+
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
|
3214
|
+
# In the non-tensor case, fall back to the pointer of the object itself
|
|
3215
|
+
ptrs[id(tensor)].append(name)
|
|
3216
|
+
|
|
3217
|
+
elif tensor.device.type == "meta":
|
|
3218
|
+
# In offloaded cases, there may be meta tensors in the state_dict.
|
|
3219
|
+
# For these cases, key by the pointer of the original tensor object
|
|
3220
|
+
# (state_dict tensors are detached and therefore no longer shared)
|
|
3221
|
+
tensor = self.get_parameter(name)
|
|
3222
|
+
ptrs[id(tensor)].append(name)
|
|
3223
|
+
|
|
3224
|
+
else:
|
|
3225
|
+
ptrs[id_tensor_storage(tensor)].append(name)
|
|
3226
|
+
|
|
3227
|
+
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
|
3228
|
+
|
|
3229
|
+
# Recursively descend to find tied weight keys
|
|
3230
|
+
_tied_weights_keys = set(_get_tied_weight_keys(self))
|
|
3231
|
+
error_names = []
|
|
3232
|
+
to_delete_names = set()
|
|
3233
|
+
for names in shared_ptrs.values():
|
|
3234
|
+
# Removing the keys which are declared as known duplicates on
|
|
3235
|
+
# load. This allows to make sure the name which is kept is consistent.
|
|
3236
|
+
if _tied_weights_keys is not None:
|
|
3237
|
+
found = 0
|
|
3238
|
+
for name in sorted(names):
|
|
3239
|
+
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
|
|
3240
|
+
if matches_pattern and name in state_dict:
|
|
3241
|
+
found += 1
|
|
3242
|
+
if found < len(names):
|
|
3243
|
+
to_delete_names.add(name)
|
|
3244
|
+
# We are entering a place where the weights and the transformers configuration do NOT match.
|
|
3245
|
+
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
|
|
3246
|
+
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
|
|
3247
|
+
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
|
|
3248
|
+
for name in disjoint_names:
|
|
3249
|
+
state_dict[name] = state_dict[name].clone()
|
|
3250
|
+
|
|
3251
|
+
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
|
3252
|
+
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
|
3253
|
+
# the key back leading to random tensor. A proper warning will be shown
|
|
3254
|
+
# during reload (if applicable), but since the file is not necessarily compatible with
|
|
3255
|
+
# the config, better show a proper warning.
|
|
3256
|
+
shared_names, identical_names = _find_identical(shared_names, state_dict)
|
|
3257
|
+
# delete tensors that have identical storage
|
|
3258
|
+
for inames in identical_names:
|
|
3259
|
+
known = inames.intersection(to_delete_names)
|
|
3260
|
+
for name in known:
|
|
3261
|
+
del state_dict[name]
|
|
3262
|
+
unknown = inames.difference(to_delete_names)
|
|
3263
|
+
if len(unknown) > 1:
|
|
3264
|
+
error_names.append(unknown)
|
|
3265
|
+
|
|
3266
|
+
if shared_names:
|
|
3267
|
+
error_names.extend(shared_names)
|
|
3268
|
+
|
|
3269
|
+
if len(error_names) > 0:
|
|
3270
|
+
raise RuntimeError(
|
|
3271
|
+
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
|
|
3272
|
+
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
|
|
3273
|
+
)
|
|
3359
3274
|
|
|
3360
3275
|
# Revert all renaming and/or weight operations
|
|
3361
|
-
if save_original_format
|
|
3362
|
-
state_dict = revert_weight_conversion(
|
|
3276
|
+
if save_original_format:
|
|
3277
|
+
state_dict = revert_weight_conversion(self, state_dict)
|
|
3363
3278
|
|
|
3364
3279
|
# Shard the model if it is too big.
|
|
3365
3280
|
if not _hf_peft_config_loaded:
|
|
3366
|
-
weights_name = SAFE_WEIGHTS_NAME
|
|
3281
|
+
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
|
3367
3282
|
weights_name = _add_variant(weights_name, variant)
|
|
3368
3283
|
else:
|
|
3369
|
-
weights_name = ADAPTER_SAFE_WEIGHTS_NAME
|
|
3284
|
+
weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
|
|
3370
3285
|
|
|
3371
3286
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
|
3372
3287
|
state_dict_split = split_torch_state_dict_into_shards(
|
|
@@ -3399,45 +3314,57 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3399
3314
|
and reg.fullmatch(filename_no_suffix) is not None
|
|
3400
3315
|
):
|
|
3401
3316
|
os.remove(full_filename)
|
|
3402
|
-
|
|
3403
3317
|
# Save the model
|
|
3404
|
-
|
|
3405
|
-
|
|
3406
|
-
|
|
3407
|
-
|
|
3408
|
-
|
|
3409
|
-
for
|
|
3410
|
-
|
|
3411
|
-
|
|
3412
|
-
|
|
3413
|
-
# In case of TP, get the full parameter back
|
|
3414
|
-
if _is_dtensor_available and isinstance(tensor, DTensor):
|
|
3415
|
-
tensor = tensor.full_tensor()
|
|
3318
|
+
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
|
3319
|
+
if module_map:
|
|
3320
|
+
filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
|
|
3321
|
+
for shard_file, tensors in filename_to_tensors:
|
|
3322
|
+
shard = {}
|
|
3323
|
+
for tensor in tensors:
|
|
3324
|
+
if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
|
|
3325
|
+
full_tensor = state_dict[tensor].full_tensor()
|
|
3416
3326
|
# to get the correctly ordered tensor we need to repack if packed
|
|
3417
|
-
if _get_parameter_tp_plan(
|
|
3418
|
-
|
|
3419
|
-
|
|
3420
|
-
|
|
3421
|
-
|
|
3422
|
-
#
|
|
3423
|
-
|
|
3424
|
-
|
|
3425
|
-
|
|
3426
|
-
|
|
3427
|
-
|
|
3327
|
+
if _get_parameter_tp_plan(tensor, self._tp_plan) == "local_packed_rowwise":
|
|
3328
|
+
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
|
|
3329
|
+
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
|
|
3330
|
+
else:
|
|
3331
|
+
shard[tensor] = state_dict[tensor].contiguous()
|
|
3332
|
+
# delete reference, see https://github.com/huggingface/transformers/pull/34890
|
|
3333
|
+
del state_dict[tensor]
|
|
3334
|
+
|
|
3335
|
+
# remake shard with onloaded parameters if necessary
|
|
3336
|
+
if module_map:
|
|
3337
|
+
# init state_dict for this shard
|
|
3338
|
+
shard_state_dict = dict.fromkeys(shard, "")
|
|
3339
|
+
for module_name in shard:
|
|
3340
|
+
# note that get_state_dict_from_offload can update with meta tensors
|
|
3341
|
+
# if both a parent module and its descendant are offloaded
|
|
3342
|
+
tensor = shard_state_dict[module_name]
|
|
3343
|
+
if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
|
|
3344
|
+
# update state dict with onloaded parameters
|
|
3345
|
+
module = module_map[module_name]
|
|
3346
|
+
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
|
|
3347
|
+
|
|
3348
|
+
# assign shard to be the completed state dict
|
|
3349
|
+
shard = shard_state_dict
|
|
3350
|
+
del shard_state_dict
|
|
3351
|
+
gc.collect()
|
|
3352
|
+
|
|
3353
|
+
if safe_serialization:
|
|
3354
|
+
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
|
3355
|
+
# joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting
|
|
3356
|
+
# too much before scheduling the next write when its in a different file
|
|
3357
|
+
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
|
3358
|
+
else:
|
|
3359
|
+
save_function(shard, os.path.join(save_directory, shard_file))
|
|
3428
3360
|
|
|
3429
|
-
|
|
3430
|
-
# so it's not possible for now....
|
|
3431
|
-
# Write the shard to disk
|
|
3432
|
-
safe_save_file(shard_state_dict, filename, metadata=metadata)
|
|
3433
|
-
# Cleanup the data before next loop (important with offloading, so we don't blowup cpu RAM)
|
|
3434
|
-
del shard_state_dict
|
|
3361
|
+
del state_dict
|
|
3435
3362
|
|
|
3436
3363
|
if index is None:
|
|
3437
3364
|
path_to_weights = os.path.join(save_directory, weights_name)
|
|
3438
3365
|
logger.info(f"Model weights saved in {path_to_weights}")
|
|
3439
3366
|
else:
|
|
3440
|
-
save_index_file = SAFE_WEIGHTS_INDEX_NAME
|
|
3367
|
+
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
|
3441
3368
|
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
|
3442
3369
|
# Save the index as well
|
|
3443
3370
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
|
@@ -3574,9 +3501,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3574
3501
|
" desired `dtype` by passing the correct `dtype` argument."
|
|
3575
3502
|
)
|
|
3576
3503
|
|
|
3577
|
-
if getattr(self, "is_loaded_in_8bit", False)
|
|
3504
|
+
if getattr(self, "is_loaded_in_8bit", False):
|
|
3578
3505
|
raise ValueError(
|
|
3579
|
-
"
|
|
3506
|
+
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
|
3507
|
+
" model has already been set to the correct devices and casted to the correct `dtype`."
|
|
3580
3508
|
)
|
|
3581
3509
|
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
|
|
3582
3510
|
if dtype_present_in_args:
|
|
@@ -3607,38 +3535,23 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3607
3535
|
return super().float(*args)
|
|
3608
3536
|
|
|
3609
3537
|
@classmethod
|
|
3610
|
-
def get_init_context(cls,
|
|
3611
|
-
# Need to instantiate with correct dtype
|
|
3612
|
-
init_contexts = [local_torch_dtype(dtype, cls.__name__)]
|
|
3538
|
+
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
|
|
3613
3539
|
if is_deepspeed_zero3_enabled():
|
|
3614
3540
|
import deepspeed
|
|
3615
3541
|
|
|
3542
|
+
init_contexts = [no_init_weights()]
|
|
3616
3543
|
# We cannot initialize the model on meta device with deepspeed when not quantized
|
|
3617
3544
|
if not is_quantized and not _is_ds_init_called:
|
|
3618
3545
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
|
3619
|
-
init_contexts.extend(
|
|
3620
|
-
[
|
|
3621
|
-
init.no_init_weights(),
|
|
3622
|
-
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
|
3623
|
-
set_zero3_state(),
|
|
3624
|
-
]
|
|
3625
|
-
)
|
|
3546
|
+
init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
|
|
3626
3547
|
elif is_quantized:
|
|
3627
|
-
init_contexts.extend([
|
|
3548
|
+
init_contexts.extend([init_empty_weights(), set_quantized_state()])
|
|
3628
3549
|
else:
|
|
3629
|
-
init_contexts
|
|
3550
|
+
init_contexts = [no_init_weights(), init_empty_weights()]
|
|
3630
3551
|
|
|
3631
3552
|
return init_contexts
|
|
3632
3553
|
|
|
3633
|
-
def set_use_kernels(self, use_kernels, kernel_config
|
|
3634
|
-
"""
|
|
3635
|
-
Set whether or not to use the `kernels` library to kernelize some layers of the model.
|
|
3636
|
-
Args:
|
|
3637
|
-
use_kernels (`bool`):
|
|
3638
|
-
Whether or not to use the `kernels` library to kernelize some layers of the model.
|
|
3639
|
-
kernel_config (`KernelConfig`, *optional*):
|
|
3640
|
-
The kernel configuration to use to kernelize the model. If `None`, the default kernel mapping will be used.
|
|
3641
|
-
"""
|
|
3554
|
+
def set_use_kernels(self, use_kernels, kernel_config):
|
|
3642
3555
|
if use_kernels:
|
|
3643
3556
|
if not is_kernels_available():
|
|
3644
3557
|
raise ValueError(
|
|
@@ -3659,9 +3572,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3659
3572
|
|
|
3660
3573
|
# This is a context manager to override the default kernel mapping
|
|
3661
3574
|
# We are calling kernelize inside this context manager using the use_kernels setter
|
|
3662
|
-
|
|
3663
|
-
inherit_mapping = not kernel_config.use_local_kernel
|
|
3664
|
-
with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
|
|
3575
|
+
with use_kernel_mapping(kernel_config.kernel_mapping):
|
|
3665
3576
|
self.use_kernels = True
|
|
3666
3577
|
# We use the default kernel mapping in .integrations.hub_kernels
|
|
3667
3578
|
else:
|
|
@@ -3670,18 +3581,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3670
3581
|
self.use_kernels = False
|
|
3671
3582
|
|
|
3672
3583
|
@classmethod
|
|
3584
|
+
@restore_default_dtype
|
|
3673
3585
|
def from_pretrained(
|
|
3674
3586
|
cls: type[SpecificPreTrainedModelType],
|
|
3675
|
-
pretrained_model_name_or_path: str
|
|
3587
|
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
|
3676
3588
|
*model_args,
|
|
3677
|
-
config: PreTrainedConfig
|
|
3678
|
-
cache_dir: str
|
|
3589
|
+
config: Optional[Union[PreTrainedConfig, str, os.PathLike]] = None,
|
|
3590
|
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
|
3679
3591
|
ignore_mismatched_sizes: bool = False,
|
|
3680
3592
|
force_download: bool = False,
|
|
3681
3593
|
local_files_only: bool = False,
|
|
3682
|
-
token: str
|
|
3594
|
+
token: Optional[Union[str, bool]] = None,
|
|
3683
3595
|
revision: str = "main",
|
|
3684
|
-
use_safetensors: bool
|
|
3596
|
+
use_safetensors: Optional[bool] = True,
|
|
3685
3597
|
weights_only: bool = True,
|
|
3686
3598
|
**kwargs,
|
|
3687
3599
|
) -> SpecificPreTrainedModelType:
|
|
@@ -3778,18 +3690,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3778
3690
|
"org/model@main"
|
|
3779
3691
|
"org/model:custom_kernel"
|
|
3780
3692
|
"org/model@v1.2.3:custom_kernel"
|
|
3781
|
-
experts_implementation (`str`, *optional*):
|
|
3782
|
-
The experts implementation to use in the model (if relevant). Can be any of:
|
|
3783
|
-
|
|
3784
|
-
- `"eager"` (sequential implementation of the experts matrix multiplications).
|
|
3785
|
-
- `"batched_mm"` (using [`torch.bmm`](https://pytorch.org/docs/stable/generated/torch.bmm.html)).
|
|
3786
|
-
- `"grouped_mm"` (using [`torch._grouped_mm`](https://docs.pytorch.org/docs/main/generated/torch.nn.functional.grouped_mm.html)).
|
|
3787
|
-
|
|
3788
|
-
By default, if available, `grouped_mm` will be used for torch>=2.9.0. The default is otherwise the sequential `"eager"` implementation.
|
|
3789
3693
|
|
|
3790
3694
|
> Parameters for big model inference
|
|
3791
3695
|
|
|
3792
|
-
dtype (`str` or `torch.dtype`, *optional
|
|
3696
|
+
dtype (`str` or `torch.dtype`, *optional*):
|
|
3793
3697
|
Override the default `torch_dtype` and load the model under a specific `dtype`. The different options
|
|
3794
3698
|
are:
|
|
3795
3699
|
|
|
@@ -3931,8 +3835,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3931
3835
|
# For BC on torch_dtype argument
|
|
3932
3836
|
if torch_dtype is not None:
|
|
3933
3837
|
dtype = dtype if dtype is not None else torch_dtype
|
|
3934
|
-
if dtype is None:
|
|
3935
|
-
dtype = "auto"
|
|
3936
3838
|
|
|
3937
3839
|
if is_offline_mode() and not local_files_only:
|
|
3938
3840
|
local_files_only = True
|
|
@@ -4009,11 +3911,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4009
3911
|
if "attn_implementation" in kwargs:
|
|
4010
3912
|
config._attn_implementation = kwargs.pop("attn_implementation")
|
|
4011
3913
|
|
|
4012
|
-
|
|
4013
|
-
config
|
|
4014
|
-
|
|
4015
|
-
hf_quantizer, config, device_map = get_hf_quantizer(
|
|
4016
|
-
config, quantization_config, device_map, weights_only, user_agent
|
|
3914
|
+
hf_quantizer, config, dtype, device_map = get_hf_quantizer(
|
|
3915
|
+
config, quantization_config, dtype, device_map, weights_only, user_agent
|
|
4017
3916
|
)
|
|
4018
3917
|
|
|
4019
3918
|
if gguf_file:
|
|
@@ -4060,29 +3959,33 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4060
3959
|
]
|
|
4061
3960
|
|
|
4062
3961
|
# Find the correct dtype based on current state
|
|
4063
|
-
config, dtype = _get_dtype(
|
|
4064
|
-
dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
|
|
3962
|
+
config, dtype, dtype_orig = _get_dtype(
|
|
3963
|
+
cls, dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
|
|
4065
3964
|
)
|
|
4066
3965
|
|
|
4067
3966
|
config.name_or_path = pretrained_model_name_or_path
|
|
4068
|
-
model_init_context = cls.get_init_context(
|
|
3967
|
+
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
|
|
4069
3968
|
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
|
4070
3969
|
with ContextManagers(model_init_context):
|
|
4071
3970
|
# Let's make sure we don't run the init function of buffer modules
|
|
4072
3971
|
model = cls(config, *model_args, **model_kwargs)
|
|
4073
3972
|
|
|
4074
|
-
if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
|
|
4075
|
-
hf_quantizer.preprocess_model(
|
|
4076
|
-
model=model,
|
|
4077
|
-
dtype=dtype,
|
|
4078
|
-
device_map=device_map,
|
|
4079
|
-
checkpoint_files=checkpoint_files,
|
|
4080
|
-
use_kernels=use_kernels,
|
|
4081
|
-
)
|
|
4082
|
-
|
|
4083
3973
|
# Obtain the weight conversion mapping for this model if any are registered
|
|
4084
3974
|
weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
|
|
4085
3975
|
|
|
3976
|
+
# make sure we use the model's config since the __init__ call might have copied it
|
|
3977
|
+
config = model.config
|
|
3978
|
+
|
|
3979
|
+
if hf_quantizer is not None: # replace module with quantized modules (does not touch weights)
|
|
3980
|
+
hf_quantizer.preprocess_model(
|
|
3981
|
+
model=model,
|
|
3982
|
+
device_map=device_map,
|
|
3983
|
+
keep_in_fp32_modules=model._keep_in_fp32_modules, # TODO prob no longer needed?
|
|
3984
|
+
config=config,
|
|
3985
|
+
checkpoint_files=checkpoint_files,
|
|
3986
|
+
use_kernels=use_kernels,
|
|
3987
|
+
)
|
|
3988
|
+
|
|
4086
3989
|
if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
|
|
4087
3990
|
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
|
|
4088
3991
|
|
|
@@ -4090,30 +3993,33 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4090
3993
|
if device_map is not None:
|
|
4091
3994
|
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
|
|
4092
3995
|
|
|
3996
|
+
# restore default dtype
|
|
3997
|
+
if dtype_orig is not None:
|
|
3998
|
+
torch.set_default_dtype(dtype_orig)
|
|
3999
|
+
|
|
4093
4000
|
# Finalize model weight initialization
|
|
4094
|
-
|
|
4095
|
-
|
|
4001
|
+
model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
|
|
4002
|
+
model,
|
|
4003
|
+
state_dict,
|
|
4004
|
+
checkpoint_files,
|
|
4005
|
+
pretrained_model_name_or_path,
|
|
4096
4006
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
4097
4007
|
sharded_metadata=sharded_metadata,
|
|
4098
4008
|
device_map=device_map,
|
|
4099
4009
|
disk_offload_folder=offload_folder,
|
|
4100
|
-
offload_buffers=offload_buffers,
|
|
4101
4010
|
dtype=dtype,
|
|
4102
4011
|
hf_quantizer=hf_quantizer,
|
|
4103
4012
|
device_mesh=device_mesh,
|
|
4104
4013
|
weights_only=weights_only,
|
|
4105
4014
|
weight_mapping=weight_conversions,
|
|
4106
|
-
use_safetensors=use_safetensors,
|
|
4107
|
-
download_kwargs=download_kwargs,
|
|
4108
4015
|
)
|
|
4109
|
-
|
|
4110
|
-
|
|
4111
|
-
model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
|
|
4016
|
+
|
|
4017
|
+
model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
|
|
4112
4018
|
model.set_use_kernels(use_kernels, kernel_config)
|
|
4113
4019
|
|
|
4114
4020
|
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
|
4115
4021
|
# custom generate function)
|
|
4116
|
-
if model.can_generate() and hasattr(model, "adjust_generation_fn")
|
|
4022
|
+
if model.can_generate() and hasattr(model, "adjust_generation_fn"):
|
|
4117
4023
|
model.adjust_generation_fn(
|
|
4118
4024
|
generation_config,
|
|
4119
4025
|
from_auto_class,
|
|
@@ -4124,34 +4030,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4124
4030
|
**kwargs,
|
|
4125
4031
|
)
|
|
4126
4032
|
|
|
4127
|
-
#
|
|
4128
|
-
if device_map is not None and
|
|
4129
|
-
accelerate_dispatch(
|
|
4130
|
-
model, hf_quantizer, device_map, offload_folder, load_info.disk_offload_index, offload_buffers
|
|
4131
|
-
)
|
|
4033
|
+
# for device_map="auto" : dispatch model with hooks on all devices if necessary
|
|
4034
|
+
if device_map is not None and device_mesh is None:
|
|
4035
|
+
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
|
|
4132
4036
|
|
|
4133
4037
|
if hf_quantizer is not None:
|
|
4134
4038
|
model.hf_quantizer = hf_quantizer
|
|
4135
|
-
hf_quantizer.postprocess_model(
|
|
4136
|
-
model
|
|
4137
|
-
) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
|
|
4039
|
+
hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
|
|
4138
4040
|
|
|
4139
4041
|
if _adapter_model_path is not None:
|
|
4140
|
-
|
|
4141
|
-
|
|
4142
|
-
load_info = model.load_adapter(
|
|
4042
|
+
adapter_kwargs["key_mapping"] = weight_conversions # TODO: Dynamic weight loader for adapters
|
|
4043
|
+
model.load_adapter(
|
|
4143
4044
|
_adapter_model_path,
|
|
4144
4045
|
adapter_name=adapter_name,
|
|
4145
|
-
|
|
4046
|
+
token=token,
|
|
4146
4047
|
adapter_kwargs=adapter_kwargs,
|
|
4147
4048
|
)
|
|
4148
4049
|
|
|
4149
4050
|
if output_loading_info:
|
|
4150
4051
|
loading_info = {
|
|
4151
|
-
"missing_keys":
|
|
4152
|
-
"unexpected_keys":
|
|
4153
|
-
"mismatched_keys":
|
|
4154
|
-
"error_msgs":
|
|
4052
|
+
"missing_keys": missing_keys,
|
|
4053
|
+
"unexpected_keys": unexpected_keys,
|
|
4054
|
+
"mismatched_keys": mismatched_keys,
|
|
4055
|
+
"error_msgs": error_msgs,
|
|
4155
4056
|
}
|
|
4156
4057
|
return model, loading_info
|
|
4157
4058
|
return model
|
|
@@ -4160,65 +4061,74 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4160
4061
|
def _load_pretrained_model(
|
|
4161
4062
|
cls,
|
|
4162
4063
|
model: "PreTrainedModel",
|
|
4163
|
-
state_dict: dict
|
|
4164
|
-
checkpoint_files: list[str]
|
|
4165
|
-
|
|
4166
|
-
|
|
4167
|
-
|
|
4168
|
-
|
|
4064
|
+
state_dict: Optional[dict],
|
|
4065
|
+
checkpoint_files: Optional[list[str]],
|
|
4066
|
+
pretrained_model_name_or_path: Optional[str],
|
|
4067
|
+
ignore_mismatched_sizes: bool = False,
|
|
4068
|
+
sharded_metadata: Optional[dict] = None,
|
|
4069
|
+
device_map: Optional[dict] = None,
|
|
4070
|
+
disk_offload_folder: Optional[str] = None,
|
|
4071
|
+
dtype: Optional[torch.dtype] = None,
|
|
4072
|
+
hf_quantizer: Optional[HfQuantizer] = None,
|
|
4073
|
+
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
|
4074
|
+
weights_only: bool = True,
|
|
4075
|
+
weight_mapping: Optional[Sequence[WeightConverter | WeightRenaming]] = None,
|
|
4076
|
+
):
|
|
4077
|
+
is_quantized = hf_quantizer is not None
|
|
4078
|
+
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
|
4169
4079
|
QuantizationMethod.HQQ,
|
|
4170
4080
|
QuantizationMethod.QUARK,
|
|
4171
4081
|
}
|
|
4172
4082
|
|
|
4173
4083
|
# Model's definition arriving here is final (TP hooks added, quantized layers replaces)
|
|
4174
4084
|
expected_keys = list(model.state_dict().keys())
|
|
4175
|
-
|
|
4176
4085
|
if logger.level >= logging.WARNING:
|
|
4177
4086
|
verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
|
|
4178
4087
|
|
|
4179
4088
|
# This offload index if for params explicitly on the "disk" in the device_map
|
|
4180
4089
|
disk_offload_index = None
|
|
4181
4090
|
# Prepare parameters offloading if needed
|
|
4182
|
-
if
|
|
4091
|
+
if device_map is not None and "disk" in device_map.values():
|
|
4183
4092
|
disk_offload_index = accelerate_disk_offload(
|
|
4184
|
-
|
|
4185
|
-
load_config.disk_offload_folder,
|
|
4093
|
+
disk_offload_folder,
|
|
4186
4094
|
checkpoint_files,
|
|
4187
|
-
|
|
4188
|
-
|
|
4189
|
-
|
|
4190
|
-
|
|
4095
|
+
device_map,
|
|
4096
|
+
expected_keys,
|
|
4097
|
+
sharded_metadata,
|
|
4098
|
+
dtype,
|
|
4099
|
+
weight_mapping,
|
|
4191
4100
|
)
|
|
4192
4101
|
|
|
4193
4102
|
# Warmup cuda to load the weights much faster on devices
|
|
4194
|
-
if
|
|
4195
|
-
expanded_device_map = expand_device_map(
|
|
4196
|
-
caching_allocator_warmup(model, expanded_device_map,
|
|
4103
|
+
if device_map is not None and not is_hqq_or_quark:
|
|
4104
|
+
expanded_device_map = expand_device_map(device_map, expected_keys)
|
|
4105
|
+
caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
|
|
4197
4106
|
|
|
4107
|
+
tp_plan = getattr(model, "_tp_plan", None)
|
|
4198
4108
|
error_msgs = []
|
|
4199
4109
|
|
|
4200
4110
|
if is_deepspeed_zero3_enabled() and not is_quantized:
|
|
4201
4111
|
if state_dict is None:
|
|
4202
4112
|
merged_state_dict = {}
|
|
4203
4113
|
for ckpt_file in checkpoint_files:
|
|
4204
|
-
merged_state_dict.update(
|
|
4205
|
-
load_state_dict(ckpt_file, map_location="cpu", weights_only=load_config.weights_only)
|
|
4206
|
-
)
|
|
4114
|
+
merged_state_dict.update(load_state_dict(ckpt_file, map_location="cpu", weights_only=weights_only))
|
|
4207
4115
|
state_dict = merged_state_dict
|
|
4208
|
-
error_msgs
|
|
4116
|
+
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
|
|
4209
4117
|
# This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
|
|
4210
|
-
unexpected_keys, mismatched_keys,
|
|
4118
|
+
missing_keys, unexpected_keys, mismatched_keys, misc = set(), set(), set(), set()
|
|
4211
4119
|
else:
|
|
4212
4120
|
all_pointer = set()
|
|
4213
|
-
|
|
4214
|
-
|
|
4215
|
-
elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None:
|
|
4121
|
+
# Checkpoints are safetensors
|
|
4122
|
+
if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"):
|
|
4216
4123
|
merged_state_dict = {}
|
|
4217
4124
|
for file in checkpoint_files:
|
|
4218
4125
|
file_pointer = safe_open(file, framework="pt", device="cpu")
|
|
4219
4126
|
all_pointer.add(file_pointer)
|
|
4220
4127
|
for k in file_pointer.keys():
|
|
4221
4128
|
merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
|
|
4129
|
+
# User passed an explicit state_dict
|
|
4130
|
+
elif state_dict is not None:
|
|
4131
|
+
merged_state_dict = state_dict
|
|
4222
4132
|
# Checkpoints are .bin
|
|
4223
4133
|
elif checkpoint_files is not None:
|
|
4224
4134
|
merged_state_dict = {}
|
|
@@ -4227,14 +4137,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4227
4137
|
else:
|
|
4228
4138
|
raise ValueError("Neither a state dict nor checkpoint files were found.")
|
|
4229
4139
|
|
|
4230
|
-
missing_keys, unexpected_keys, mismatched_keys, disk_offload_index,
|
|
4140
|
+
missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, misc = (
|
|
4231
4141
|
convert_and_load_state_dict_in_model(
|
|
4232
|
-
model
|
|
4233
|
-
|
|
4234
|
-
|
|
4235
|
-
tp_plan
|
|
4236
|
-
|
|
4237
|
-
|
|
4142
|
+
model,
|
|
4143
|
+
merged_state_dict,
|
|
4144
|
+
weight_mapping,
|
|
4145
|
+
tp_plan,
|
|
4146
|
+
hf_quantizer,
|
|
4147
|
+
dtype,
|
|
4148
|
+
device_map,
|
|
4149
|
+
model.dtype_plan,
|
|
4150
|
+
device_mesh,
|
|
4151
|
+
disk_offload_index,
|
|
4152
|
+
disk_offload_folder,
|
|
4238
4153
|
)
|
|
4239
4154
|
)
|
|
4240
4155
|
|
|
@@ -4242,58 +4157,65 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4242
4157
|
for k in all_pointer:
|
|
4243
4158
|
k.__exit__(None, None, None)
|
|
4244
4159
|
|
|
4245
|
-
return LoadStateDictInfo(
|
|
4246
|
-
missing_keys=missing_keys,
|
|
4247
|
-
unexpected_keys=unexpected_keys,
|
|
4248
|
-
mismatched_keys=mismatched_keys,
|
|
4249
|
-
disk_offload_index=disk_offload_index,
|
|
4250
|
-
error_msgs=error_msgs,
|
|
4251
|
-
conversion_errors=conversion_errors,
|
|
4252
|
-
)
|
|
4253
|
-
|
|
4254
|
-
@staticmethod
|
|
4255
|
-
def _finalize_load_state_dict(
|
|
4256
|
-
model,
|
|
4257
|
-
load_config: LoadStateDictConfig,
|
|
4258
|
-
load_info: LoadStateDictInfo,
|
|
4259
|
-
) -> LoadStateDictInfo:
|
|
4260
|
-
# TODO @ArthurZucker this will be in a separate function to allows people not to run this
|
|
4261
|
-
# for more granularity
|
|
4262
|
-
|
|
4263
4160
|
# Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
|
|
4264
4161
|
model.mark_tied_weights_as_initialized()
|
|
4265
4162
|
|
|
4266
|
-
# Move missing (and potentially mismatched) keys
|
|
4267
|
-
#
|
|
4268
|
-
|
|
4269
|
-
model.
|
|
4270
|
-
missing_and_mismatched, load_config.device_map, load_config.device_mesh, load_config.hf_quantizer
|
|
4271
|
-
)
|
|
4163
|
+
# Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
|
|
4164
|
+
# loading the weights as they are not in the loaded state dict)
|
|
4165
|
+
miss_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
|
|
4166
|
+
model._move_missing_keys_from_meta_to_cpu(miss_and_mismatched, dtype, hf_quantizer)
|
|
4272
4167
|
|
|
4273
|
-
# Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `
|
|
4274
|
-
model._initialize_missing_keys(
|
|
4168
|
+
# Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialzed` flag)
|
|
4169
|
+
model._initialize_missing_keys(is_quantized)
|
|
4275
4170
|
|
|
4276
4171
|
# Tie the weights
|
|
4277
|
-
model.tie_weights(missing_keys=
|
|
4172
|
+
model.tie_weights(missing_keys=missing_keys, recompute_mapping=False)
|
|
4278
4173
|
|
|
4279
4174
|
# Adjust missing and unexpected keys
|
|
4280
|
-
missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(
|
|
4281
|
-
|
|
4282
|
-
|
|
4175
|
+
missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys)
|
|
4176
|
+
|
|
4177
|
+
# Post-processing for tensor parallelism
|
|
4178
|
+
if device_mesh is not None:
|
|
4179
|
+
# When using TP, the device map is a single device for all parameters
|
|
4180
|
+
tp_device = list(device_map.values())[0]
|
|
4181
|
+
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
|
|
4182
|
+
# not part of the state_dict (persistent=False)
|
|
4183
|
+
for buffer in model.buffers(): # TODO to avaoid this buffer could be added to the ckpt
|
|
4184
|
+
if buffer.device != tp_device:
|
|
4185
|
+
buffer.data = buffer.to(tp_device)
|
|
4186
|
+
|
|
4187
|
+
# In this case, the top-most task module weights were not moved to device and parallelized as they
|
|
4188
|
+
# were not part of the loaded weights: do it now
|
|
4189
|
+
if missing_keys:
|
|
4190
|
+
state_dict = model.state_dict()
|
|
4191
|
+
for name in missing_keys:
|
|
4192
|
+
param = state_dict[name]
|
|
4193
|
+
# Shard the param
|
|
4194
|
+
shard_and_distribute_module(
|
|
4195
|
+
model,
|
|
4196
|
+
param.to(tp_device),
|
|
4197
|
+
param,
|
|
4198
|
+
name,
|
|
4199
|
+
None,
|
|
4200
|
+
False,
|
|
4201
|
+
device_mesh.get_local_rank(),
|
|
4202
|
+
device_mesh,
|
|
4203
|
+
)
|
|
4283
4204
|
|
|
4284
4205
|
log_state_dict_report(
|
|
4285
4206
|
model=model,
|
|
4286
|
-
|
|
4207
|
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
4287
4208
|
logger=logger,
|
|
4288
|
-
error_msgs=
|
|
4209
|
+
error_msgs=error_msgs,
|
|
4289
4210
|
unexpected_keys=unexpected_keys,
|
|
4290
4211
|
missing_keys=missing_keys,
|
|
4291
|
-
mismatched_keys=
|
|
4292
|
-
mismatched_shapes=
|
|
4293
|
-
|
|
4212
|
+
mismatched_keys=mismatched_keys,
|
|
4213
|
+
mismatched_shapes=mismatched_keys,
|
|
4214
|
+
misc=misc,
|
|
4215
|
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
4294
4216
|
)
|
|
4295
4217
|
|
|
4296
|
-
return
|
|
4218
|
+
return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
|
|
4297
4219
|
|
|
4298
4220
|
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
|
|
4299
4221
|
module_keys = {".".join(key.split(".")[:-1]) for key in names}
|
|
@@ -4362,17 +4284,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4362
4284
|
|
|
4363
4285
|
# If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
|
|
4364
4286
|
# attention_mask or not. In this case, we should still show a warning because this is a rare case.
|
|
4365
|
-
# NOTE: `sep_token_id` is not used in all models and it can be absent in the config
|
|
4366
|
-
sep_token_id = getattr(self.config, "sep_token_id", None)
|
|
4367
4287
|
if (
|
|
4368
4288
|
(self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
|
|
4369
4289
|
or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
|
|
4370
|
-
or (sep_token_id is not None and sep_token_id == self.config.pad_token_id)
|
|
4290
|
+
or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id)
|
|
4371
4291
|
):
|
|
4372
4292
|
warn_string += (
|
|
4373
4293
|
f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
|
|
4374
4294
|
f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
|
|
4375
|
-
f"or the `sep_token_id` ({sep_token_id}), and your input is not padded."
|
|
4295
|
+
f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded."
|
|
4376
4296
|
)
|
|
4377
4297
|
|
|
4378
4298
|
logger.warning_once(warn_string)
|
|
@@ -4457,7 +4377,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4457
4377
|
)
|
|
4458
4378
|
self._use_kernels = False
|
|
4459
4379
|
|
|
4460
|
-
def get_compiled_call(self, compile_config: CompileConfig
|
|
4380
|
+
def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
|
|
4461
4381
|
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
|
|
4462
4382
|
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
|
|
4463
4383
|
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
|
|
@@ -4479,54 +4399,33 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4479
4399
|
def is_backend_compatible(cls):
|
|
4480
4400
|
return cls._supports_attention_backend
|
|
4481
4401
|
|
|
4482
|
-
def
|
|
4483
|
-
self,
|
|
4484
|
-
missing_keys: list[str],
|
|
4485
|
-
device_map: dict | None,
|
|
4486
|
-
device_mesh: "torch.distributed.device_mesh.DeviceMesh | None",
|
|
4487
|
-
hf_quantizer: HfQuantizer | None,
|
|
4402
|
+
def _move_missing_keys_from_meta_to_cpu(
|
|
4403
|
+
self, missing_keys: list[str], dtype: torch.dtype, hf_quantizer: Optional[HfQuantizer]
|
|
4488
4404
|
) -> None:
|
|
4489
|
-
"""Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts)
|
|
4490
|
-
|
|
4491
|
-
missing parameters if `device_mesh` is provided, i.e. we are using TP.
|
|
4492
|
-
All non-persistent buffers are also moved back to the correct device (they are not part of the state_dict, but are
|
|
4493
|
-
not missing either).
|
|
4405
|
+
"""Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back
|
|
4406
|
+
from meta device to cpu.
|
|
4494
4407
|
"""
|
|
4495
4408
|
is_quantized = hf_quantizer is not None
|
|
4496
|
-
# This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
|
|
4497
|
-
if is_deepspeed_zero3_enabled() and not is_quantized:
|
|
4498
|
-
return
|
|
4499
4409
|
|
|
4500
4410
|
# In this case we need to move everything back
|
|
4501
4411
|
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
|
4412
|
+
# We only do it for the parameters, as the buffers are not initialized on the meta device by default
|
|
4502
4413
|
for key, param in self.named_parameters():
|
|
4503
|
-
value = torch.empty_like(param, device="cpu")
|
|
4504
|
-
_load_parameter_into_model(self, key, value)
|
|
4505
|
-
for key, buffer in self.named_buffers():
|
|
4506
|
-
value = torch.empty_like(buffer, device="cpu")
|
|
4414
|
+
value = torch.empty_like(param, dtype=dtype, device="cpu")
|
|
4507
4415
|
_load_parameter_into_model(self, key, value)
|
|
4508
4416
|
return
|
|
4509
4417
|
|
|
4418
|
+
model_state_dict = self.state_dict()
|
|
4510
4419
|
# The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway)
|
|
4511
4420
|
# This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they
|
|
4512
4421
|
# will be re-initialized for nothing (which can be quite long)
|
|
4513
4422
|
for key in missing_keys - self.all_tied_weights_keys.keys():
|
|
4514
|
-
param =
|
|
4515
|
-
|
|
4516
|
-
|
|
4517
|
-
|
|
4518
|
-
|
|
4519
|
-
|
|
4520
|
-
self, value, param, key, None, False, device_mesh.get_local_rank(), device_mesh
|
|
4521
|
-
)
|
|
4522
|
-
# Otherwise, just move it to device
|
|
4523
|
-
else:
|
|
4524
|
-
_load_parameter_into_model(self, key, value)
|
|
4525
|
-
# We need to move back non-persistent buffers as well, as they are not part of loaded weights anyway
|
|
4526
|
-
for key, buffer in self.named_non_persistent_buffers():
|
|
4527
|
-
buffer_device = get_device(device_map, key, valid_torch_device=True)
|
|
4528
|
-
value = torch.empty_like(buffer, device=buffer_device)
|
|
4529
|
-
_load_parameter_into_model(self, key, value)
|
|
4423
|
+
param = model_state_dict[key]
|
|
4424
|
+
# Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
|
|
4425
|
+
if param.device == torch.device("meta"):
|
|
4426
|
+
value = torch.empty_like(param, dtype=dtype, device="cpu")
|
|
4427
|
+
if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
|
|
4428
|
+
_load_parameter_into_model(self, key, value)
|
|
4530
4429
|
|
|
4531
4430
|
def _initialize_missing_keys(self, is_quantized: bool) -> None:
|
|
4532
4431
|
"""
|
|
@@ -4554,6 +4453,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4554
4453
|
) -> tuple[set[str], set[str]]:
|
|
4555
4454
|
"""Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
|
|
4556
4455
|
raising unneeded warnings/errors.
|
|
4456
|
+
Also, set the `_is_hf_initialized` on tied weight keys, to avoid initializing them as they are going to
|
|
4457
|
+
be tied anyway.
|
|
4557
4458
|
"""
|
|
4558
4459
|
# Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
|
|
4559
4460
|
# (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
|
|
@@ -4612,19 +4513,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4612
4513
|
|
|
4613
4514
|
raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
|
|
4614
4515
|
|
|
4615
|
-
def named_non_persistent_buffers(
|
|
4616
|
-
self, recurse: bool = True, remove_duplicate: bool = True
|
|
4617
|
-
) -> Iterator[tuple[str, torch.Tensor]]:
|
|
4618
|
-
"""Similar to `named_buffers`, but only yield non-persistent ones. It is handy as it's not perfectly straightforward
|
|
4619
|
-
to know if they are persistent or not"""
|
|
4620
|
-
for name, tensor in self.named_buffers(recurse=recurse, remove_duplicate=remove_duplicate):
|
|
4621
|
-
# We have to grab the parent here, as the attribute `_non_persistent_buffers_set` is on the immediate
|
|
4622
|
-
# parent only
|
|
4623
|
-
parent, buf_name = name.rsplit(".", 1) if "." in name else ("", name)
|
|
4624
|
-
parent = self.get_submodule(parent)
|
|
4625
|
-
if buf_name in parent._non_persistent_buffers_set:
|
|
4626
|
-
yield name, tensor
|
|
4627
|
-
|
|
4628
4516
|
def train(self, mode: bool = True):
|
|
4629
4517
|
out = super().train(mode)
|
|
4630
4518
|
if self.use_kernels:
|
|
@@ -4667,7 +4555,7 @@ def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
|
|
|
4667
4555
|
return model
|
|
4668
4556
|
|
|
4669
4557
|
|
|
4670
|
-
def is_accelerator_device(device: str
|
|
4558
|
+
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
|
4671
4559
|
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
|
|
4672
4560
|
a proper `torch.device`.
|
|
4673
4561
|
"""
|
|
@@ -4677,41 +4565,7 @@ def is_accelerator_device(device: str | int | torch.device) -> bool:
|
|
|
4677
4565
|
return torch.device(device).type not in ["meta", "cpu"]
|
|
4678
4566
|
|
|
4679
4567
|
|
|
4680
|
-
def
|
|
4681
|
-
model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: HfQuantizer | None = None
|
|
4682
|
-
):
|
|
4683
|
-
"""
|
|
4684
|
-
This utility function calculates the total bytes count needed to load the model on each device.
|
|
4685
|
-
This is useful for caching_allocator_warmup as we want to know how much cache we need to pre-allocate.
|
|
4686
|
-
"""
|
|
4687
|
-
|
|
4688
|
-
total_byte_count = defaultdict(lambda: 0)
|
|
4689
|
-
tied_param_names = model.all_tied_weights_keys.keys()
|
|
4690
|
-
tp_plan = model._tp_plan if torch.distributed.is_available() and torch.distributed.is_initialized() else []
|
|
4691
|
-
|
|
4692
|
-
for param_name, device in accelerator_device_map.items():
|
|
4693
|
-
# Skip if the parameter has already been accounted for (tied weights)
|
|
4694
|
-
if param_name in tied_param_names:
|
|
4695
|
-
continue
|
|
4696
|
-
|
|
4697
|
-
param = model.get_parameter_or_buffer(param_name)
|
|
4698
|
-
|
|
4699
|
-
if hf_quantizer is not None:
|
|
4700
|
-
dtype_size = hf_quantizer.param_element_size(model, param_name, param)
|
|
4701
|
-
else:
|
|
4702
|
-
dtype_size = param.element_size()
|
|
4703
|
-
|
|
4704
|
-
param_byte_count = param.numel() * dtype_size
|
|
4705
|
-
|
|
4706
|
-
if len(tp_plan) > 0:
|
|
4707
|
-
is_part_of_plan = _get_parameter_tp_plan(param_name, tp_plan, is_weight=True) is not None
|
|
4708
|
-
param_byte_count //= torch.distributed.get_world_size() if is_part_of_plan else 1
|
|
4709
|
-
|
|
4710
|
-
total_byte_count[device] += param_byte_count
|
|
4711
|
-
return total_byte_count
|
|
4712
|
-
|
|
4713
|
-
|
|
4714
|
-
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: HfQuantizer | None):
|
|
4568
|
+
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: Optional[HfQuantizer]):
|
|
4715
4569
|
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
|
4716
4570
|
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
|
4717
4571
|
the model, which is actually the loading speed bottleneck.
|
|
@@ -4730,6 +4584,8 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4730
4584
|
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
|
|
4731
4585
|
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
|
|
4732
4586
|
"""
|
|
4587
|
+
factor = 2 if hf_quantizer is None else hf_quantizer.get_accelerator_warm_up_factor()
|
|
4588
|
+
|
|
4733
4589
|
# Remove disk, cpu and meta devices, and cast to proper torch.device
|
|
4734
4590
|
accelerator_device_map = {
|
|
4735
4591
|
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
|
|
@@ -4737,7 +4593,40 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4737
4593
|
if not accelerator_device_map:
|
|
4738
4594
|
return
|
|
4739
4595
|
|
|
4740
|
-
|
|
4596
|
+
tp_plan = getattr(model, "_tp_plan", []) or []
|
|
4597
|
+
tp_plan_regex = (
|
|
4598
|
+
re.compile("|".join([re.escape(plan) for plan in tp_plan]))
|
|
4599
|
+
if _torch_distributed_available and torch.distributed.is_initialized()
|
|
4600
|
+
else None
|
|
4601
|
+
)
|
|
4602
|
+
total_byte_count = defaultdict(lambda: 0)
|
|
4603
|
+
tied_param_names = model.all_tied_weights_keys.keys()
|
|
4604
|
+
for param_name, device in accelerator_device_map.items():
|
|
4605
|
+
# Skip if the parameter has already been accounted for (tied weights)
|
|
4606
|
+
if param_name in tied_param_names:
|
|
4607
|
+
continue
|
|
4608
|
+
|
|
4609
|
+
# For example in the case of MXFP4 quantization, we need to update the param name to the original param name
|
|
4610
|
+
# because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
|
|
4611
|
+
if hf_quantizer is not None:
|
|
4612
|
+
param_name = hf_quantizer.get_param_name(param_name)
|
|
4613
|
+
|
|
4614
|
+
try:
|
|
4615
|
+
param = model.get_parameter_or_buffer(param_name)
|
|
4616
|
+
except AttributeError:
|
|
4617
|
+
# TODO: for now let's skip if we can't find the parameters
|
|
4618
|
+
if hf_quantizer is not None:
|
|
4619
|
+
continue
|
|
4620
|
+
raise AttributeError(f"Parameter {param_name} not found in model")
|
|
4621
|
+
|
|
4622
|
+
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
|
4623
|
+
param_byte_count = param.numel() * param.element_size()
|
|
4624
|
+
|
|
4625
|
+
if tp_plan_regex is not None:
|
|
4626
|
+
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
|
|
4627
|
+
param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
|
|
4628
|
+
|
|
4629
|
+
total_byte_count[device] += param_byte_count
|
|
4741
4630
|
|
|
4742
4631
|
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
|
4743
4632
|
for device, byte_count in total_byte_count.items():
|
|
@@ -4757,9 +4646,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4757
4646
|
unused_memory = torch_accelerator_module.memory_reserved(
|
|
4758
4647
|
index
|
|
4759
4648
|
) - torch_accelerator_module.memory_allocated(index)
|
|
4760
|
-
byte_count =
|
|
4761
|
-
#
|
|
4762
|
-
_ = torch.empty(byte_count //
|
|
4649
|
+
byte_count = max(0, byte_count - unused_memory)
|
|
4650
|
+
# Allocate memory
|
|
4651
|
+
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
|
4763
4652
|
|
|
4764
4653
|
|
|
4765
4654
|
class AttentionInterface(GeneralInterface):
|