transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +11 -37
- transformers/activations.py +2 -2
- transformers/audio_utils.py +32 -32
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +26 -126
- transformers/cli/chat.py +3 -3
- transformers/cli/serve.py +13 -10
- transformers/cli/transformers.py +2 -1
- transformers/configuration_utils.py +22 -92
- transformers/conversion_mapping.py +150 -26
- transformers/convert_slow_tokenizer.py +9 -12
- transformers/core_model_loading.py +217 -129
- transformers/data/processors/glue.py +0 -1
- transformers/data/processors/utils.py +0 -1
- transformers/data/processors/xnli.py +0 -1
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +10 -11
- transformers/distributed/configuration_utils.py +1 -2
- transformers/dynamic_module_utils.py +23 -23
- transformers/feature_extraction_sequence_utils.py +19 -23
- transformers/feature_extraction_utils.py +14 -14
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +2 -4
- transformers/generation/configuration_utils.py +54 -39
- transformers/generation/continuous_batching/__init__.py +0 -1
- transformers/generation/continuous_batching/cache.py +74 -44
- transformers/generation/continuous_batching/cache_manager.py +28 -28
- transformers/generation/continuous_batching/continuous_api.py +133 -414
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +77 -19
- transformers/generation/continuous_batching/scheduler.py +154 -104
- transformers/generation/logits_process.py +10 -133
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/streamers.py +0 -1
- transformers/generation/utils.py +91 -121
- transformers/generation/watermarking.py +2 -3
- transformers/hf_argparser.py +9 -13
- transformers/hyperparameter_search.py +1 -2
- transformers/image_processing_base.py +9 -9
- transformers/image_processing_utils.py +11 -15
- transformers/image_processing_utils_fast.py +70 -71
- transformers/image_transforms.py +73 -42
- transformers/image_utils.py +30 -37
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/awq.py +1 -3
- transformers/integrations/deepspeed.py +146 -4
- transformers/integrations/eetq.py +0 -1
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/fbgemm_fp8.py +1 -2
- transformers/integrations/finegrained_fp8.py +149 -13
- transformers/integrations/flash_attention.py +3 -8
- transformers/integrations/flex_attention.py +1 -1
- transformers/integrations/fp_quant.py +4 -6
- transformers/integrations/ggml.py +0 -1
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/integration_utils.py +2 -3
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +52 -40
- transformers/integrations/peft.py +488 -176
- transformers/integrations/quark.py +2 -4
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/integrations/torchao.py +4 -6
- transformers/loss/loss_lw_detr.py +356 -0
- transformers/loss/loss_utils.py +2 -0
- transformers/masking_utils.py +199 -59
- transformers/model_debugging_utils.py +4 -5
- transformers/modelcard.py +14 -192
- transformers/modeling_attn_mask_utils.py +19 -19
- transformers/modeling_flash_attention_utils.py +28 -29
- transformers/modeling_gguf_pytorch_utils.py +5 -5
- transformers/modeling_layers.py +21 -22
- transformers/modeling_outputs.py +242 -253
- transformers/modeling_rope_utils.py +32 -32
- transformers/modeling_utils.py +416 -438
- transformers/models/__init__.py +10 -0
- transformers/models/afmoe/configuration_afmoe.py +40 -33
- transformers/models/afmoe/modeling_afmoe.py +38 -41
- transformers/models/afmoe/modular_afmoe.py +23 -25
- transformers/models/aimv2/configuration_aimv2.py +2 -10
- transformers/models/aimv2/modeling_aimv2.py +46 -45
- transformers/models/aimv2/modular_aimv2.py +13 -19
- transformers/models/albert/configuration_albert.py +8 -2
- transformers/models/albert/modeling_albert.py +70 -72
- transformers/models/albert/tokenization_albert.py +1 -4
- transformers/models/align/configuration_align.py +8 -6
- transformers/models/align/modeling_align.py +83 -86
- transformers/models/align/processing_align.py +2 -30
- transformers/models/altclip/configuration_altclip.py +4 -7
- transformers/models/altclip/modeling_altclip.py +106 -103
- transformers/models/altclip/processing_altclip.py +2 -15
- transformers/models/apertus/__init__.py +0 -1
- transformers/models/apertus/configuration_apertus.py +23 -28
- transformers/models/apertus/modeling_apertus.py +35 -38
- transformers/models/apertus/modular_apertus.py +36 -40
- transformers/models/arcee/configuration_arcee.py +25 -30
- transformers/models/arcee/modeling_arcee.py +35 -38
- transformers/models/arcee/modular_arcee.py +20 -23
- transformers/models/aria/configuration_aria.py +31 -44
- transformers/models/aria/image_processing_aria.py +25 -27
- transformers/models/aria/modeling_aria.py +102 -102
- transformers/models/aria/modular_aria.py +111 -124
- transformers/models/aria/processing_aria.py +28 -35
- transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
- transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
- transformers/models/audioflamingo3/__init__.py +0 -1
- transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
- transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
- transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
- transformers/models/auto/auto_factory.py +12 -11
- transformers/models/auto/configuration_auto.py +48 -5
- transformers/models/auto/feature_extraction_auto.py +5 -7
- transformers/models/auto/image_processing_auto.py +30 -39
- transformers/models/auto/modeling_auto.py +33 -199
- transformers/models/auto/processing_auto.py +11 -19
- transformers/models/auto/tokenization_auto.py +38 -37
- transformers/models/auto/video_processing_auto.py +7 -8
- transformers/models/autoformer/configuration_autoformer.py +4 -7
- transformers/models/autoformer/modeling_autoformer.py +100 -101
- transformers/models/aya_vision/configuration_aya_vision.py +4 -1
- transformers/models/aya_vision/modeling_aya_vision.py +64 -99
- transformers/models/aya_vision/modular_aya_vision.py +46 -74
- transformers/models/aya_vision/processing_aya_vision.py +25 -53
- transformers/models/bamba/configuration_bamba.py +46 -39
- transformers/models/bamba/modeling_bamba.py +83 -119
- transformers/models/bamba/modular_bamba.py +70 -109
- transformers/models/bark/configuration_bark.py +6 -8
- transformers/models/bark/generation_configuration_bark.py +3 -5
- transformers/models/bark/modeling_bark.py +64 -65
- transformers/models/bark/processing_bark.py +19 -41
- transformers/models/bart/configuration_bart.py +9 -5
- transformers/models/bart/modeling_bart.py +124 -129
- transformers/models/barthez/tokenization_barthez.py +1 -4
- transformers/models/bartpho/tokenization_bartpho.py +6 -7
- transformers/models/beit/configuration_beit.py +2 -15
- transformers/models/beit/image_processing_beit.py +53 -56
- transformers/models/beit/image_processing_beit_fast.py +11 -12
- transformers/models/beit/modeling_beit.py +65 -62
- transformers/models/bert/configuration_bert.py +12 -2
- transformers/models/bert/modeling_bert.py +117 -152
- transformers/models/bert/tokenization_bert.py +2 -4
- transformers/models/bert/tokenization_bert_legacy.py +3 -5
- transformers/models/bert_generation/configuration_bert_generation.py +17 -2
- transformers/models/bert_generation/modeling_bert_generation.py +53 -55
- transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
- transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
- transformers/models/bertweet/tokenization_bertweet.py +1 -3
- transformers/models/big_bird/configuration_big_bird.py +12 -9
- transformers/models/big_bird/modeling_big_bird.py +107 -124
- transformers/models/big_bird/tokenization_big_bird.py +1 -4
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
- transformers/models/biogpt/configuration_biogpt.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +73 -79
- transformers/models/biogpt/modular_biogpt.py +60 -66
- transformers/models/biogpt/tokenization_biogpt.py +3 -5
- transformers/models/bit/configuration_bit.py +2 -5
- transformers/models/bit/image_processing_bit.py +21 -24
- transformers/models/bit/image_processing_bit_fast.py +0 -1
- transformers/models/bit/modeling_bit.py +15 -16
- transformers/models/bitnet/configuration_bitnet.py +23 -28
- transformers/models/bitnet/modeling_bitnet.py +34 -38
- transformers/models/bitnet/modular_bitnet.py +7 -10
- transformers/models/blenderbot/configuration_blenderbot.py +8 -5
- transformers/models/blenderbot/modeling_blenderbot.py +68 -99
- transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
- transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
- transformers/models/blip/configuration_blip.py +9 -10
- transformers/models/blip/image_processing_blip.py +17 -20
- transformers/models/blip/image_processing_blip_fast.py +0 -1
- transformers/models/blip/modeling_blip.py +115 -108
- transformers/models/blip/modeling_blip_text.py +63 -65
- transformers/models/blip/processing_blip.py +5 -36
- transformers/models/blip_2/configuration_blip_2.py +2 -2
- transformers/models/blip_2/modeling_blip_2.py +145 -121
- transformers/models/blip_2/processing_blip_2.py +8 -38
- transformers/models/bloom/configuration_bloom.py +5 -2
- transformers/models/bloom/modeling_bloom.py +60 -60
- transformers/models/blt/configuration_blt.py +94 -86
- transformers/models/blt/modeling_blt.py +93 -90
- transformers/models/blt/modular_blt.py +127 -69
- transformers/models/bridgetower/configuration_bridgetower.py +7 -2
- transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
- transformers/models/bridgetower/modeling_bridgetower.py +136 -124
- transformers/models/bridgetower/processing_bridgetower.py +2 -16
- transformers/models/bros/configuration_bros.py +24 -18
- transformers/models/bros/modeling_bros.py +78 -80
- transformers/models/bros/processing_bros.py +2 -12
- transformers/models/byt5/tokenization_byt5.py +4 -6
- transformers/models/camembert/configuration_camembert.py +8 -2
- transformers/models/camembert/modeling_camembert.py +97 -99
- transformers/models/camembert/modular_camembert.py +51 -54
- transformers/models/camembert/tokenization_camembert.py +1 -4
- transformers/models/canine/configuration_canine.py +4 -2
- transformers/models/canine/modeling_canine.py +73 -75
- transformers/models/canine/tokenization_canine.py +0 -1
- transformers/models/chameleon/configuration_chameleon.py +29 -34
- transformers/models/chameleon/image_processing_chameleon.py +21 -24
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
- transformers/models/chameleon/modeling_chameleon.py +135 -92
- transformers/models/chameleon/processing_chameleon.py +16 -41
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
- transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
- transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
- transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
- transformers/models/clap/configuration_clap.py +4 -9
- transformers/models/clap/feature_extraction_clap.py +9 -10
- transformers/models/clap/modeling_clap.py +109 -111
- transformers/models/clap/processing_clap.py +2 -15
- transformers/models/clip/configuration_clip.py +4 -2
- transformers/models/clip/image_processing_clip.py +21 -24
- transformers/models/clip/image_processing_clip_fast.py +9 -1
- transformers/models/clip/modeling_clip.py +70 -68
- transformers/models/clip/processing_clip.py +2 -14
- transformers/models/clip/tokenization_clip.py +2 -5
- transformers/models/clipseg/configuration_clipseg.py +4 -2
- transformers/models/clipseg/modeling_clipseg.py +113 -112
- transformers/models/clipseg/processing_clipseg.py +19 -42
- transformers/models/clvp/configuration_clvp.py +15 -5
- transformers/models/clvp/feature_extraction_clvp.py +7 -10
- transformers/models/clvp/modeling_clvp.py +138 -145
- transformers/models/clvp/number_normalizer.py +1 -2
- transformers/models/clvp/processing_clvp.py +3 -20
- transformers/models/clvp/tokenization_clvp.py +0 -1
- transformers/models/code_llama/tokenization_code_llama.py +3 -6
- transformers/models/codegen/configuration_codegen.py +4 -4
- transformers/models/codegen/modeling_codegen.py +50 -49
- transformers/models/codegen/tokenization_codegen.py +5 -6
- transformers/models/cohere/configuration_cohere.py +25 -30
- transformers/models/cohere/modeling_cohere.py +39 -42
- transformers/models/cohere/modular_cohere.py +27 -31
- transformers/models/cohere/tokenization_cohere.py +5 -6
- transformers/models/cohere2/configuration_cohere2.py +27 -32
- transformers/models/cohere2/modeling_cohere2.py +38 -41
- transformers/models/cohere2/modular_cohere2.py +48 -52
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
- transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
- transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
- transformers/models/colpali/configuration_colpali.py +0 -1
- transformers/models/colpali/modeling_colpali.py +14 -16
- transformers/models/colpali/modular_colpali.py +11 -51
- transformers/models/colpali/processing_colpali.py +14 -52
- transformers/models/colqwen2/modeling_colqwen2.py +27 -28
- transformers/models/colqwen2/modular_colqwen2.py +36 -74
- transformers/models/colqwen2/processing_colqwen2.py +16 -52
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
- transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
- transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
- transformers/models/convbert/configuration_convbert.py +11 -8
- transformers/models/convbert/modeling_convbert.py +85 -87
- transformers/models/convbert/tokenization_convbert.py +0 -1
- transformers/models/convnext/configuration_convnext.py +2 -5
- transformers/models/convnext/image_processing_convnext.py +18 -21
- transformers/models/convnext/image_processing_convnext_fast.py +7 -8
- transformers/models/convnext/modeling_convnext.py +12 -14
- transformers/models/convnextv2/configuration_convnextv2.py +2 -5
- transformers/models/convnextv2/modeling_convnextv2.py +12 -14
- transformers/models/cpm/tokenization_cpm.py +6 -7
- transformers/models/cpm/tokenization_cpm_fast.py +3 -5
- transformers/models/cpmant/configuration_cpmant.py +4 -1
- transformers/models/cpmant/modeling_cpmant.py +38 -40
- transformers/models/cpmant/tokenization_cpmant.py +1 -3
- transformers/models/csm/configuration_csm.py +58 -66
- transformers/models/csm/generation_csm.py +13 -14
- transformers/models/csm/modeling_csm.py +81 -84
- transformers/models/csm/modular_csm.py +56 -58
- transformers/models/csm/processing_csm.py +25 -68
- transformers/models/ctrl/configuration_ctrl.py +16 -1
- transformers/models/ctrl/modeling_ctrl.py +51 -66
- transformers/models/ctrl/tokenization_ctrl.py +0 -1
- transformers/models/cvt/configuration_cvt.py +0 -1
- transformers/models/cvt/modeling_cvt.py +13 -15
- transformers/models/cwm/__init__.py +0 -1
- transformers/models/cwm/configuration_cwm.py +8 -12
- transformers/models/cwm/modeling_cwm.py +36 -38
- transformers/models/cwm/modular_cwm.py +10 -12
- transformers/models/d_fine/configuration_d_fine.py +10 -57
- transformers/models/d_fine/modeling_d_fine.py +786 -927
- transformers/models/d_fine/modular_d_fine.py +339 -417
- transformers/models/dab_detr/configuration_dab_detr.py +22 -49
- transformers/models/dab_detr/modeling_dab_detr.py +79 -77
- transformers/models/dac/configuration_dac.py +0 -1
- transformers/models/dac/feature_extraction_dac.py +6 -9
- transformers/models/dac/modeling_dac.py +22 -24
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
- transformers/models/data2vec/configuration_data2vec_text.py +11 -3
- transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
- transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
- transformers/models/data2vec/modeling_data2vec_text.py +97 -99
- transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
- transformers/models/data2vec/modular_data2vec_audio.py +6 -1
- transformers/models/data2vec/modular_data2vec_text.py +51 -54
- transformers/models/dbrx/configuration_dbrx.py +29 -22
- transformers/models/dbrx/modeling_dbrx.py +45 -48
- transformers/models/dbrx/modular_dbrx.py +37 -39
- transformers/models/deberta/configuration_deberta.py +6 -1
- transformers/models/deberta/modeling_deberta.py +57 -60
- transformers/models/deberta/tokenization_deberta.py +2 -5
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
- transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
- transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
- transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
- transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
- transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
- transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
- transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
- transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
- transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
- transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
- transformers/models/deit/configuration_deit.py +0 -1
- transformers/models/deit/image_processing_deit.py +18 -21
- transformers/models/deit/image_processing_deit_fast.py +0 -1
- transformers/models/deit/modeling_deit.py +27 -25
- transformers/models/depth_anything/configuration_depth_anything.py +12 -43
- transformers/models/depth_anything/modeling_depth_anything.py +10 -11
- transformers/models/depth_pro/configuration_depth_pro.py +0 -1
- transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
- transformers/models/depth_pro/modeling_depth_pro.py +29 -27
- transformers/models/detr/configuration_detr.py +18 -50
- transformers/models/detr/image_processing_detr.py +64 -66
- transformers/models/detr/image_processing_detr_fast.py +33 -34
- transformers/models/detr/modeling_detr.py +748 -789
- transformers/models/dia/configuration_dia.py +9 -15
- transformers/models/dia/feature_extraction_dia.py +6 -9
- transformers/models/dia/generation_dia.py +48 -53
- transformers/models/dia/modeling_dia.py +68 -71
- transformers/models/dia/modular_dia.py +56 -58
- transformers/models/dia/processing_dia.py +39 -29
- transformers/models/dia/tokenization_dia.py +3 -6
- transformers/models/diffllama/configuration_diffllama.py +25 -30
- transformers/models/diffllama/modeling_diffllama.py +45 -53
- transformers/models/diffllama/modular_diffllama.py +18 -25
- transformers/models/dinat/configuration_dinat.py +2 -5
- transformers/models/dinat/modeling_dinat.py +47 -48
- transformers/models/dinov2/configuration_dinov2.py +2 -5
- transformers/models/dinov2/modeling_dinov2.py +20 -21
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
- transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
- transformers/models/distilbert/configuration_distilbert.py +8 -2
- transformers/models/distilbert/modeling_distilbert.py +47 -49
- transformers/models/distilbert/tokenization_distilbert.py +0 -1
- transformers/models/doge/__init__.py +0 -1
- transformers/models/doge/configuration_doge.py +42 -35
- transformers/models/doge/modeling_doge.py +46 -49
- transformers/models/doge/modular_doge.py +77 -68
- transformers/models/donut/configuration_donut_swin.py +0 -1
- transformers/models/donut/image_processing_donut.py +26 -29
- transformers/models/donut/image_processing_donut_fast.py +9 -14
- transformers/models/donut/modeling_donut_swin.py +44 -46
- transformers/models/donut/processing_donut.py +5 -26
- transformers/models/dots1/configuration_dots1.py +43 -36
- transformers/models/dots1/modeling_dots1.py +35 -38
- transformers/models/dots1/modular_dots1.py +0 -1
- transformers/models/dpr/configuration_dpr.py +19 -2
- transformers/models/dpr/modeling_dpr.py +37 -39
- transformers/models/dpr/tokenization_dpr.py +7 -9
- transformers/models/dpr/tokenization_dpr_fast.py +7 -9
- transformers/models/dpt/configuration_dpt.py +23 -66
- transformers/models/dpt/image_processing_dpt.py +65 -66
- transformers/models/dpt/image_processing_dpt_fast.py +18 -19
- transformers/models/dpt/modeling_dpt.py +38 -36
- transformers/models/dpt/modular_dpt.py +14 -15
- transformers/models/edgetam/configuration_edgetam.py +1 -2
- transformers/models/edgetam/modeling_edgetam.py +87 -89
- transformers/models/edgetam/modular_edgetam.py +7 -13
- transformers/models/edgetam_video/__init__.py +0 -1
- transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
- transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
- transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
- transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
- transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
- transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
- transformers/models/efficientnet/configuration_efficientnet.py +0 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
- transformers/models/efficientnet/modeling_efficientnet.py +12 -14
- transformers/models/electra/configuration_electra.py +13 -3
- transformers/models/electra/modeling_electra.py +107 -109
- transformers/models/emu3/configuration_emu3.py +17 -17
- transformers/models/emu3/image_processing_emu3.py +44 -39
- transformers/models/emu3/modeling_emu3.py +143 -109
- transformers/models/emu3/modular_emu3.py +109 -73
- transformers/models/emu3/processing_emu3.py +18 -43
- transformers/models/encodec/configuration_encodec.py +2 -4
- transformers/models/encodec/feature_extraction_encodec.py +10 -13
- transformers/models/encodec/modeling_encodec.py +25 -29
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
- transformers/models/eomt/configuration_eomt.py +12 -14
- transformers/models/eomt/image_processing_eomt.py +53 -55
- transformers/models/eomt/image_processing_eomt_fast.py +18 -19
- transformers/models/eomt/modeling_eomt.py +19 -21
- transformers/models/eomt/modular_eomt.py +28 -30
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -3
- transformers/models/ernie/modeling_ernie.py +127 -162
- transformers/models/ernie/modular_ernie.py +91 -103
- transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
- transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
- transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
- transformers/models/esm/configuration_esm.py +11 -15
- transformers/models/esm/modeling_esm.py +35 -37
- transformers/models/esm/modeling_esmfold.py +43 -50
- transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
- transformers/models/esm/openfold_utils/loss.py +1 -2
- transformers/models/esm/openfold_utils/protein.py +15 -16
- transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
- transformers/models/esm/tokenization_esm.py +2 -4
- transformers/models/evolla/configuration_evolla.py +50 -40
- transformers/models/evolla/modeling_evolla.py +69 -68
- transformers/models/evolla/modular_evolla.py +50 -48
- transformers/models/evolla/processing_evolla.py +23 -35
- transformers/models/exaone4/configuration_exaone4.py +27 -27
- transformers/models/exaone4/modeling_exaone4.py +36 -39
- transformers/models/exaone4/modular_exaone4.py +51 -50
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +31 -26
- transformers/models/falcon/modeling_falcon.py +76 -84
- transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
- transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
- transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
- transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
- transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
- transformers/models/flaubert/configuration_flaubert.py +10 -5
- transformers/models/flaubert/modeling_flaubert.py +125 -129
- transformers/models/flaubert/tokenization_flaubert.py +3 -5
- transformers/models/flava/configuration_flava.py +9 -9
- transformers/models/flava/image_processing_flava.py +66 -67
- transformers/models/flava/image_processing_flava_fast.py +46 -47
- transformers/models/flava/modeling_flava.py +144 -135
- transformers/models/flava/processing_flava.py +2 -12
- transformers/models/flex_olmo/__init__.py +0 -1
- transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
- transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
- transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
- transformers/models/florence2/configuration_florence2.py +4 -1
- transformers/models/florence2/modeling_florence2.py +96 -72
- transformers/models/florence2/modular_florence2.py +100 -107
- transformers/models/florence2/processing_florence2.py +18 -47
- transformers/models/fnet/configuration_fnet.py +6 -2
- transformers/models/fnet/modeling_fnet.py +69 -80
- transformers/models/fnet/tokenization_fnet.py +0 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -5
- transformers/models/focalnet/modeling_focalnet.py +49 -48
- transformers/models/fsmt/configuration_fsmt.py +12 -17
- transformers/models/fsmt/modeling_fsmt.py +47 -48
- transformers/models/fsmt/tokenization_fsmt.py +3 -5
- transformers/models/funnel/configuration_funnel.py +8 -1
- transformers/models/funnel/modeling_funnel.py +91 -93
- transformers/models/funnel/tokenization_funnel.py +2 -5
- transformers/models/fuyu/configuration_fuyu.py +28 -34
- transformers/models/fuyu/image_processing_fuyu.py +29 -31
- transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
- transformers/models/fuyu/modeling_fuyu.py +50 -52
- transformers/models/fuyu/processing_fuyu.py +9 -36
- transformers/models/gemma/configuration_gemma.py +25 -30
- transformers/models/gemma/modeling_gemma.py +36 -38
- transformers/models/gemma/modular_gemma.py +33 -36
- transformers/models/gemma/tokenization_gemma.py +3 -6
- transformers/models/gemma2/configuration_gemma2.py +30 -35
- transformers/models/gemma2/modeling_gemma2.py +38 -41
- transformers/models/gemma2/modular_gemma2.py +63 -67
- transformers/models/gemma3/configuration_gemma3.py +53 -48
- transformers/models/gemma3/image_processing_gemma3.py +29 -31
- transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
- transformers/models/gemma3/modeling_gemma3.py +123 -122
- transformers/models/gemma3/modular_gemma3.py +128 -125
- transformers/models/gemma3/processing_gemma3.py +5 -5
- transformers/models/gemma3n/configuration_gemma3n.py +42 -30
- transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
- transformers/models/gemma3n/modeling_gemma3n.py +166 -147
- transformers/models/gemma3n/modular_gemma3n.py +176 -148
- transformers/models/gemma3n/processing_gemma3n.py +12 -26
- transformers/models/git/configuration_git.py +5 -8
- transformers/models/git/modeling_git.py +115 -127
- transformers/models/git/processing_git.py +2 -14
- transformers/models/glm/configuration_glm.py +26 -30
- transformers/models/glm/modeling_glm.py +36 -39
- transformers/models/glm/modular_glm.py +4 -7
- transformers/models/glm4/configuration_glm4.py +26 -30
- transformers/models/glm4/modeling_glm4.py +39 -41
- transformers/models/glm4/modular_glm4.py +8 -10
- transformers/models/glm46v/configuration_glm46v.py +4 -1
- transformers/models/glm46v/image_processing_glm46v.py +40 -38
- transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
- transformers/models/glm46v/modeling_glm46v.py +138 -93
- transformers/models/glm46v/modular_glm46v.py +5 -3
- transformers/models/glm46v/processing_glm46v.py +7 -41
- transformers/models/glm46v/video_processing_glm46v.py +9 -11
- transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
- transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
- transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
- transformers/models/glm4_moe_lite/__init__.py +28 -0
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
- transformers/models/glm4v/configuration_glm4v.py +25 -24
- transformers/models/glm4v/image_processing_glm4v.py +39 -38
- transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
- transformers/models/glm4v/modeling_glm4v.py +249 -210
- transformers/models/glm4v/modular_glm4v.py +211 -230
- transformers/models/glm4v/processing_glm4v.py +7 -41
- transformers/models/glm4v/video_processing_glm4v.py +9 -11
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
- transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
- transformers/models/glm_image/__init__.py +31 -0
- transformers/models/glm_image/configuration_glm_image.py +358 -0
- transformers/models/glm_image/image_processing_glm_image.py +503 -0
- transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
- transformers/models/glm_image/modeling_glm_image.py +1691 -0
- transformers/models/glm_image/modular_glm_image.py +1640 -0
- transformers/models/glm_image/processing_glm_image.py +265 -0
- transformers/models/glm_ocr/__init__.py +28 -0
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/__init__.py +0 -1
- transformers/models/glmasr/configuration_glmasr.py +0 -1
- transformers/models/glmasr/modeling_glmasr.py +51 -46
- transformers/models/glmasr/modular_glmasr.py +39 -29
- transformers/models/glmasr/processing_glmasr.py +7 -8
- transformers/models/glpn/configuration_glpn.py +0 -1
- transformers/models/glpn/image_processing_glpn.py +11 -12
- transformers/models/glpn/image_processing_glpn_fast.py +11 -12
- transformers/models/glpn/modeling_glpn.py +14 -14
- transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
- transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
- transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
- transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
- transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
- transformers/models/gpt2/configuration_gpt2.py +13 -2
- transformers/models/gpt2/modeling_gpt2.py +111 -113
- transformers/models/gpt2/tokenization_gpt2.py +6 -9
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
- transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
- transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
- transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
- transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
- transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
- transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
- transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
- transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
- transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
- transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
- transformers/models/gptj/configuration_gptj.py +4 -5
- transformers/models/gptj/modeling_gptj.py +85 -88
- transformers/models/granite/configuration_granite.py +28 -33
- transformers/models/granite/modeling_granite.py +43 -45
- transformers/models/granite/modular_granite.py +29 -31
- transformers/models/granite_speech/configuration_granite_speech.py +0 -1
- transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
- transformers/models/granite_speech/modeling_granite_speech.py +84 -60
- transformers/models/granite_speech/processing_granite_speech.py +11 -4
- transformers/models/granitemoe/configuration_granitemoe.py +31 -36
- transformers/models/granitemoe/modeling_granitemoe.py +39 -41
- transformers/models/granitemoe/modular_granitemoe.py +21 -23
- transformers/models/granitemoehybrid/__init__.py +0 -1
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
- transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
- transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
- transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
- transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
- transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
- transformers/models/groupvit/configuration_groupvit.py +4 -2
- transformers/models/groupvit/modeling_groupvit.py +98 -92
- transformers/models/helium/configuration_helium.py +25 -29
- transformers/models/helium/modeling_helium.py +37 -40
- transformers/models/helium/modular_helium.py +3 -7
- transformers/models/herbert/tokenization_herbert.py +4 -6
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
- transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
- transformers/models/hiera/configuration_hiera.py +2 -5
- transformers/models/hiera/modeling_hiera.py +71 -70
- transformers/models/hubert/configuration_hubert.py +4 -2
- transformers/models/hubert/modeling_hubert.py +42 -41
- transformers/models/hubert/modular_hubert.py +8 -11
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
- transformers/models/ibert/configuration_ibert.py +4 -2
- transformers/models/ibert/modeling_ibert.py +60 -62
- transformers/models/ibert/quant_modules.py +0 -1
- transformers/models/idefics/configuration_idefics.py +5 -8
- transformers/models/idefics/image_processing_idefics.py +13 -15
- transformers/models/idefics/modeling_idefics.py +63 -65
- transformers/models/idefics/perceiver.py +1 -3
- transformers/models/idefics/processing_idefics.py +32 -48
- transformers/models/idefics/vision.py +27 -28
- transformers/models/idefics2/configuration_idefics2.py +1 -3
- transformers/models/idefics2/image_processing_idefics2.py +31 -32
- transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
- transformers/models/idefics2/modeling_idefics2.py +126 -106
- transformers/models/idefics2/processing_idefics2.py +10 -68
- transformers/models/idefics3/configuration_idefics3.py +1 -4
- transformers/models/idefics3/image_processing_idefics3.py +42 -43
- transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
- transformers/models/idefics3/modeling_idefics3.py +113 -92
- transformers/models/idefics3/processing_idefics3.py +15 -69
- transformers/models/ijepa/configuration_ijepa.py +0 -1
- transformers/models/ijepa/modeling_ijepa.py +13 -14
- transformers/models/ijepa/modular_ijepa.py +5 -7
- transformers/models/imagegpt/configuration_imagegpt.py +9 -2
- transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
- transformers/models/imagegpt/modeling_imagegpt.py +65 -62
- transformers/models/informer/configuration_informer.py +6 -9
- transformers/models/informer/modeling_informer.py +87 -89
- transformers/models/informer/modular_informer.py +13 -16
- transformers/models/instructblip/configuration_instructblip.py +2 -2
- transformers/models/instructblip/modeling_instructblip.py +104 -79
- transformers/models/instructblip/processing_instructblip.py +10 -36
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
- transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
- transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
- transformers/models/internvl/configuration_internvl.py +5 -1
- transformers/models/internvl/modeling_internvl.py +76 -98
- transformers/models/internvl/modular_internvl.py +45 -59
- transformers/models/internvl/processing_internvl.py +12 -45
- transformers/models/internvl/video_processing_internvl.py +10 -11
- transformers/models/jais2/configuration_jais2.py +25 -29
- transformers/models/jais2/modeling_jais2.py +36 -38
- transformers/models/jais2/modular_jais2.py +20 -22
- transformers/models/jamba/configuration_jamba.py +5 -8
- transformers/models/jamba/modeling_jamba.py +47 -50
- transformers/models/jamba/modular_jamba.py +40 -41
- transformers/models/janus/configuration_janus.py +0 -1
- transformers/models/janus/image_processing_janus.py +37 -39
- transformers/models/janus/image_processing_janus_fast.py +20 -21
- transformers/models/janus/modeling_janus.py +103 -188
- transformers/models/janus/modular_janus.py +122 -83
- transformers/models/janus/processing_janus.py +17 -43
- transformers/models/jetmoe/configuration_jetmoe.py +26 -27
- transformers/models/jetmoe/modeling_jetmoe.py +42 -45
- transformers/models/jetmoe/modular_jetmoe.py +33 -36
- transformers/models/kosmos2/configuration_kosmos2.py +10 -9
- transformers/models/kosmos2/modeling_kosmos2.py +199 -178
- transformers/models/kosmos2/processing_kosmos2.py +40 -55
- transformers/models/kosmos2_5/__init__.py +0 -1
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
- transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
- transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
- transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
- transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
- transformers/models/lasr/configuration_lasr.py +3 -7
- transformers/models/lasr/feature_extraction_lasr.py +10 -12
- transformers/models/lasr/modeling_lasr.py +21 -24
- transformers/models/lasr/modular_lasr.py +11 -13
- transformers/models/lasr/processing_lasr.py +12 -6
- transformers/models/lasr/tokenization_lasr.py +2 -4
- transformers/models/layoutlm/configuration_layoutlm.py +14 -2
- transformers/models/layoutlm/modeling_layoutlm.py +70 -72
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
- transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
- transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
- transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
- transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
- transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
- transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
- transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
- transformers/models/led/configuration_led.py +8 -12
- transformers/models/led/modeling_led.py +113 -267
- transformers/models/levit/configuration_levit.py +0 -1
- transformers/models/levit/image_processing_levit.py +19 -21
- transformers/models/levit/image_processing_levit_fast.py +4 -5
- transformers/models/levit/modeling_levit.py +17 -19
- transformers/models/lfm2/configuration_lfm2.py +27 -30
- transformers/models/lfm2/modeling_lfm2.py +46 -48
- transformers/models/lfm2/modular_lfm2.py +32 -32
- transformers/models/lfm2_moe/__init__.py +0 -1
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
- transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
- transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
- transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
- transformers/models/lightglue/image_processing_lightglue.py +16 -15
- transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
- transformers/models/lightglue/modeling_lightglue.py +31 -33
- transformers/models/lightglue/modular_lightglue.py +31 -31
- transformers/models/lighton_ocr/__init__.py +28 -0
- transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
- transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
- transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
- transformers/models/lilt/configuration_lilt.py +6 -2
- transformers/models/lilt/modeling_lilt.py +53 -55
- transformers/models/llama/configuration_llama.py +26 -31
- transformers/models/llama/modeling_llama.py +35 -38
- transformers/models/llama/tokenization_llama.py +2 -4
- transformers/models/llama4/configuration_llama4.py +87 -69
- transformers/models/llama4/image_processing_llama4_fast.py +11 -12
- transformers/models/llama4/modeling_llama4.py +116 -115
- transformers/models/llama4/processing_llama4.py +33 -57
- transformers/models/llava/configuration_llava.py +10 -1
- transformers/models/llava/image_processing_llava.py +25 -28
- transformers/models/llava/image_processing_llava_fast.py +9 -10
- transformers/models/llava/modeling_llava.py +73 -102
- transformers/models/llava/processing_llava.py +18 -51
- transformers/models/llava_next/configuration_llava_next.py +2 -2
- transformers/models/llava_next/image_processing_llava_next.py +43 -45
- transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
- transformers/models/llava_next/modeling_llava_next.py +103 -104
- transformers/models/llava_next/processing_llava_next.py +18 -47
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
- transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
- transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
- transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
- transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
- transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
- transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
- transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
- transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
- transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
- transformers/models/longcat_flash/__init__.py +0 -1
- transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
- transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
- transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
- transformers/models/longformer/configuration_longformer.py +5 -5
- transformers/models/longformer/modeling_longformer.py +99 -101
- transformers/models/longt5/configuration_longt5.py +9 -7
- transformers/models/longt5/modeling_longt5.py +45 -45
- transformers/models/luke/configuration_luke.py +8 -2
- transformers/models/luke/modeling_luke.py +179 -181
- transformers/models/luke/tokenization_luke.py +99 -105
- transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
- transformers/models/lw_detr/configuration_lw_detr.py +362 -0
- transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
- transformers/models/lw_detr/modular_lw_detr.py +1609 -0
- transformers/models/lxmert/configuration_lxmert.py +16 -1
- transformers/models/lxmert/modeling_lxmert.py +63 -74
- transformers/models/m2m_100/configuration_m2m_100.py +7 -9
- transformers/models/m2m_100/modeling_m2m_100.py +72 -74
- transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
- transformers/models/mamba/configuration_mamba.py +5 -3
- transformers/models/mamba/modeling_mamba.py +61 -70
- transformers/models/mamba2/configuration_mamba2.py +5 -8
- transformers/models/mamba2/modeling_mamba2.py +66 -79
- transformers/models/marian/configuration_marian.py +10 -5
- transformers/models/marian/modeling_marian.py +88 -90
- transformers/models/marian/tokenization_marian.py +6 -6
- transformers/models/markuplm/configuration_markuplm.py +4 -7
- transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
- transformers/models/markuplm/modeling_markuplm.py +63 -65
- transformers/models/markuplm/processing_markuplm.py +31 -38
- transformers/models/markuplm/tokenization_markuplm.py +67 -77
- transformers/models/mask2former/configuration_mask2former.py +14 -52
- transformers/models/mask2former/image_processing_mask2former.py +84 -85
- transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
- transformers/models/mask2former/modeling_mask2former.py +108 -104
- transformers/models/mask2former/modular_mask2former.py +6 -8
- transformers/models/maskformer/configuration_maskformer.py +17 -51
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
- transformers/models/maskformer/image_processing_maskformer.py +84 -85
- transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
- transformers/models/maskformer/modeling_maskformer.py +71 -67
- transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
- transformers/models/mbart/configuration_mbart.py +9 -5
- transformers/models/mbart/modeling_mbart.py +120 -119
- transformers/models/mbart/tokenization_mbart.py +2 -4
- transformers/models/mbart50/tokenization_mbart50.py +3 -5
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
- transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
- transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
- transformers/models/mgp_str/configuration_mgp_str.py +0 -1
- transformers/models/mgp_str/modeling_mgp_str.py +18 -18
- transformers/models/mgp_str/processing_mgp_str.py +3 -20
- transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
- transformers/models/mimi/configuration_mimi.py +42 -40
- transformers/models/mimi/modeling_mimi.py +116 -115
- transformers/models/minimax/__init__.py +0 -1
- transformers/models/minimax/configuration_minimax.py +40 -47
- transformers/models/minimax/modeling_minimax.py +46 -49
- transformers/models/minimax/modular_minimax.py +59 -65
- transformers/models/minimax_m2/__init__.py +28 -0
- transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
- transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
- transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
- transformers/models/ministral/configuration_ministral.py +25 -29
- transformers/models/ministral/modeling_ministral.py +35 -37
- transformers/models/ministral/modular_ministral.py +32 -37
- transformers/models/ministral3/configuration_ministral3.py +23 -26
- transformers/models/ministral3/modeling_ministral3.py +35 -37
- transformers/models/ministral3/modular_ministral3.py +7 -8
- transformers/models/mistral/configuration_mistral.py +24 -29
- transformers/models/mistral/modeling_mistral.py +35 -37
- transformers/models/mistral/modular_mistral.py +14 -15
- transformers/models/mistral3/configuration_mistral3.py +4 -1
- transformers/models/mistral3/modeling_mistral3.py +79 -82
- transformers/models/mistral3/modular_mistral3.py +66 -67
- transformers/models/mixtral/configuration_mixtral.py +32 -38
- transformers/models/mixtral/modeling_mixtral.py +39 -42
- transformers/models/mixtral/modular_mixtral.py +26 -29
- transformers/models/mlcd/configuration_mlcd.py +0 -1
- transformers/models/mlcd/modeling_mlcd.py +17 -17
- transformers/models/mlcd/modular_mlcd.py +16 -16
- transformers/models/mllama/configuration_mllama.py +10 -15
- transformers/models/mllama/image_processing_mllama.py +23 -25
- transformers/models/mllama/image_processing_mllama_fast.py +11 -11
- transformers/models/mllama/modeling_mllama.py +100 -103
- transformers/models/mllama/processing_mllama.py +6 -55
- transformers/models/mluke/tokenization_mluke.py +97 -103
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
- transformers/models/mobilebert/configuration_mobilebert.py +4 -2
- transformers/models/mobilebert/modeling_mobilebert.py +78 -88
- transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
- transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
- transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
- transformers/models/mobilevit/configuration_mobilevit.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
- transformers/models/mobilevit/modeling_mobilevit.py +21 -21
- transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
- transformers/models/modernbert/configuration_modernbert.py +76 -51
- transformers/models/modernbert/modeling_modernbert.py +188 -943
- transformers/models/modernbert/modular_modernbert.py +255 -978
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
- transformers/models/moonshine/configuration_moonshine.py +34 -31
- transformers/models/moonshine/modeling_moonshine.py +70 -72
- transformers/models/moonshine/modular_moonshine.py +91 -86
- transformers/models/moshi/configuration_moshi.py +46 -23
- transformers/models/moshi/modeling_moshi.py +134 -142
- transformers/models/mpnet/configuration_mpnet.py +6 -2
- transformers/models/mpnet/modeling_mpnet.py +55 -57
- transformers/models/mpnet/tokenization_mpnet.py +1 -4
- transformers/models/mpt/configuration_mpt.py +17 -9
- transformers/models/mpt/modeling_mpt.py +58 -60
- transformers/models/mra/configuration_mra.py +8 -2
- transformers/models/mra/modeling_mra.py +54 -56
- transformers/models/mt5/configuration_mt5.py +9 -6
- transformers/models/mt5/modeling_mt5.py +80 -85
- transformers/models/musicgen/configuration_musicgen.py +12 -8
- transformers/models/musicgen/modeling_musicgen.py +114 -116
- transformers/models/musicgen/processing_musicgen.py +3 -21
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
- transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
- transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
- transformers/models/mvp/configuration_mvp.py +8 -5
- transformers/models/mvp/modeling_mvp.py +121 -123
- transformers/models/myt5/tokenization_myt5.py +8 -10
- transformers/models/nanochat/configuration_nanochat.py +5 -8
- transformers/models/nanochat/modeling_nanochat.py +36 -39
- transformers/models/nanochat/modular_nanochat.py +16 -18
- transformers/models/nemotron/configuration_nemotron.py +25 -30
- transformers/models/nemotron/modeling_nemotron.py +53 -66
- transformers/models/nllb/tokenization_nllb.py +14 -14
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
- transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
- transformers/models/nougat/image_processing_nougat.py +29 -32
- transformers/models/nougat/image_processing_nougat_fast.py +12 -13
- transformers/models/nougat/processing_nougat.py +37 -39
- transformers/models/nougat/tokenization_nougat.py +5 -7
- transformers/models/nystromformer/configuration_nystromformer.py +8 -2
- transformers/models/nystromformer/modeling_nystromformer.py +61 -63
- transformers/models/olmo/configuration_olmo.py +23 -28
- transformers/models/olmo/modeling_olmo.py +35 -38
- transformers/models/olmo/modular_olmo.py +8 -12
- transformers/models/olmo2/configuration_olmo2.py +27 -32
- transformers/models/olmo2/modeling_olmo2.py +36 -39
- transformers/models/olmo2/modular_olmo2.py +36 -38
- transformers/models/olmo3/__init__.py +0 -1
- transformers/models/olmo3/configuration_olmo3.py +30 -34
- transformers/models/olmo3/modeling_olmo3.py +35 -38
- transformers/models/olmo3/modular_olmo3.py +44 -47
- transformers/models/olmoe/configuration_olmoe.py +29 -33
- transformers/models/olmoe/modeling_olmoe.py +41 -43
- transformers/models/olmoe/modular_olmoe.py +15 -16
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
- transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
- transformers/models/oneformer/configuration_oneformer.py +11 -51
- transformers/models/oneformer/image_processing_oneformer.py +83 -84
- transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
- transformers/models/oneformer/modeling_oneformer.py +137 -133
- transformers/models/oneformer/processing_oneformer.py +28 -43
- transformers/models/openai/configuration_openai.py +16 -1
- transformers/models/openai/modeling_openai.py +50 -51
- transformers/models/openai/tokenization_openai.py +2 -5
- transformers/models/opt/configuration_opt.py +6 -7
- transformers/models/opt/modeling_opt.py +79 -80
- transformers/models/ovis2/__init__.py +0 -1
- transformers/models/ovis2/configuration_ovis2.py +4 -1
- transformers/models/ovis2/image_processing_ovis2.py +22 -24
- transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
- transformers/models/ovis2/modeling_ovis2.py +99 -142
- transformers/models/ovis2/modular_ovis2.py +82 -45
- transformers/models/ovis2/processing_ovis2.py +12 -40
- transformers/models/owlv2/configuration_owlv2.py +4 -2
- transformers/models/owlv2/image_processing_owlv2.py +20 -21
- transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
- transformers/models/owlv2/modeling_owlv2.py +122 -114
- transformers/models/owlv2/modular_owlv2.py +11 -12
- transformers/models/owlv2/processing_owlv2.py +20 -49
- transformers/models/owlvit/configuration_owlvit.py +4 -2
- transformers/models/owlvit/image_processing_owlvit.py +21 -22
- transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
- transformers/models/owlvit/modeling_owlvit.py +121 -113
- transformers/models/owlvit/processing_owlvit.py +20 -48
- transformers/models/paddleocr_vl/__init__.py +0 -1
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
- transformers/models/paligemma/configuration_paligemma.py +4 -1
- transformers/models/paligemma/modeling_paligemma.py +81 -79
- transformers/models/paligemma/processing_paligemma.py +13 -66
- transformers/models/parakeet/configuration_parakeet.py +3 -8
- transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
- transformers/models/parakeet/modeling_parakeet.py +21 -25
- transformers/models/parakeet/modular_parakeet.py +19 -21
- transformers/models/parakeet/processing_parakeet.py +12 -5
- transformers/models/parakeet/tokenization_parakeet.py +2 -4
- transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
- transformers/models/patchtst/configuration_patchtst.py +6 -9
- transformers/models/patchtst/modeling_patchtst.py +75 -77
- transformers/models/pe_audio/__init__.py +0 -1
- transformers/models/pe_audio/configuration_pe_audio.py +14 -16
- transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
- transformers/models/pe_audio/modeling_pe_audio.py +30 -31
- transformers/models/pe_audio/modular_pe_audio.py +17 -18
- transformers/models/pe_audio/processing_pe_audio.py +0 -1
- transformers/models/pe_audio_video/__init__.py +0 -1
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
- transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
- transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
- transformers/models/pe_video/__init__.py +0 -1
- transformers/models/pe_video/configuration_pe_video.py +14 -16
- transformers/models/pe_video/modeling_pe_video.py +57 -46
- transformers/models/pe_video/modular_pe_video.py +47 -35
- transformers/models/pe_video/video_processing_pe_video.py +2 -4
- transformers/models/pegasus/configuration_pegasus.py +8 -6
- transformers/models/pegasus/modeling_pegasus.py +67 -69
- transformers/models/pegasus/tokenization_pegasus.py +1 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
- transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
- transformers/models/perceiver/configuration_perceiver.py +0 -1
- transformers/models/perceiver/image_processing_perceiver.py +22 -25
- transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
- transformers/models/perceiver/modeling_perceiver.py +152 -145
- transformers/models/perceiver/tokenization_perceiver.py +3 -6
- transformers/models/perception_lm/configuration_perception_lm.py +0 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
- transformers/models/perception_lm/modeling_perception_lm.py +64 -67
- transformers/models/perception_lm/modular_perception_lm.py +58 -58
- transformers/models/perception_lm/processing_perception_lm.py +13 -47
- transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
- transformers/models/persimmon/configuration_persimmon.py +23 -28
- transformers/models/persimmon/modeling_persimmon.py +44 -47
- transformers/models/phi/configuration_phi.py +27 -28
- transformers/models/phi/modeling_phi.py +39 -41
- transformers/models/phi/modular_phi.py +26 -26
- transformers/models/phi3/configuration_phi3.py +32 -37
- transformers/models/phi3/modeling_phi3.py +37 -40
- transformers/models/phi3/modular_phi3.py +16 -20
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
- transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
- transformers/models/phimoe/configuration_phimoe.py +31 -36
- transformers/models/phimoe/modeling_phimoe.py +50 -77
- transformers/models/phimoe/modular_phimoe.py +12 -8
- transformers/models/phobert/tokenization_phobert.py +4 -6
- transformers/models/pix2struct/configuration_pix2struct.py +12 -10
- transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
- transformers/models/pix2struct/modeling_pix2struct.py +56 -52
- transformers/models/pix2struct/processing_pix2struct.py +5 -26
- transformers/models/pixio/__init__.py +0 -1
- transformers/models/pixio/configuration_pixio.py +2 -5
- transformers/models/pixio/modeling_pixio.py +16 -17
- transformers/models/pixio/modular_pixio.py +7 -8
- transformers/models/pixtral/configuration_pixtral.py +11 -14
- transformers/models/pixtral/image_processing_pixtral.py +26 -28
- transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
- transformers/models/pixtral/modeling_pixtral.py +31 -37
- transformers/models/pixtral/processing_pixtral.py +18 -52
- transformers/models/plbart/configuration_plbart.py +8 -6
- transformers/models/plbart/modeling_plbart.py +109 -109
- transformers/models/plbart/modular_plbart.py +31 -33
- transformers/models/plbart/tokenization_plbart.py +4 -5
- transformers/models/poolformer/configuration_poolformer.py +0 -1
- transformers/models/poolformer/image_processing_poolformer.py +21 -24
- transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
- transformers/models/poolformer/modeling_poolformer.py +10 -12
- transformers/models/pop2piano/configuration_pop2piano.py +7 -7
- transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
- transformers/models/pop2piano/modeling_pop2piano.py +24 -24
- transformers/models/pop2piano/processing_pop2piano.py +25 -33
- transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
- transformers/models/prophetnet/configuration_prophetnet.py +37 -38
- transformers/models/prophetnet/modeling_prophetnet.py +121 -153
- transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
- transformers/models/pvt/configuration_pvt.py +0 -1
- transformers/models/pvt/image_processing_pvt.py +24 -27
- transformers/models/pvt/image_processing_pvt_fast.py +1 -2
- transformers/models/pvt/modeling_pvt.py +19 -21
- transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
- transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
- transformers/models/qwen2/configuration_qwen2.py +32 -25
- transformers/models/qwen2/modeling_qwen2.py +35 -37
- transformers/models/qwen2/modular_qwen2.py +14 -15
- transformers/models/qwen2/tokenization_qwen2.py +2 -9
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
- transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
- transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
- transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
- transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
- transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
- transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
- transformers/models/qwen3/configuration_qwen3.py +34 -27
- transformers/models/qwen3/modeling_qwen3.py +35 -38
- transformers/models/qwen3/modular_qwen3.py +7 -9
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
- transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
- transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
- transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
- transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
- transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
- transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
- transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
- transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
- transformers/models/rag/configuration_rag.py +6 -7
- transformers/models/rag/modeling_rag.py +119 -121
- transformers/models/rag/retrieval_rag.py +3 -5
- transformers/models/rag/tokenization_rag.py +0 -50
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
- transformers/models/reformer/configuration_reformer.py +7 -8
- transformers/models/reformer/modeling_reformer.py +67 -68
- transformers/models/reformer/tokenization_reformer.py +3 -6
- transformers/models/regnet/configuration_regnet.py +0 -1
- transformers/models/regnet/modeling_regnet.py +7 -9
- transformers/models/rembert/configuration_rembert.py +8 -2
- transformers/models/rembert/modeling_rembert.py +108 -132
- transformers/models/rembert/tokenization_rembert.py +1 -4
- transformers/models/resnet/configuration_resnet.py +2 -5
- transformers/models/resnet/modeling_resnet.py +14 -15
- transformers/models/roberta/configuration_roberta.py +11 -3
- transformers/models/roberta/modeling_roberta.py +97 -99
- transformers/models/roberta/modular_roberta.py +55 -58
- transformers/models/roberta/tokenization_roberta.py +2 -5
- transformers/models/roberta/tokenization_roberta_old.py +2 -4
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
- transformers/models/roc_bert/configuration_roc_bert.py +8 -2
- transformers/models/roc_bert/modeling_roc_bert.py +125 -162
- transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
- transformers/models/roformer/configuration_roformer.py +13 -3
- transformers/models/roformer/modeling_roformer.py +79 -95
- transformers/models/roformer/tokenization_roformer.py +3 -6
- transformers/models/roformer/tokenization_utils.py +0 -1
- transformers/models/rt_detr/configuration_rt_detr.py +8 -50
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
- transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
- transformers/models/rt_detr/modeling_rt_detr.py +643 -804
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
- transformers/models/rt_detr/modular_rt_detr.py +1522 -20
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
- transformers/models/rwkv/configuration_rwkv.py +2 -4
- transformers/models/rwkv/modeling_rwkv.py +29 -54
- transformers/models/sam/configuration_sam.py +2 -1
- transformers/models/sam/image_processing_sam.py +59 -60
- transformers/models/sam/image_processing_sam_fast.py +25 -26
- transformers/models/sam/modeling_sam.py +46 -43
- transformers/models/sam/processing_sam.py +39 -27
- transformers/models/sam2/configuration_sam2.py +1 -2
- transformers/models/sam2/image_processing_sam2_fast.py +14 -15
- transformers/models/sam2/modeling_sam2.py +96 -94
- transformers/models/sam2/modular_sam2.py +85 -94
- transformers/models/sam2/processing_sam2.py +31 -47
- transformers/models/sam2_video/configuration_sam2_video.py +0 -1
- transformers/models/sam2_video/modeling_sam2_video.py +114 -116
- transformers/models/sam2_video/modular_sam2_video.py +72 -89
- transformers/models/sam2_video/processing_sam2_video.py +49 -66
- transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
- transformers/models/sam3/configuration_sam3.py +0 -1
- transformers/models/sam3/image_processing_sam3_fast.py +17 -20
- transformers/models/sam3/modeling_sam3.py +94 -100
- transformers/models/sam3/modular_sam3.py +3 -8
- transformers/models/sam3/processing_sam3.py +37 -52
- transformers/models/sam3_tracker/__init__.py +0 -1
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
- transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
- transformers/models/sam3_tracker_video/__init__.py +0 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
- transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
- transformers/models/sam3_video/configuration_sam3_video.py +0 -1
- transformers/models/sam3_video/modeling_sam3_video.py +56 -45
- transformers/models/sam3_video/processing_sam3_video.py +25 -45
- transformers/models/sam_hq/__init__.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +2 -1
- transformers/models/sam_hq/modeling_sam_hq.py +52 -50
- transformers/models/sam_hq/modular_sam_hq.py +23 -25
- transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
- transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
- transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
- transformers/models/seed_oss/configuration_seed_oss.py +30 -34
- transformers/models/seed_oss/modeling_seed_oss.py +34 -36
- transformers/models/seed_oss/modular_seed_oss.py +6 -7
- transformers/models/segformer/configuration_segformer.py +0 -10
- transformers/models/segformer/image_processing_segformer.py +39 -42
- transformers/models/segformer/image_processing_segformer_fast.py +11 -12
- transformers/models/segformer/modeling_segformer.py +28 -28
- transformers/models/segformer/modular_segformer.py +8 -9
- transformers/models/seggpt/configuration_seggpt.py +0 -1
- transformers/models/seggpt/image_processing_seggpt.py +38 -41
- transformers/models/seggpt/modeling_seggpt.py +48 -38
- transformers/models/sew/configuration_sew.py +4 -2
- transformers/models/sew/modeling_sew.py +42 -40
- transformers/models/sew/modular_sew.py +12 -13
- transformers/models/sew_d/configuration_sew_d.py +4 -2
- transformers/models/sew_d/modeling_sew_d.py +32 -31
- transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
- transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
- transformers/models/siglip/configuration_siglip.py +4 -2
- transformers/models/siglip/image_processing_siglip.py +17 -20
- transformers/models/siglip/image_processing_siglip_fast.py +0 -1
- transformers/models/siglip/modeling_siglip.py +65 -110
- transformers/models/siglip/processing_siglip.py +2 -14
- transformers/models/siglip/tokenization_siglip.py +6 -7
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2.py +15 -16
- transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
- transformers/models/siglip2/modeling_siglip2.py +89 -130
- transformers/models/siglip2/modular_siglip2.py +95 -48
- transformers/models/siglip2/processing_siglip2.py +2 -14
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +29 -32
- transformers/models/smollm3/modeling_smollm3.py +35 -38
- transformers/models/smollm3/modular_smollm3.py +36 -38
- transformers/models/smolvlm/configuration_smolvlm.py +2 -4
- transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
- transformers/models/smolvlm/modeling_smolvlm.py +124 -96
- transformers/models/smolvlm/modular_smolvlm.py +50 -39
- transformers/models/smolvlm/processing_smolvlm.py +15 -76
- transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
- transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
- transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
- transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
- transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
- transformers/models/speecht5/configuration_speecht5.py +7 -9
- transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
- transformers/models/speecht5/modeling_speecht5.py +172 -174
- transformers/models/speecht5/number_normalizer.py +0 -1
- transformers/models/speecht5/processing_speecht5.py +3 -37
- transformers/models/speecht5/tokenization_speecht5.py +4 -5
- transformers/models/splinter/configuration_splinter.py +6 -7
- transformers/models/splinter/modeling_splinter.py +62 -59
- transformers/models/splinter/tokenization_splinter.py +2 -4
- transformers/models/squeezebert/configuration_squeezebert.py +14 -2
- transformers/models/squeezebert/modeling_squeezebert.py +60 -62
- transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
- transformers/models/stablelm/configuration_stablelm.py +28 -29
- transformers/models/stablelm/modeling_stablelm.py +44 -47
- transformers/models/starcoder2/configuration_starcoder2.py +30 -27
- transformers/models/starcoder2/modeling_starcoder2.py +38 -41
- transformers/models/starcoder2/modular_starcoder2.py +17 -19
- transformers/models/superglue/configuration_superglue.py +7 -3
- transformers/models/superglue/image_processing_superglue.py +15 -15
- transformers/models/superglue/image_processing_superglue_fast.py +8 -8
- transformers/models/superglue/modeling_superglue.py +41 -37
- transformers/models/superpoint/image_processing_superpoint.py +15 -15
- transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
- transformers/models/superpoint/modeling_superpoint.py +17 -16
- transformers/models/swiftformer/configuration_swiftformer.py +0 -1
- transformers/models/swiftformer/modeling_swiftformer.py +12 -14
- transformers/models/swin/configuration_swin.py +2 -5
- transformers/models/swin/modeling_swin.py +69 -78
- transformers/models/swin2sr/configuration_swin2sr.py +0 -1
- transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
- transformers/models/swin2sr/modeling_swin2sr.py +30 -30
- transformers/models/swinv2/configuration_swinv2.py +2 -5
- transformers/models/swinv2/modeling_swinv2.py +65 -74
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
- transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
- transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
- transformers/models/t5/configuration_t5.py +9 -9
- transformers/models/t5/modeling_t5.py +80 -85
- transformers/models/t5/tokenization_t5.py +1 -3
- transformers/models/t5gemma/configuration_t5gemma.py +43 -59
- transformers/models/t5gemma/modeling_t5gemma.py +105 -108
- transformers/models/t5gemma/modular_t5gemma.py +128 -142
- transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
- transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
- transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
- transformers/models/table_transformer/configuration_table_transformer.py +18 -50
- transformers/models/table_transformer/modeling_table_transformer.py +73 -101
- transformers/models/tapas/configuration_tapas.py +12 -2
- transformers/models/tapas/modeling_tapas.py +65 -67
- transformers/models/tapas/tokenization_tapas.py +116 -153
- transformers/models/textnet/configuration_textnet.py +4 -7
- transformers/models/textnet/image_processing_textnet.py +22 -25
- transformers/models/textnet/image_processing_textnet_fast.py +8 -9
- transformers/models/textnet/modeling_textnet.py +28 -28
- transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
- transformers/models/timesfm/configuration_timesfm.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +22 -25
- transformers/models/timesfm/modular_timesfm.py +21 -24
- transformers/models/timesformer/configuration_timesformer.py +0 -1
- transformers/models/timesformer/modeling_timesformer.py +13 -16
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
- transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
- transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
- transformers/models/trocr/configuration_trocr.py +11 -8
- transformers/models/trocr/modeling_trocr.py +42 -42
- transformers/models/trocr/processing_trocr.py +5 -25
- transformers/models/tvp/configuration_tvp.py +10 -36
- transformers/models/tvp/image_processing_tvp.py +50 -52
- transformers/models/tvp/image_processing_tvp_fast.py +15 -15
- transformers/models/tvp/modeling_tvp.py +26 -28
- transformers/models/tvp/processing_tvp.py +2 -14
- transformers/models/udop/configuration_udop.py +16 -8
- transformers/models/udop/modeling_udop.py +73 -72
- transformers/models/udop/processing_udop.py +7 -26
- transformers/models/udop/tokenization_udop.py +80 -93
- transformers/models/umt5/configuration_umt5.py +8 -7
- transformers/models/umt5/modeling_umt5.py +87 -84
- transformers/models/unispeech/configuration_unispeech.py +4 -2
- transformers/models/unispeech/modeling_unispeech.py +54 -53
- transformers/models/unispeech/modular_unispeech.py +20 -22
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
- transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
- transformers/models/univnet/feature_extraction_univnet.py +14 -14
- transformers/models/univnet/modeling_univnet.py +7 -8
- transformers/models/upernet/configuration_upernet.py +8 -36
- transformers/models/upernet/modeling_upernet.py +11 -14
- transformers/models/vaultgemma/__init__.py +0 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
- transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
- transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
- transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
- transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
- transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
- transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
- transformers/models/video_llava/configuration_video_llava.py +4 -1
- transformers/models/video_llava/image_processing_video_llava.py +35 -38
- transformers/models/video_llava/modeling_video_llava.py +139 -143
- transformers/models/video_llava/processing_video_llava.py +38 -78
- transformers/models/video_llava/video_processing_video_llava.py +0 -1
- transformers/models/videomae/configuration_videomae.py +0 -1
- transformers/models/videomae/image_processing_videomae.py +31 -34
- transformers/models/videomae/modeling_videomae.py +17 -20
- transformers/models/videomae/video_processing_videomae.py +0 -1
- transformers/models/vilt/configuration_vilt.py +4 -2
- transformers/models/vilt/image_processing_vilt.py +29 -30
- transformers/models/vilt/image_processing_vilt_fast.py +15 -16
- transformers/models/vilt/modeling_vilt.py +103 -90
- transformers/models/vilt/processing_vilt.py +2 -14
- transformers/models/vipllava/configuration_vipllava.py +4 -1
- transformers/models/vipllava/modeling_vipllava.py +92 -67
- transformers/models/vipllava/modular_vipllava.py +78 -54
- transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
- transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
- transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
- transformers/models/visual_bert/configuration_visual_bert.py +6 -2
- transformers/models/visual_bert/modeling_visual_bert.py +90 -92
- transformers/models/vit/configuration_vit.py +2 -3
- transformers/models/vit/image_processing_vit.py +19 -22
- transformers/models/vit/image_processing_vit_fast.py +0 -1
- transformers/models/vit/modeling_vit.py +20 -20
- transformers/models/vit_mae/configuration_vit_mae.py +0 -1
- transformers/models/vit_mae/modeling_vit_mae.py +32 -30
- transformers/models/vit_msn/configuration_vit_msn.py +0 -1
- transformers/models/vit_msn/modeling_vit_msn.py +21 -19
- transformers/models/vitdet/configuration_vitdet.py +2 -5
- transformers/models/vitdet/modeling_vitdet.py +14 -17
- transformers/models/vitmatte/configuration_vitmatte.py +7 -39
- transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
- transformers/models/vitmatte/modeling_vitmatte.py +10 -12
- transformers/models/vitpose/configuration_vitpose.py +7 -47
- transformers/models/vitpose/image_processing_vitpose.py +24 -25
- transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
- transformers/models/vitpose/modeling_vitpose.py +15 -15
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
- transformers/models/vits/configuration_vits.py +4 -1
- transformers/models/vits/modeling_vits.py +43 -42
- transformers/models/vits/tokenization_vits.py +3 -4
- transformers/models/vivit/configuration_vivit.py +0 -1
- transformers/models/vivit/image_processing_vivit.py +36 -39
- transformers/models/vivit/modeling_vivit.py +9 -11
- transformers/models/vjepa2/__init__.py +0 -1
- transformers/models/vjepa2/configuration_vjepa2.py +0 -1
- transformers/models/vjepa2/modeling_vjepa2.py +39 -41
- transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
- transformers/models/voxtral/__init__.py +0 -1
- transformers/models/voxtral/configuration_voxtral.py +0 -2
- transformers/models/voxtral/modeling_voxtral.py +41 -48
- transformers/models/voxtral/modular_voxtral.py +35 -38
- transformers/models/voxtral/processing_voxtral.py +25 -48
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
- transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
- transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
- transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
- transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
- transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
- transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
- transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
- transformers/models/wavlm/configuration_wavlm.py +4 -2
- transformers/models/wavlm/modeling_wavlm.py +49 -49
- transformers/models/wavlm/modular_wavlm.py +4 -5
- transformers/models/whisper/configuration_whisper.py +6 -5
- transformers/models/whisper/english_normalizer.py +3 -4
- transformers/models/whisper/feature_extraction_whisper.py +9 -24
- transformers/models/whisper/generation_whisper.py +26 -49
- transformers/models/whisper/modeling_whisper.py +71 -73
- transformers/models/whisper/processing_whisper.py +3 -20
- transformers/models/whisper/tokenization_whisper.py +9 -30
- transformers/models/x_clip/configuration_x_clip.py +4 -2
- transformers/models/x_clip/modeling_x_clip.py +94 -96
- transformers/models/x_clip/processing_x_clip.py +2 -14
- transformers/models/xcodec/configuration_xcodec.py +4 -6
- transformers/models/xcodec/modeling_xcodec.py +15 -17
- transformers/models/xglm/configuration_xglm.py +9 -8
- transformers/models/xglm/modeling_xglm.py +49 -55
- transformers/models/xglm/tokenization_xglm.py +1 -4
- transformers/models/xlm/configuration_xlm.py +10 -8
- transformers/models/xlm/modeling_xlm.py +127 -131
- transformers/models/xlm/tokenization_xlm.py +3 -5
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
- transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
- transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
- transformers/models/xlnet/configuration_xlnet.py +3 -12
- transformers/models/xlnet/modeling_xlnet.py +149 -162
- transformers/models/xlnet/tokenization_xlnet.py +1 -4
- transformers/models/xlstm/configuration_xlstm.py +8 -12
- transformers/models/xlstm/modeling_xlstm.py +61 -96
- transformers/models/xmod/configuration_xmod.py +11 -3
- transformers/models/xmod/modeling_xmod.py +111 -116
- transformers/models/yolos/configuration_yolos.py +0 -1
- transformers/models/yolos/image_processing_yolos.py +60 -62
- transformers/models/yolos/image_processing_yolos_fast.py +42 -45
- transformers/models/yolos/modeling_yolos.py +19 -21
- transformers/models/yolos/modular_yolos.py +17 -19
- transformers/models/yoso/configuration_yoso.py +8 -2
- transformers/models/yoso/modeling_yoso.py +60 -62
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -8
- transformers/models/zamba/modeling_zamba.py +93 -125
- transformers/models/zamba2/configuration_zamba2.py +44 -50
- transformers/models/zamba2/modeling_zamba2.py +137 -165
- transformers/models/zamba2/modular_zamba2.py +79 -74
- transformers/models/zoedepth/configuration_zoedepth.py +17 -41
- transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
- transformers/models/zoedepth/modeling_zoedepth.py +19 -19
- transformers/pipelines/__init__.py +47 -106
- transformers/pipelines/any_to_any.py +15 -23
- transformers/pipelines/audio_utils.py +1 -2
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +13 -17
- transformers/pipelines/image_text_to_text.py +1 -2
- transformers/pipelines/question_answering.py +4 -43
- transformers/pipelines/text_classification.py +1 -14
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/pipelines/token_classification.py +1 -22
- transformers/pipelines/video_classification.py +1 -9
- transformers/pipelines/zero_shot_audio_classification.py +0 -1
- transformers/pipelines/zero_shot_classification.py +0 -6
- transformers/pipelines/zero_shot_image_classification.py +0 -7
- transformers/processing_utils.py +128 -137
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/base.py +10 -0
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_quark.py +0 -1
- transformers/quantizers/quantizer_torchao.py +3 -19
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +6 -65
- transformers/tokenization_mistral_common.py +563 -903
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +228 -341
- transformers/tokenization_utils_sentencepiece.py +5 -6
- transformers/tokenization_utils_tokenizers.py +36 -7
- transformers/trainer.py +30 -41
- transformers/trainer_jit_checkpoint.py +1 -2
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +414 -420
- transformers/utils/__init__.py +1 -4
- transformers/utils/attention_visualizer.py +1 -1
- transformers/utils/auto_docstring.py +567 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/dummy_pt_objects.py +0 -42
- transformers/utils/generic.py +70 -34
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +135 -107
- transformers/utils/quantization_config.py +8 -31
- transformers/video_processing_utils.py +24 -25
- transformers/video_utils.py +21 -23
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
- transformers-5.1.0.dist-info/RECORD +2092 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -229
- transformers-5.0.0rc2.dist-info/RECORD +0 -2042
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
transformers/modeling_utils.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# coding=utf-8
|
|
2
1
|
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
|
3
2
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
4
3
|
#
|
|
@@ -25,13 +24,14 @@ import sys
|
|
|
25
24
|
import warnings
|
|
26
25
|
from abc import abstractmethod
|
|
27
26
|
from collections import defaultdict
|
|
28
|
-
from collections.abc import Callable, Iterator
|
|
27
|
+
from collections.abc import Callable, Iterator
|
|
29
28
|
from contextlib import contextmanager
|
|
29
|
+
from dataclasses import dataclass, field
|
|
30
30
|
from enum import Enum
|
|
31
31
|
from functools import partial, wraps
|
|
32
32
|
from itertools import cycle
|
|
33
33
|
from threading import Thread
|
|
34
|
-
from typing import Optional, TypeVar,
|
|
34
|
+
from typing import Optional, TypeVar, get_type_hints
|
|
35
35
|
from zipfile import is_zipfile
|
|
36
36
|
|
|
37
37
|
import torch
|
|
@@ -78,9 +78,8 @@ from .integrations.tensor_parallel import (
|
|
|
78
78
|
ALL_PARALLEL_STYLES,
|
|
79
79
|
_get_parameter_tp_plan,
|
|
80
80
|
distribute_model,
|
|
81
|
+
gather_state_dict_for_save,
|
|
81
82
|
initialize_tensor_parallelism,
|
|
82
|
-
repack_weights,
|
|
83
|
-
replace_state_dict_local_with_dtensor,
|
|
84
83
|
shard_and_distribute_module,
|
|
85
84
|
verify_tp_plan,
|
|
86
85
|
)
|
|
@@ -107,25 +106,26 @@ from .utils import (
|
|
|
107
106
|
copy_func,
|
|
108
107
|
has_file,
|
|
109
108
|
is_accelerate_available,
|
|
109
|
+
is_bitsandbytes_available,
|
|
110
|
+
is_env_variable_true,
|
|
110
111
|
is_flash_attn_2_available,
|
|
111
112
|
is_flash_attn_3_available,
|
|
112
113
|
is_grouped_mm_available,
|
|
113
114
|
is_kernels_available,
|
|
114
115
|
is_torch_flex_attn_available,
|
|
115
|
-
is_torch_greater_or_equal,
|
|
116
116
|
is_torch_mlu_available,
|
|
117
117
|
is_torch_npu_available,
|
|
118
118
|
is_torch_xpu_available,
|
|
119
119
|
logging,
|
|
120
120
|
)
|
|
121
|
-
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
|
|
121
|
+
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder, is_flash_attention_requested
|
|
122
122
|
from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
|
|
123
123
|
from .utils.import_utils import (
|
|
124
124
|
is_huggingface_hub_greater_or_equal,
|
|
125
125
|
is_sagemaker_mp_enabled,
|
|
126
126
|
is_tracing,
|
|
127
127
|
)
|
|
128
|
-
from .utils.loading_report import log_state_dict_report
|
|
128
|
+
from .utils.loading_report import LoadStateDictInfo, log_state_dict_report
|
|
129
129
|
from .utils.quantization_config import QuantizationMethod
|
|
130
130
|
|
|
131
131
|
|
|
@@ -135,9 +135,6 @@ if is_accelerate_available():
|
|
|
135
135
|
|
|
136
136
|
|
|
137
137
|
_torch_distributed_available = torch.distributed.is_available()
|
|
138
|
-
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
|
|
139
|
-
if _is_dtensor_available:
|
|
140
|
-
from torch.distributed.tensor import DTensor
|
|
141
138
|
|
|
142
139
|
if is_sagemaker_mp_enabled():
|
|
143
140
|
import smdistributed.modelparallel.torch as smp
|
|
@@ -163,6 +160,33 @@ FLASH_ATTN_KERNEL_FALLBACK = {
|
|
|
163
160
|
}
|
|
164
161
|
|
|
165
162
|
|
|
163
|
+
@dataclass(frozen=True)
|
|
164
|
+
class LoadStateDictConfig:
|
|
165
|
+
"""
|
|
166
|
+
Config for loading weights. This allows bundling arguments that are just
|
|
167
|
+
passed around.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
pretrained_model_name_or_path: str | None = None
|
|
171
|
+
download_kwargs: DownloadKwargs | None = field(default_factory=DownloadKwargs)
|
|
172
|
+
use_safetensors: bool | None = None
|
|
173
|
+
ignore_mismatched_sizes: bool = False
|
|
174
|
+
sharded_metadata: dict | None = None
|
|
175
|
+
device_map: dict | None = None
|
|
176
|
+
disk_offload_folder: str | None = None
|
|
177
|
+
offload_buffers: bool = False
|
|
178
|
+
dtype: torch.dtype | None = None
|
|
179
|
+
dtype_plan: dict = field(default_factory=dict)
|
|
180
|
+
hf_quantizer: HfQuantizer | None = None
|
|
181
|
+
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None
|
|
182
|
+
weights_only: bool = True
|
|
183
|
+
weight_mapping: list[WeightConverter | WeightRenaming] | None = None
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def is_quantized(self) -> bool:
|
|
187
|
+
return self.hf_quantizer is not None
|
|
188
|
+
|
|
189
|
+
|
|
166
190
|
def is_local_dist_rank_0():
|
|
167
191
|
return (
|
|
168
192
|
torch.distributed.is_available()
|
|
@@ -224,8 +248,7 @@ def get_torch_context_manager_or_global_device():
|
|
|
224
248
|
is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
|
|
225
249
|
"""
|
|
226
250
|
device_in_context = torch.tensor([]).device
|
|
227
|
-
|
|
228
|
-
default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu")
|
|
251
|
+
default_device = torch.get_default_device()
|
|
229
252
|
# This case means no context manager was used -> we still check if the default that was potentially set is not cpu
|
|
230
253
|
if device_in_context == default_device:
|
|
231
254
|
if default_device != torch.device("cpu"):
|
|
@@ -253,25 +276,22 @@ str_to_torch_dtype = {
|
|
|
253
276
|
"U8": torch.uint8,
|
|
254
277
|
"I8": torch.int8,
|
|
255
278
|
"I16": torch.int16,
|
|
279
|
+
"U16": torch.uint16,
|
|
256
280
|
"F16": torch.float16,
|
|
257
281
|
"BF16": torch.bfloat16,
|
|
258
282
|
"I32": torch.int32,
|
|
283
|
+
"U32": torch.uint32,
|
|
259
284
|
"F32": torch.float32,
|
|
260
285
|
"F64": torch.float64,
|
|
261
286
|
"I64": torch.int64,
|
|
287
|
+
"U64": torch.uint64,
|
|
262
288
|
"F8_E4M3": torch.float8_e4m3fn,
|
|
263
289
|
"F8_E5M2": torch.float8_e5m2,
|
|
264
290
|
}
|
|
265
291
|
|
|
266
292
|
|
|
267
|
-
if is_torch_greater_or_equal("2.3.0"):
|
|
268
|
-
str_to_torch_dtype["U16"] = torch.uint16
|
|
269
|
-
str_to_torch_dtype["U32"] = torch.uint32
|
|
270
|
-
str_to_torch_dtype["U64"] = torch.uint64
|
|
271
|
-
|
|
272
|
-
|
|
273
293
|
def load_state_dict(
|
|
274
|
-
checkpoint_file:
|
|
294
|
+
checkpoint_file: str | os.PathLike, map_location: str | torch.device = "cpu", weights_only: bool = True
|
|
275
295
|
) -> dict[str, torch.Tensor]:
|
|
276
296
|
"""
|
|
277
297
|
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
|
|
@@ -461,7 +481,7 @@ def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor
|
|
|
461
481
|
setattr(parent, param_type, tensor)
|
|
462
482
|
|
|
463
483
|
|
|
464
|
-
def _add_variant(weights_name: str, variant:
|
|
484
|
+
def _add_variant(weights_name: str, variant: str | None = None) -> str:
|
|
465
485
|
if variant is not None:
|
|
466
486
|
path, name = weights_name.rsplit(".", 1)
|
|
467
487
|
weights_name = f"{path}.{variant}.{name}"
|
|
@@ -469,19 +489,20 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|
|
469
489
|
|
|
470
490
|
|
|
471
491
|
def _get_resolved_checkpoint_files(
|
|
472
|
-
pretrained_model_name_or_path:
|
|
473
|
-
variant:
|
|
474
|
-
gguf_file:
|
|
475
|
-
use_safetensors:
|
|
476
|
-
|
|
477
|
-
user_agent: dict,
|
|
492
|
+
pretrained_model_name_or_path: str | os.PathLike | None,
|
|
493
|
+
variant: str | None,
|
|
494
|
+
gguf_file: str | None,
|
|
495
|
+
use_safetensors: bool | None,
|
|
496
|
+
user_agent: dict | None,
|
|
478
497
|
is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
|
|
479
|
-
transformers_explicit_filename:
|
|
480
|
-
|
|
498
|
+
transformers_explicit_filename: str | None = None,
|
|
499
|
+
download_kwargs: DownloadKwargs | None = None,
|
|
500
|
+
) -> tuple[list[str] | None, dict | None]:
|
|
481
501
|
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
|
482
502
|
checkpoints are sharded.
|
|
483
503
|
This function will download the data if necessary.
|
|
484
504
|
"""
|
|
505
|
+
download_kwargs = download_kwargs or DownloadKwargs()
|
|
485
506
|
cache_dir = download_kwargs.get("cache_dir")
|
|
486
507
|
force_download = download_kwargs.get("force_download", False)
|
|
487
508
|
proxies = download_kwargs.get("proxies")
|
|
@@ -494,17 +515,19 @@ def _get_resolved_checkpoint_files(
|
|
|
494
515
|
if not transformers_explicit_filename.endswith(".safetensors") and not transformers_explicit_filename.endswith(
|
|
495
516
|
".safetensors.index.json"
|
|
496
517
|
):
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
518
|
+
if transformers_explicit_filename != "adapter_model.bin":
|
|
519
|
+
raise ValueError(
|
|
520
|
+
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
|
|
521
|
+
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
|
|
522
|
+
f"{transformers_explicit_filename}"
|
|
523
|
+
)
|
|
502
524
|
|
|
503
525
|
is_sharded = False
|
|
504
526
|
|
|
505
527
|
if pretrained_model_name_or_path is not None and gguf_file is None:
|
|
506
528
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
|
507
529
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
|
530
|
+
# If the file is a local folder (but not in the HF_HOME cache, even if it's technically local)
|
|
508
531
|
if is_local:
|
|
509
532
|
if transformers_explicit_filename is not None:
|
|
510
533
|
# If the filename is explicitly defined, load this by default.
|
|
@@ -563,25 +586,38 @@ def _get_resolved_checkpoint_files(
|
|
|
563
586
|
else:
|
|
564
587
|
filename = _add_variant(WEIGHTS_NAME, variant)
|
|
565
588
|
|
|
589
|
+
# Prepare set of kwargs for hub functions
|
|
590
|
+
has_file_kwargs = {
|
|
591
|
+
"revision": revision,
|
|
592
|
+
"proxies": proxies,
|
|
593
|
+
"token": token,
|
|
594
|
+
"cache_dir": cache_dir,
|
|
595
|
+
"local_files_only": local_files_only,
|
|
596
|
+
}
|
|
597
|
+
cached_file_kwargs = {
|
|
598
|
+
"force_download": force_download,
|
|
599
|
+
"user_agent": user_agent,
|
|
600
|
+
"subfolder": subfolder,
|
|
601
|
+
"_raise_exceptions_for_gated_repo": False,
|
|
602
|
+
"_raise_exceptions_for_missing_entries": False,
|
|
603
|
+
"_commit_hash": commit_hash,
|
|
604
|
+
**has_file_kwargs,
|
|
605
|
+
}
|
|
606
|
+
can_auto_convert = (
|
|
607
|
+
not is_offline_mode() # for obvious reasons
|
|
608
|
+
# If we are in a CI environment or in a pytest run, we prevent the conversion
|
|
609
|
+
and not is_env_variable_true("DISABLE_SAFETENSORS_CONVERSION")
|
|
610
|
+
and not is_remote_code # converter bot does not work on remote code
|
|
611
|
+
and subfolder == "" # converter bot does not work on subfolders
|
|
612
|
+
)
|
|
613
|
+
|
|
566
614
|
try:
|
|
567
615
|
# Load from URL or cache if already cached
|
|
568
|
-
cached_file_kwargs = {
|
|
569
|
-
"cache_dir": cache_dir,
|
|
570
|
-
"force_download": force_download,
|
|
571
|
-
"proxies": proxies,
|
|
572
|
-
"local_files_only": local_files_only,
|
|
573
|
-
"token": token,
|
|
574
|
-
"user_agent": user_agent,
|
|
575
|
-
"revision": revision,
|
|
576
|
-
"subfolder": subfolder,
|
|
577
|
-
"_raise_exceptions_for_gated_repo": False,
|
|
578
|
-
"_raise_exceptions_for_missing_entries": False,
|
|
579
|
-
"_commit_hash": commit_hash,
|
|
580
|
-
}
|
|
581
|
-
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
|
582
|
-
|
|
583
616
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
|
584
617
|
# result when internet is up, the repo and revision exist, but the file does not.
|
|
618
|
+
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
|
619
|
+
|
|
620
|
+
# Try safetensors files first if not already found
|
|
585
621
|
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
|
|
586
622
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
|
587
623
|
resolved_archive_file = cached_file(
|
|
@@ -592,7 +628,7 @@ def _get_resolved_checkpoint_files(
|
|
|
592
628
|
if resolved_archive_file is not None:
|
|
593
629
|
is_sharded = True
|
|
594
630
|
elif use_safetensors:
|
|
595
|
-
if revision == "main" and
|
|
631
|
+
if revision == "main" and can_auto_convert:
|
|
596
632
|
resolved_archive_file, revision, is_sharded = auto_conversion(
|
|
597
633
|
pretrained_model_name_or_path, **cached_file_kwargs
|
|
598
634
|
)
|
|
@@ -609,6 +645,8 @@ def _get_resolved_checkpoint_files(
|
|
|
609
645
|
resolved_archive_file = cached_file(
|
|
610
646
|
pretrained_model_name_or_path, filename, **cached_file_kwargs
|
|
611
647
|
)
|
|
648
|
+
|
|
649
|
+
# Then try `.bin` files
|
|
612
650
|
if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
|
|
613
651
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
|
614
652
|
resolved_archive_file = cached_file(
|
|
@@ -618,67 +656,38 @@ def _get_resolved_checkpoint_files(
|
|
|
618
656
|
)
|
|
619
657
|
if resolved_archive_file is not None:
|
|
620
658
|
is_sharded = True
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
**has_file_kwargs,
|
|
649
|
-
}
|
|
650
|
-
if (
|
|
651
|
-
not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
|
|
652
|
-
and not is_remote_code
|
|
653
|
-
):
|
|
654
|
-
Thread(
|
|
655
|
-
target=auto_conversion,
|
|
656
|
-
args=(pretrained_model_name_or_path,),
|
|
657
|
-
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
|
|
658
|
-
name="Thread-auto_conversion",
|
|
659
|
-
).start()
|
|
659
|
+
|
|
660
|
+
# If we have a match, but it's `.bin` format, try to launch safetensors conversion for next time
|
|
661
|
+
if resolved_archive_file is not None:
|
|
662
|
+
safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
|
|
663
|
+
if (
|
|
664
|
+
filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
|
|
665
|
+
and not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
|
|
666
|
+
and can_auto_convert
|
|
667
|
+
):
|
|
668
|
+
Thread(
|
|
669
|
+
target=auto_conversion,
|
|
670
|
+
args=(pretrained_model_name_or_path,),
|
|
671
|
+
kwargs={"ignore_errors_during_conversion": False, **cached_file_kwargs},
|
|
672
|
+
name="Thread-auto_conversion",
|
|
673
|
+
).start()
|
|
674
|
+
|
|
675
|
+
# If no match, raise appropriare errors
|
|
676
|
+
else:
|
|
677
|
+
# Otherwise, no PyTorch file was found
|
|
678
|
+
if variant is not None and has_file(
|
|
679
|
+
pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
|
|
680
|
+
):
|
|
681
|
+
raise OSError(
|
|
682
|
+
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
683
|
+
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
|
|
684
|
+
f" {variant}. Use `variant=None` to load this model from those weights."
|
|
685
|
+
)
|
|
660
686
|
else:
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
"
|
|
664
|
-
|
|
665
|
-
"token": token,
|
|
666
|
-
"cache_dir": cache_dir,
|
|
667
|
-
"local_files_only": local_files_only,
|
|
668
|
-
}
|
|
669
|
-
if variant is not None and has_file(
|
|
670
|
-
pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
|
|
671
|
-
):
|
|
672
|
-
raise OSError(
|
|
673
|
-
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
674
|
-
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
|
|
675
|
-
f" {variant}. Use `variant=None` to load this model from those weights."
|
|
676
|
-
)
|
|
677
|
-
else:
|
|
678
|
-
raise OSError(
|
|
679
|
-
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
680
|
-
f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
|
|
681
|
-
)
|
|
687
|
+
raise OSError(
|
|
688
|
+
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
689
|
+
f" {_add_variant(WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_NAME, variant)}."
|
|
690
|
+
)
|
|
682
691
|
|
|
683
692
|
except OSError:
|
|
684
693
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
|
@@ -745,13 +754,13 @@ def _get_resolved_checkpoint_files(
|
|
|
745
754
|
|
|
746
755
|
|
|
747
756
|
def _get_dtype(
|
|
748
|
-
dtype:
|
|
749
|
-
checkpoint_files:
|
|
757
|
+
dtype: str | torch.dtype | dict | None,
|
|
758
|
+
checkpoint_files: list[str] | None,
|
|
750
759
|
config: PreTrainedConfig,
|
|
751
|
-
sharded_metadata:
|
|
752
|
-
state_dict:
|
|
760
|
+
sharded_metadata: dict | None,
|
|
761
|
+
state_dict: dict | None,
|
|
753
762
|
weights_only: bool,
|
|
754
|
-
hf_quantizer:
|
|
763
|
+
hf_quantizer: HfQuantizer | None = None,
|
|
755
764
|
) -> tuple[PreTrainedConfig, torch.dtype]:
|
|
756
765
|
"""Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
|
|
757
766
|
inferred dtype. We do the following:
|
|
@@ -760,7 +769,6 @@ def _get_dtype(
|
|
|
760
769
|
2. Else, use the dtype provided as a dict or str
|
|
761
770
|
"""
|
|
762
771
|
is_sharded = sharded_metadata is not None
|
|
763
|
-
asked_dtype = dtype
|
|
764
772
|
|
|
765
773
|
if dtype is not None:
|
|
766
774
|
if isinstance(dtype, str):
|
|
@@ -807,6 +815,13 @@ def _get_dtype(
|
|
|
807
815
|
if isinstance(dtype, dict):
|
|
808
816
|
main_dtype = dtype.get("", torch.get_default_dtype())
|
|
809
817
|
main_dtype = getattr(torch, main_dtype) if isinstance(main_dtype, str) else main_dtype
|
|
818
|
+
|
|
819
|
+
logger.warning_once(
|
|
820
|
+
"Using different dtypes per module is deprecated and will be removed in future versions "
|
|
821
|
+
"Setting different dtypes per backbone model might cause device errors downstream, therefore "
|
|
822
|
+
f"setting the dtype={main_dtype} for all modules."
|
|
823
|
+
)
|
|
824
|
+
|
|
810
825
|
else:
|
|
811
826
|
main_dtype = dtype
|
|
812
827
|
|
|
@@ -814,17 +829,7 @@ def _get_dtype(
|
|
|
814
829
|
config.dtype = main_dtype
|
|
815
830
|
for sub_config_key in config.sub_configs:
|
|
816
831
|
if (sub_config := getattr(config, sub_config_key)) is not None:
|
|
817
|
-
|
|
818
|
-
if asked_dtype == "auto":
|
|
819
|
-
sub_dtype = getattr(sub_config, "dtype", main_dtype)
|
|
820
|
-
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
|
|
821
|
-
# The dtype was provided as a dict, try to see if we match the subconfig name
|
|
822
|
-
elif isinstance(dtype, dict):
|
|
823
|
-
sub_dtype = dtype.get(sub_config_key, main_dtype)
|
|
824
|
-
sub_dtype = getattr(torch, sub_dtype) if isinstance(sub_dtype, str) else sub_dtype
|
|
825
|
-
else:
|
|
826
|
-
sub_dtype = main_dtype
|
|
827
|
-
sub_config.dtype = sub_dtype
|
|
832
|
+
sub_config.dtype = main_dtype
|
|
828
833
|
|
|
829
834
|
return config, main_dtype
|
|
830
835
|
|
|
@@ -877,13 +882,8 @@ class ModuleUtilsMixin:
|
|
|
877
882
|
return encoder_extended_attention_mask
|
|
878
883
|
|
|
879
884
|
@staticmethod
|
|
880
|
-
def create_extended_attention_mask_for_decoder(input_shape, attention_mask
|
|
881
|
-
|
|
882
|
-
warnings.warn(
|
|
883
|
-
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
|
884
|
-
)
|
|
885
|
-
else:
|
|
886
|
-
device = attention_mask.device
|
|
885
|
+
def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
|
|
886
|
+
device = attention_mask.device
|
|
887
887
|
batch_size, seq_length = input_shape
|
|
888
888
|
seq_ids = torch.arange(seq_length, device=device)
|
|
889
889
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
@@ -907,8 +907,7 @@ class ModuleUtilsMixin:
|
|
|
907
907
|
self,
|
|
908
908
|
attention_mask: Tensor,
|
|
909
909
|
input_shape: tuple[int, ...],
|
|
910
|
-
|
|
911
|
-
dtype: Optional[torch.dtype] = None,
|
|
910
|
+
dtype: torch.dtype | None = None,
|
|
912
911
|
) -> Tensor:
|
|
913
912
|
"""
|
|
914
913
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
|
@@ -925,12 +924,6 @@ class ModuleUtilsMixin:
|
|
|
925
924
|
if dtype is None:
|
|
926
925
|
dtype = self.dtype
|
|
927
926
|
|
|
928
|
-
if not (attention_mask.dim() == 2 and self.config.is_decoder):
|
|
929
|
-
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
|
|
930
|
-
if device is not None:
|
|
931
|
-
warnings.warn(
|
|
932
|
-
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
|
933
|
-
)
|
|
934
927
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
935
928
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
936
929
|
if attention_mask.dim() == 3:
|
|
@@ -939,9 +932,9 @@ class ModuleUtilsMixin:
|
|
|
939
932
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
940
933
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
941
934
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
942
|
-
if self.config
|
|
935
|
+
if getattr(self.config, "is_decoder", None):
|
|
943
936
|
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
|
|
944
|
-
input_shape, attention_mask
|
|
937
|
+
input_shape, attention_mask
|
|
945
938
|
)
|
|
946
939
|
else:
|
|
947
940
|
extended_attention_mask = attention_mask[:, None, None, :]
|
|
@@ -1112,83 +1105,67 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1112
1105
|
- **can_record_outputs** (dict):
|
|
1113
1106
|
"""
|
|
1114
1107
|
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
main_input_name = "input_ids"
|
|
1118
|
-
model_tags = None
|
|
1119
|
-
|
|
1120
|
-
_checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
|
|
1121
|
-
|
|
1108
|
+
# General model properties
|
|
1109
|
+
config_class: type[PreTrainedConfig] | None = None
|
|
1122
1110
|
_auto_class = None
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
_keep_in_fp32_modules = None
|
|
1127
|
-
# the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
|
|
1128
|
-
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
|
|
1129
|
-
_keep_in_fp32_modules_strict = None
|
|
1130
|
-
|
|
1131
|
-
dtype_plan: Optional[dict[str, torch.dtype]] = None
|
|
1132
|
-
|
|
1133
|
-
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
|
1134
|
-
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
|
1135
|
-
_keys_to_ignore_on_load_missing = None
|
|
1136
|
-
# a list of `re` patterns of `state_dict` keys that should be removed from the list of
|
|
1137
|
-
# unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
|
|
1138
|
-
# warnings.
|
|
1139
|
-
_keys_to_ignore_on_load_unexpected = None
|
|
1140
|
-
# a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
|
|
1141
|
-
# trained, but which are either deterministic or tied variables)
|
|
1142
|
-
_keys_to_ignore_on_save = None
|
|
1143
|
-
# a list of `state_dict` keys that are potentially tied to another key in the state_dict.
|
|
1144
|
-
_tied_weights_keys = None
|
|
1145
|
-
|
|
1146
|
-
supports_gradient_checkpointing = False
|
|
1147
|
-
_is_stateful = False
|
|
1148
|
-
|
|
1149
|
-
# Flash Attention support
|
|
1150
|
-
_supports_flash_attn = False
|
|
1151
|
-
|
|
1152
|
-
# SDPA support
|
|
1153
|
-
_supports_sdpa = False
|
|
1154
|
-
|
|
1155
|
-
# Flex Attention support
|
|
1156
|
-
_supports_flex_attn = False
|
|
1157
|
-
|
|
1158
|
-
_can_compile_fullgraph = False
|
|
1159
|
-
|
|
1160
|
-
# A tensor parallel plan to be applied to the model when TP is enabled. For
|
|
1161
|
-
# top-level models, this attribute is currently defined in respective model
|
|
1162
|
-
# code. For base models, this attribute comes from
|
|
1163
|
-
# `config.base_model_tp_plan` during `__init__`.
|
|
1164
|
-
# It should identify the layers exactly: if you want to TP model.language_model.layers.fc1
|
|
1165
|
-
# by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"}
|
|
1166
|
-
# for example.
|
|
1167
|
-
_tp_plan = None
|
|
1168
|
-
|
|
1169
|
-
# tensor parallel degree to which model is sharded to.
|
|
1170
|
-
_tp_size = None
|
|
1171
|
-
|
|
1172
|
-
# A pipeline parallel plan specifying the layers which may not be present
|
|
1173
|
-
# on all ranks when PP is enabled. For top-level models, this attribute is
|
|
1174
|
-
# currently defined in respective model code. For base models, this
|
|
1175
|
-
# attribute comes from `config.base_model_pp_plan` during `post_init`.
|
|
1176
|
-
#
|
|
1177
|
-
# The variable names for the inputs and outputs of the specified layers can
|
|
1178
|
-
# be indexed using the `PipelineParallel` enum as follows:
|
|
1179
|
-
# - `_pp_plan["layers"][PipelineParallel.inputs]`
|
|
1180
|
-
# - `_pp_plan["layers"][PipelineParallel.outputs]`
|
|
1181
|
-
_pp_plan = None
|
|
1111
|
+
base_model_prefix: str = ""
|
|
1112
|
+
_is_stateful: bool = False
|
|
1113
|
+
model_tags: list[str] | None = None
|
|
1182
1114
|
|
|
1115
|
+
# Input-related properties
|
|
1116
|
+
main_input_name: str = "input_ids"
|
|
1117
|
+
# Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
|
|
1118
|
+
# Possible values are: text, image, video, audio and time
|
|
1119
|
+
input_modalities: str | list[str] = "text"
|
|
1120
|
+
|
|
1121
|
+
# Device-map related properties
|
|
1122
|
+
_no_split_modules: set[str] | list[str] | None = None
|
|
1123
|
+
_skip_keys_device_placement: str | list[str] | None = None
|
|
1124
|
+
|
|
1125
|
+
# Specific dtype upcasting
|
|
1126
|
+
# `_keep_in_fp32_modules` will upcast to fp32 only if the requested dtype is fp16
|
|
1127
|
+
# `_keep_in_fp32_modules_strict` will upcast to fp32 independently if the requested dtype is fp16 or bf16
|
|
1128
|
+
_keep_in_fp32_modules: set[str] | list[str] | None = None
|
|
1129
|
+
_keep_in_fp32_modules_strict: set[str] | list[str] | None = None
|
|
1130
|
+
|
|
1131
|
+
# Loading-specific properties
|
|
1132
|
+
# A dictionary `{"target": "source"}` of checkpoint keys that are potentially tied to one another
|
|
1133
|
+
_tied_weights_keys: dict[str, str] = None
|
|
1134
|
+
# Used for BC support in VLMs, not meant to be used by new models
|
|
1135
|
+
_checkpoint_conversion_mapping: dict[str, str] = {}
|
|
1136
|
+
# A list of `re` patterns describing keys to ignore if they are missing from checkpoints to avoid warnings
|
|
1137
|
+
_keys_to_ignore_on_load_missing: list[str] | None = None
|
|
1138
|
+
# A list of `re` patterns describing keys to ignore if they are unexpected in the checkpoints to avoid warnings
|
|
1139
|
+
_keys_to_ignore_on_load_unexpected: list[str] | None = None
|
|
1140
|
+
# A list of keys to ignore when saving the model
|
|
1141
|
+
_keys_to_ignore_on_save: list[str] | None = None
|
|
1142
|
+
|
|
1143
|
+
# Attention interfaces support properties
|
|
1144
|
+
_supports_sdpa: bool = False
|
|
1145
|
+
_supports_flash_attn: bool = False
|
|
1146
|
+
_supports_flex_attn: bool = False
|
|
1147
|
+
|
|
1148
|
+
# Tensor-parallelism-related properties
|
|
1149
|
+
# A tensor parallel plan of the form `{"model.layer.mlp.param": "colwise"}` to be applied to the model when TP is enabled.
|
|
1150
|
+
# For top-level models, this attribute is currently defined in respective model code. For base models, this attribute comes
|
|
1151
|
+
# from `config.base_model_tp_plan` during `post_init`.
|
|
1152
|
+
_tp_plan: dict[str, str] = None
|
|
1153
|
+
# Tensor parallel degree to which model is sharded to
|
|
1154
|
+
_tp_size = None
|
|
1155
|
+
# A pipeline parallel plan specifying the layers which may not be present on all ranks when PP is enabled. For top-level
|
|
1156
|
+
# models, this attribute is currently defined in respective model code. For base models, it comes from
|
|
1157
|
+
# `config.base_model_pp_plan` during `post_init`.
|
|
1158
|
+
_pp_plan: dict[str, PipelineParallel] | None = None
|
|
1159
|
+
|
|
1160
|
+
# Advanced functionalities support
|
|
1161
|
+
supports_gradient_checkpointing: bool = False
|
|
1162
|
+
_can_compile_fullgraph: bool = False
|
|
1183
1163
|
# This flag signal that the model can be used as an efficient backend in TGI and vLLM
|
|
1184
1164
|
# In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
|
|
1185
1165
|
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
|
|
1186
|
-
_supports_attention_backend = False
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
# Attributes used mainly in multimodal LLMs, though all models contain a valid field for these
|
|
1190
|
-
# Possible values are: text, image, video, audio and time
|
|
1191
|
-
input_modalities: Union[str, list[str]] = "text" # most models are text
|
|
1166
|
+
_supports_attention_backend: bool = False
|
|
1167
|
+
# A mapping describing what outputs can be captured by `check_model_inputs` decorator during the forward pass
|
|
1168
|
+
_can_record_outputs: dict | None = None
|
|
1192
1169
|
|
|
1193
1170
|
@property
|
|
1194
1171
|
@torch._dynamo.allow_in_graph
|
|
@@ -1273,6 +1250,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1273
1250
|
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
|
1274
1251
|
)
|
|
1275
1252
|
self.config = config
|
|
1253
|
+
self.name_or_path = config.name_or_path
|
|
1276
1254
|
|
|
1277
1255
|
# Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
|
|
1278
1256
|
# setting it recursively)
|
|
@@ -1298,38 +1276,33 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1298
1276
|
loss_type = None
|
|
1299
1277
|
self.loss_type = loss_type
|
|
1300
1278
|
|
|
1301
|
-
self.name_or_path = config.name_or_path
|
|
1302
|
-
self.warnings_issued = {}
|
|
1303
|
-
# Overwrite the class attribute to make it an instance attribute, so models like
|
|
1304
|
-
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
|
|
1305
|
-
# when a different component (e.g. language_model) is used.
|
|
1306
|
-
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
|
|
1307
|
-
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
|
1308
|
-
self.dtype_plan = {}
|
|
1309
|
-
|
|
1310
|
-
if isinstance(self._keep_in_fp32_modules, list):
|
|
1311
|
-
self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
|
|
1312
|
-
if isinstance(self._keep_in_fp32_modules_strict, list):
|
|
1313
|
-
self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
|
|
1314
|
-
|
|
1315
|
-
self._no_split_modules = self._no_split_modules or []
|
|
1316
1279
|
_CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
|
|
1317
1280
|
|
|
1318
1281
|
def post_init(self):
|
|
1319
1282
|
"""
|
|
1320
1283
|
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
|
|
1321
1284
|
modules properly initialized (such as weight initialization).
|
|
1285
|
+
It is also used to obtain all correct static properties (parallelism plans, tied_weights_keys, _keep_in_fp32_modules, etc)
|
|
1286
|
+
correctly in the case of composite models (that is, the top level model should know about those properties from its children).
|
|
1322
1287
|
"""
|
|
1323
1288
|
# Attach the different parallel plans and tied weight keys to the top-most model, so that everything is
|
|
1324
1289
|
# easily available
|
|
1325
1290
|
self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {}
|
|
1326
|
-
# Current submodel should register its tied weights
|
|
1327
|
-
self.all_tied_weights_keys = self.get_expanded_tied_weights_keys(all_submodels=False)
|
|
1328
1291
|
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
|
1329
1292
|
if self.base_model is self:
|
|
1330
1293
|
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
|
|
1331
1294
|
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
|
1332
1295
|
self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
|
|
1296
|
+
# Current submodel should register its tied weights
|
|
1297
|
+
self.all_tied_weights_keys = self.get_expanded_tied_weights_keys(all_submodels=False)
|
|
1298
|
+
# Current submodel should register its `_keep_in_fp32_modules`
|
|
1299
|
+
self._keep_in_fp32_modules = set(self._keep_in_fp32_modules or [])
|
|
1300
|
+
self._keep_in_fp32_modules_strict = set(self._keep_in_fp32_modules_strict or [])
|
|
1301
|
+
# Current submodel must register its `_no_split_modules` as well
|
|
1302
|
+
self._no_split_modules = set(self._no_split_modules or [])
|
|
1303
|
+
|
|
1304
|
+
# Iterate over children only: as the final model is created, this is enough to gather the properties from all submodels.
|
|
1305
|
+
# This works because the way the `__init__` and `post_init` are called on all submodules is depth-first in the graph
|
|
1333
1306
|
for name, module in self.named_children():
|
|
1334
1307
|
# Parallel plans
|
|
1335
1308
|
if plan := getattr(module, "_ep_plan", None):
|
|
@@ -1341,6 +1314,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1341
1314
|
# Always attach the keys of the children (if the children's config says to NOT tie, then it's empty)
|
|
1342
1315
|
if tied_keys := getattr(module, "all_tied_weights_keys", None):
|
|
1343
1316
|
self.all_tied_weights_keys.update({f"{name}.{k}": f"{name}.{v}" for k, v in tied_keys.copy().items()})
|
|
1317
|
+
# Record keep_in_fp_32 modules from the children as well
|
|
1318
|
+
if keep_fp32 := getattr(module, "_keep_in_fp32_modules", None):
|
|
1319
|
+
self._keep_in_fp32_modules.update(keep_fp32)
|
|
1320
|
+
if keep_fp32_strict := getattr(module, "_keep_in_fp32_modules_strict", None):
|
|
1321
|
+
self._keep_in_fp32_modules_strict.update(keep_fp32_strict)
|
|
1322
|
+
# Record `_no_split_modules` from the children
|
|
1323
|
+
if no_split := getattr(module, "_no_split_modules", None):
|
|
1324
|
+
self._no_split_modules.update(no_split)
|
|
1344
1325
|
|
|
1345
1326
|
# Maybe initialize the weights and tie the keys
|
|
1346
1327
|
self.init_weights()
|
|
@@ -1417,7 +1398,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1417
1398
|
# Remove the attribute now that is has been consumed, so it's no saved in the config.
|
|
1418
1399
|
delattr(self.config, "gradient_checkpointing")
|
|
1419
1400
|
|
|
1420
|
-
def add_model_tags(self, tags:
|
|
1401
|
+
def add_model_tags(self, tags: list[str] | str) -> None:
|
|
1421
1402
|
r"""
|
|
1422
1403
|
Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
|
|
1423
1404
|
not overwrite existing tags in the model.
|
|
@@ -1784,7 +1765,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1784
1765
|
return True
|
|
1785
1766
|
|
|
1786
1767
|
def _check_and_adjust_attn_implementation(
|
|
1787
|
-
self, attn_implementation:
|
|
1768
|
+
self, attn_implementation: str | None, is_init_check: bool = False
|
|
1788
1769
|
) -> str:
|
|
1789
1770
|
"""
|
|
1790
1771
|
Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
|
|
@@ -1859,12 +1840,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1859
1840
|
)
|
|
1860
1841
|
|
|
1861
1842
|
# preload flash attention here to allow compile with fullgraph
|
|
1862
|
-
if
|
|
1843
|
+
if is_flash_attention_requested(requested_attention_implementation=applicable_attn_implementation):
|
|
1863
1844
|
lazy_import_flash_attention(applicable_attn_implementation)
|
|
1864
1845
|
|
|
1865
1846
|
return applicable_attn_implementation
|
|
1866
1847
|
|
|
1867
|
-
def _check_and_adjust_experts_implementation(self, experts_implementation:
|
|
1848
|
+
def _check_and_adjust_experts_implementation(self, experts_implementation: str | None) -> str:
|
|
1868
1849
|
"""
|
|
1869
1850
|
Check that the `experts_implementation` exists and is supported by the models.
|
|
1870
1851
|
|
|
@@ -1877,7 +1858,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1877
1858
|
applicable_experts_implementation = self.get_correct_experts_implementation(experts_implementation)
|
|
1878
1859
|
return applicable_experts_implementation
|
|
1879
1860
|
|
|
1880
|
-
def get_correct_attn_implementation(self, requested_attention:
|
|
1861
|
+
def get_correct_attn_implementation(self, requested_attention: str | None, is_init_check: bool = False) -> str:
|
|
1881
1862
|
applicable_attention = "sdpa" if requested_attention is None else requested_attention
|
|
1882
1863
|
if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
|
1883
1864
|
message = (
|
|
@@ -1911,7 +1892,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1911
1892
|
|
|
1912
1893
|
return applicable_attention
|
|
1913
1894
|
|
|
1914
|
-
def get_correct_experts_implementation(self, requested_experts:
|
|
1895
|
+
def get_correct_experts_implementation(self, requested_experts: str | None) -> str:
|
|
1915
1896
|
applicable_experts = "grouped_mm" if requested_experts is None else requested_experts
|
|
1916
1897
|
if applicable_experts not in ["eager", "grouped_mm", "batched_mm"]:
|
|
1917
1898
|
message = (
|
|
@@ -1936,15 +1917,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1936
1917
|
"""Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
|
|
1937
1918
|
opening the file, but avoids maintaining yet another property flag.
|
|
1938
1919
|
"""
|
|
1939
|
-
|
|
1940
|
-
|
|
1920
|
+
class_module = sys.modules[cls.__module__]
|
|
1921
|
+
# This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
|
|
1922
|
+
if not hasattr(class_module, "__file__"):
|
|
1923
|
+
return False
|
|
1924
|
+
class_file = class_module.__file__
|
|
1925
|
+
with open(class_file, "r", encoding="utf-8") as f:
|
|
1941
1926
|
code = f.read()
|
|
1942
1927
|
# heuristic -> if we find those patterns, the model uses the correct interface
|
|
1943
1928
|
if re.search(r"class \w+Attention\(nn.Module\)", code):
|
|
1944
|
-
return (
|
|
1945
|
-
"eager_attention_forward" in code
|
|
1946
|
-
and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code
|
|
1947
|
-
)
|
|
1929
|
+
return "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS.get_interface(" in code
|
|
1948
1930
|
else:
|
|
1949
1931
|
# If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
|
|
1950
1932
|
return True
|
|
@@ -1954,13 +1936,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
1954
1936
|
"""Detect whether the class supports setting its experts implementation dynamically. It is an ugly check based on
|
|
1955
1937
|
opening the file, but avoids maintaining yet another property flag.
|
|
1956
1938
|
"""
|
|
1957
|
-
|
|
1958
|
-
|
|
1939
|
+
class_module = sys.modules[cls.__module__]
|
|
1940
|
+
# This can happen for a custom model in a jupyter notebook or repl for example - simply do not allow to set it then
|
|
1941
|
+
if not hasattr(class_module, "__file__"):
|
|
1942
|
+
return False
|
|
1943
|
+
class_file = class_module.__file__
|
|
1944
|
+
with open(class_file, "r", encoding="utf-8") as f:
|
|
1959
1945
|
code = f.read()
|
|
1960
1946
|
# heuristic -> if we the use_experts_implementation decorator is used, then we can set it
|
|
1961
1947
|
return "@use_experts_implementation" in code
|
|
1962
1948
|
|
|
1963
|
-
def set_attn_implementation(self, attn_implementation:
|
|
1949
|
+
def set_attn_implementation(self, attn_implementation: str | dict):
|
|
1964
1950
|
"""
|
|
1965
1951
|
Set the requested `attn_implementation` for this model.
|
|
1966
1952
|
|
|
@@ -2059,7 +2045,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2059
2045
|
if hasattr(subconfig, "_attn_was_changed"):
|
|
2060
2046
|
del subconfig._attn_was_changed
|
|
2061
2047
|
|
|
2062
|
-
def set_experts_implementation(self, experts_implementation:
|
|
2048
|
+
def set_experts_implementation(self, experts_implementation: str | dict):
|
|
2063
2049
|
"""
|
|
2064
2050
|
Set the requested `experts_implementation` for this model.
|
|
2065
2051
|
|
|
@@ -2162,7 +2148,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2162
2148
|
if hasattr(self, "_require_grads_hook"):
|
|
2163
2149
|
del self._require_grads_hook
|
|
2164
2150
|
|
|
2165
|
-
def get_encoder(self, modality:
|
|
2151
|
+
def get_encoder(self, modality: str | None = None):
|
|
2166
2152
|
"""
|
|
2167
2153
|
Best-effort lookup of the *encoder* module. If provided with `modality` argument,
|
|
2168
2154
|
it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder")
|
|
@@ -2194,7 +2180,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2194
2180
|
# If this is a base transformer model (no encoder/model attributes), return self
|
|
2195
2181
|
return self
|
|
2196
2182
|
|
|
2197
|
-
def set_encoder(self, encoder, modality:
|
|
2183
|
+
def set_encoder(self, encoder, modality: str | None = None):
|
|
2198
2184
|
"""
|
|
2199
2185
|
Symmetric setter. Mirrors the lookup logic used in `get_encoder`.
|
|
2200
2186
|
"""
|
|
@@ -2421,7 +2407,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2421
2407
|
|
|
2422
2408
|
tied_mapping = self._tied_weights_keys
|
|
2423
2409
|
# If the config does not specify any tying, return empty dict
|
|
2424
|
-
|
|
2410
|
+
# NOTE: not all modules have `tie_word_embeddings` attr, for example vision-only
|
|
2411
|
+
# modules do not have any word embeddings!
|
|
2412
|
+
tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False)
|
|
2413
|
+
if not tie_word_embeddings:
|
|
2425
2414
|
return {}
|
|
2426
2415
|
# If None, return empty dict
|
|
2427
2416
|
elif tied_mapping is None:
|
|
@@ -2467,7 +2456,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2467
2456
|
|
|
2468
2457
|
return expanded_tied_weights
|
|
2469
2458
|
|
|
2470
|
-
def tie_weights(self, missing_keys:
|
|
2459
|
+
def tie_weights(self, missing_keys: set[str] | None = None, recompute_mapping: bool = True):
|
|
2471
2460
|
"""
|
|
2472
2461
|
Tie the model weights. If `recompute_mapping=False` (default when called internally), it will rely on the
|
|
2473
2462
|
`model.all_tied_weights_keys` attribute, containing the `{target: source}` mapping for the tied params.
|
|
@@ -2559,39 +2548,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2559
2548
|
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
|
2560
2549
|
output_embeddings.out_features = input_embeddings.num_embeddings
|
|
2561
2550
|
|
|
2562
|
-
def _get_no_split_modules(self, device_map: str):
|
|
2563
|
-
"""
|
|
2564
|
-
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
|
2565
|
-
get the underlying `_no_split_modules`.
|
|
2566
|
-
|
|
2567
|
-
Args:
|
|
2568
|
-
device_map (`str`):
|
|
2569
|
-
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
|
2570
|
-
|
|
2571
|
-
Returns:
|
|
2572
|
-
`list[str]`: List of modules that should not be split
|
|
2573
|
-
"""
|
|
2574
|
-
_no_split_modules = set()
|
|
2575
|
-
modules_to_check = [self]
|
|
2576
|
-
while len(modules_to_check) > 0:
|
|
2577
|
-
module = modules_to_check.pop(-1)
|
|
2578
|
-
# if the module does not appear in _no_split_modules, we also check the children
|
|
2579
|
-
if module.__class__.__name__ not in _no_split_modules:
|
|
2580
|
-
if isinstance(module, PreTrainedModel):
|
|
2581
|
-
if module._no_split_modules is None:
|
|
2582
|
-
raise ValueError(
|
|
2583
|
-
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
|
2584
|
-
"class needs to implement the `_no_split_modules` attribute."
|
|
2585
|
-
)
|
|
2586
|
-
else:
|
|
2587
|
-
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
|
2588
|
-
modules_to_check += list(module.children())
|
|
2589
|
-
return list(_no_split_modules)
|
|
2590
|
-
|
|
2591
2551
|
def resize_token_embeddings(
|
|
2592
2552
|
self,
|
|
2593
|
-
new_num_tokens:
|
|
2594
|
-
pad_to_multiple_of:
|
|
2553
|
+
new_num_tokens: int | None = None,
|
|
2554
|
+
pad_to_multiple_of: int | None = None,
|
|
2595
2555
|
mean_resizing: bool = True,
|
|
2596
2556
|
) -> nn.Embedding:
|
|
2597
2557
|
"""
|
|
@@ -2671,10 +2631,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2671
2631
|
new_num_tokens = new_embeddings.weight.shape[0]
|
|
2672
2632
|
|
|
2673
2633
|
# if word embeddings are not tied, make sure that lm head is resized as well
|
|
2674
|
-
if (
|
|
2675
|
-
self.get_output_embeddings() is not None
|
|
2676
|
-
and not self.config.get_text_config(decoder=True).tie_word_embeddings
|
|
2677
|
-
):
|
|
2634
|
+
if self.get_output_embeddings() is not None:
|
|
2678
2635
|
old_lm_head = self.get_output_embeddings()
|
|
2679
2636
|
if isinstance(old_lm_head, torch.nn.Embedding):
|
|
2680
2637
|
new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
|
|
@@ -2692,8 +2649,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2692
2649
|
def _get_resized_embeddings(
|
|
2693
2650
|
self,
|
|
2694
2651
|
old_embeddings: nn.Embedding,
|
|
2695
|
-
new_num_tokens:
|
|
2696
|
-
pad_to_multiple_of:
|
|
2652
|
+
new_num_tokens: int | None = None,
|
|
2653
|
+
pad_to_multiple_of: int | None = None,
|
|
2697
2654
|
mean_resizing: bool = True,
|
|
2698
2655
|
) -> nn.Embedding:
|
|
2699
2656
|
"""
|
|
@@ -2850,7 +2807,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
2850
2807
|
def _get_resized_lm_head(
|
|
2851
2808
|
self,
|
|
2852
2809
|
old_lm_head: nn.Linear,
|
|
2853
|
-
new_num_tokens:
|
|
2810
|
+
new_num_tokens: int | None = None,
|
|
2854
2811
|
transposed: bool = False,
|
|
2855
2812
|
mean_resizing: bool = True,
|
|
2856
2813
|
) -> nn.Linear:
|
|
@@ -3047,7 +3004,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3047
3004
|
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
|
|
3048
3005
|
)
|
|
3049
3006
|
|
|
3050
|
-
def get_position_embeddings(self) ->
|
|
3007
|
+
def get_position_embeddings(self) -> nn.Embedding | tuple[nn.Embedding]:
|
|
3051
3008
|
raise NotImplementedError(
|
|
3052
3009
|
f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
|
|
3053
3010
|
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
|
|
@@ -3055,15 +3012,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3055
3012
|
|
|
3056
3013
|
def init_weights(self):
|
|
3057
3014
|
"""
|
|
3058
|
-
|
|
3015
|
+
Initialize and tie the weights if needed. If using a custom `PreTrainedModel`, you need to implement any
|
|
3059
3016
|
initialization logic in `_init_weights`.
|
|
3060
3017
|
"""
|
|
3061
3018
|
# If we are initializing on meta device, there is no point in trying to run inits
|
|
3062
3019
|
if get_torch_context_manager_or_global_device() != torch.device("meta"):
|
|
3063
3020
|
# Initialize weights
|
|
3064
3021
|
self.initialize_weights()
|
|
3065
|
-
|
|
3066
|
-
|
|
3022
|
+
# Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys`
|
|
3023
|
+
self.tie_weights(recompute_mapping=False)
|
|
3067
3024
|
|
|
3068
3025
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
|
3069
3026
|
"""
|
|
@@ -3080,7 +3037,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3080
3037
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
|
3081
3038
|
|
|
3082
3039
|
if gradient_checkpointing_kwargs is None:
|
|
3083
|
-
gradient_checkpointing_kwargs = {"use_reentrant":
|
|
3040
|
+
gradient_checkpointing_kwargs = {"use_reentrant": False}
|
|
3084
3041
|
|
|
3085
3042
|
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
|
3086
3043
|
|
|
@@ -3158,13 +3115,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3158
3115
|
|
|
3159
3116
|
def save_pretrained(
|
|
3160
3117
|
self,
|
|
3161
|
-
save_directory:
|
|
3118
|
+
save_directory: str | os.PathLike,
|
|
3162
3119
|
is_main_process: bool = True,
|
|
3163
|
-
state_dict:
|
|
3120
|
+
state_dict: dict | None = None,
|
|
3164
3121
|
push_to_hub: bool = False,
|
|
3165
|
-
max_shard_size:
|
|
3166
|
-
variant:
|
|
3167
|
-
token:
|
|
3122
|
+
max_shard_size: int | str = "50GB",
|
|
3123
|
+
variant: str | None = None,
|
|
3124
|
+
token: str | bool | None = None,
|
|
3168
3125
|
save_peft_format: bool = True,
|
|
3169
3126
|
save_original_format: bool = True,
|
|
3170
3127
|
**kwargs,
|
|
@@ -3231,12 +3188,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3231
3188
|
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
|
3232
3189
|
)
|
|
3233
3190
|
|
|
3234
|
-
if "save_config" in kwargs:
|
|
3235
|
-
warnings.warn(
|
|
3236
|
-
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
|
|
3237
|
-
)
|
|
3238
|
-
is_main_process = kwargs.pop("save_config")
|
|
3239
|
-
|
|
3240
3191
|
# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
|
|
3241
3192
|
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
|
|
3242
3193
|
raise ImportError(
|
|
@@ -3339,16 +3290,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3339
3290
|
if ignore_key in state_dict:
|
|
3340
3291
|
del state_dict[ignore_key]
|
|
3341
3292
|
|
|
3342
|
-
# If model was sharded
|
|
3343
|
-
# therefore we replace them with DTensors that are equivalently sharded
|
|
3293
|
+
# If model was sharded with TP, gather full tensors for saving
|
|
3344
3294
|
if self._tp_size is not None:
|
|
3345
|
-
state_dict =
|
|
3295
|
+
state_dict = gather_state_dict_for_save(state_dict, self._tp_plan, self._device_mesh, self._tp_size)
|
|
3346
3296
|
|
|
3347
3297
|
# Remove tied weights as safetensors do not handle them
|
|
3348
3298
|
state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save)
|
|
3349
3299
|
|
|
3350
3300
|
# Revert all renaming and/or weight operations
|
|
3351
|
-
if save_original_format:
|
|
3301
|
+
if save_original_format and not _hf_peft_config_loaded:
|
|
3352
3302
|
state_dict = revert_weight_conversion(model_to_save, state_dict)
|
|
3353
3303
|
|
|
3354
3304
|
# Shard the model if it is too big.
|
|
@@ -3400,13 +3350,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3400
3350
|
# Get the tensor, and remove it from state_dict to avoid keeping the ref
|
|
3401
3351
|
tensor = state_dict.pop(tensor_name)
|
|
3402
3352
|
|
|
3403
|
-
# In case of TP, get the full parameter back
|
|
3404
|
-
if _is_dtensor_available and isinstance(tensor, DTensor):
|
|
3405
|
-
tensor = tensor.full_tensor()
|
|
3406
|
-
# to get the correctly ordered tensor we need to repack if packed
|
|
3407
|
-
if _get_parameter_tp_plan(tensor_name, self._tp_plan) == "local_packed_rowwise":
|
|
3408
|
-
tensor = repack_weights(tensor, -1, self._tp_size, 2)
|
|
3409
|
-
|
|
3410
3353
|
# If the param was offloaded, we need to load it back from disk to resave it. It's a strange pattern,
|
|
3411
3354
|
# but it would otherwise not be contained in the saved shard if we were to simply move the file
|
|
3412
3355
|
# or something
|
|
@@ -3564,10 +3507,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3564
3507
|
" desired `dtype` by passing the correct `dtype` argument."
|
|
3565
3508
|
)
|
|
3566
3509
|
|
|
3567
|
-
if getattr(self, "is_loaded_in_8bit", False):
|
|
3510
|
+
if getattr(self, "is_loaded_in_8bit", False) and not is_bitsandbytes_available("0.48"):
|
|
3568
3511
|
raise ValueError(
|
|
3569
|
-
"
|
|
3570
|
-
" model has already been set to the correct devices and casted to the correct `dtype`."
|
|
3512
|
+
"You need to install `pip install bitsandbytes>=0.48.0` if you want to move a 8-bit model across devices using to()."
|
|
3571
3513
|
)
|
|
3572
3514
|
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
|
|
3573
3515
|
if dtype_present_in_args:
|
|
@@ -3600,7 +3542,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3600
3542
|
@classmethod
|
|
3601
3543
|
def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool):
|
|
3602
3544
|
# Need to instantiate with correct dtype
|
|
3603
|
-
init_contexts = [local_torch_dtype(dtype, cls.__name__)]
|
|
3545
|
+
init_contexts = [local_torch_dtype(dtype, cls.__name__), init.no_tie_weights()]
|
|
3604
3546
|
if is_deepspeed_zero3_enabled():
|
|
3605
3547
|
import deepspeed
|
|
3606
3548
|
|
|
@@ -3621,7 +3563,31 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3621
3563
|
|
|
3622
3564
|
return init_contexts
|
|
3623
3565
|
|
|
3624
|
-
def
|
|
3566
|
+
def _get_dtype_plan(self, dtype: torch.dtype) -> dict:
|
|
3567
|
+
"""Create the dtype_plan describing modules/parameters that should use the `keep_in_fp32` flag."""
|
|
3568
|
+
dtype_plan = {}
|
|
3569
|
+
|
|
3570
|
+
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
|
|
3571
|
+
# in case of force loading a model that should stay in bf16 in fp16
|
|
3572
|
+
# See https://github.com/huggingface/transformers/issues/20287 for details.
|
|
3573
|
+
if self._keep_in_fp32_modules is not None and dtype == torch.float16:
|
|
3574
|
+
dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
|
|
3575
|
+
|
|
3576
|
+
# The _keep_in_fp32_modules_strict was introduced to always force upcast to fp32, for both fp16 and bf16
|
|
3577
|
+
if self._keep_in_fp32_modules_strict is not None and dtype in (torch.float16, torch.bfloat16):
|
|
3578
|
+
dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
|
|
3579
|
+
|
|
3580
|
+
return dtype_plan
|
|
3581
|
+
|
|
3582
|
+
def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None):
|
|
3583
|
+
"""
|
|
3584
|
+
Set whether or not to use the `kernels` library to kernelize some layers of the model.
|
|
3585
|
+
Args:
|
|
3586
|
+
use_kernels (`bool`):
|
|
3587
|
+
Whether or not to use the `kernels` library to kernelize some layers of the model.
|
|
3588
|
+
kernel_config (`KernelConfig`, *optional*):
|
|
3589
|
+
The kernel configuration to use to kernelize the model. If `None`, the default kernel mapping will be used.
|
|
3590
|
+
"""
|
|
3625
3591
|
if use_kernels:
|
|
3626
3592
|
if not is_kernels_available():
|
|
3627
3593
|
raise ValueError(
|
|
@@ -3655,16 +3621,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
3655
3621
|
@classmethod
|
|
3656
3622
|
def from_pretrained(
|
|
3657
3623
|
cls: type[SpecificPreTrainedModelType],
|
|
3658
|
-
pretrained_model_name_or_path:
|
|
3624
|
+
pretrained_model_name_or_path: str | os.PathLike | None,
|
|
3659
3625
|
*model_args,
|
|
3660
|
-
config:
|
|
3661
|
-
cache_dir:
|
|
3626
|
+
config: PreTrainedConfig | str | os.PathLike | None = None,
|
|
3627
|
+
cache_dir: str | os.PathLike | None = None,
|
|
3662
3628
|
ignore_mismatched_sizes: bool = False,
|
|
3663
3629
|
force_download: bool = False,
|
|
3664
3630
|
local_files_only: bool = False,
|
|
3665
|
-
token:
|
|
3631
|
+
token: str | bool | None = None,
|
|
3666
3632
|
revision: str = "main",
|
|
3667
|
-
use_safetensors:
|
|
3633
|
+
use_safetensors: bool | None = None,
|
|
3668
3634
|
weights_only: bool = True,
|
|
3669
3635
|
**kwargs,
|
|
3670
3636
|
) -> SpecificPreTrainedModelType:
|
|
@@ -4063,6 +4029,10 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4063
4029
|
use_kernels=use_kernels,
|
|
4064
4030
|
)
|
|
4065
4031
|
|
|
4032
|
+
# Create the dtype_plan to potentially use the `keep_in_fp32` flags (this needs to be called on the already
|
|
4033
|
+
# instantiated model, as the flags can be modified by instances sometimes)
|
|
4034
|
+
dtype_plan = model._get_dtype_plan(dtype)
|
|
4035
|
+
|
|
4066
4036
|
# Obtain the weight conversion mapping for this model if any are registered
|
|
4067
4037
|
weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
|
|
4068
4038
|
|
|
@@ -4074,29 +4044,30 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4074
4044
|
device_map = _get_device_map(model, device_map, max_memory, hf_quantizer)
|
|
4075
4045
|
|
|
4076
4046
|
# Finalize model weight initialization
|
|
4077
|
-
|
|
4078
|
-
|
|
4079
|
-
state_dict,
|
|
4080
|
-
checkpoint_files,
|
|
4081
|
-
pretrained_model_name_or_path,
|
|
4047
|
+
load_config = LoadStateDictConfig(
|
|
4048
|
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
4082
4049
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
4083
4050
|
sharded_metadata=sharded_metadata,
|
|
4084
4051
|
device_map=device_map,
|
|
4085
4052
|
disk_offload_folder=offload_folder,
|
|
4086
4053
|
offload_buffers=offload_buffers,
|
|
4087
4054
|
dtype=dtype,
|
|
4055
|
+
dtype_plan=dtype_plan,
|
|
4088
4056
|
hf_quantizer=hf_quantizer,
|
|
4089
4057
|
device_mesh=device_mesh,
|
|
4090
4058
|
weights_only=weights_only,
|
|
4091
4059
|
weight_mapping=weight_conversions,
|
|
4060
|
+
use_safetensors=use_safetensors,
|
|
4061
|
+
download_kwargs=download_kwargs,
|
|
4092
4062
|
)
|
|
4093
|
-
|
|
4063
|
+
loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
|
|
4064
|
+
loading_info = cls._finalize_model_loading(model, load_config, loading_info)
|
|
4094
4065
|
model.eval() # Set model in evaluation mode to deactivate Dropout modules by default
|
|
4095
4066
|
model.set_use_kernels(use_kernels, kernel_config)
|
|
4096
4067
|
|
|
4097
4068
|
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
|
4098
4069
|
# custom generate function)
|
|
4099
|
-
if model.can_generate() and hasattr(model, "adjust_generation_fn"):
|
|
4070
|
+
if model.can_generate() and hasattr(model, "adjust_generation_fn") and not gguf_file:
|
|
4100
4071
|
model.adjust_generation_fn(
|
|
4101
4072
|
generation_config,
|
|
4102
4073
|
from_auto_class,
|
|
@@ -4109,7 +4080,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4109
4080
|
|
|
4110
4081
|
# If the device_map has more than 1 device: dispatch model with hooks on all devices
|
|
4111
4082
|
if device_map is not None and len(set(device_map.values())) > 1:
|
|
4112
|
-
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder,
|
|
4083
|
+
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, disk_offload_index, offload_buffers)
|
|
4113
4084
|
|
|
4114
4085
|
if hf_quantizer is not None:
|
|
4115
4086
|
model.hf_quantizer = hf_quantizer
|
|
@@ -4118,44 +4089,29 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4118
4089
|
) # usually a no-op but sometimes needed, e.g to remove the quant config when dequantizing
|
|
4119
4090
|
|
|
4120
4091
|
if _adapter_model_path is not None:
|
|
4121
|
-
|
|
4122
|
-
|
|
4092
|
+
if token is not None:
|
|
4093
|
+
adapter_kwargs["token"] = token
|
|
4094
|
+
loading_info = model.load_adapter(
|
|
4123
4095
|
_adapter_model_path,
|
|
4124
4096
|
adapter_name=adapter_name,
|
|
4125
|
-
|
|
4097
|
+
load_config=load_config,
|
|
4126
4098
|
adapter_kwargs=adapter_kwargs,
|
|
4127
4099
|
)
|
|
4128
4100
|
|
|
4129
4101
|
if output_loading_info:
|
|
4130
|
-
|
|
4131
|
-
"missing_keys": missing_keys,
|
|
4132
|
-
"unexpected_keys": unexpected_keys,
|
|
4133
|
-
"mismatched_keys": mismatched_keys,
|
|
4134
|
-
"error_msgs": error_msgs,
|
|
4135
|
-
}
|
|
4136
|
-
return model, loading_info
|
|
4102
|
+
return model, loading_info.to_dict()
|
|
4137
4103
|
return model
|
|
4138
4104
|
|
|
4139
|
-
@
|
|
4105
|
+
@staticmethod
|
|
4140
4106
|
def _load_pretrained_model(
|
|
4141
|
-
cls,
|
|
4142
4107
|
model: "PreTrainedModel",
|
|
4143
|
-
state_dict:
|
|
4144
|
-
checkpoint_files:
|
|
4145
|
-
|
|
4146
|
-
|
|
4147
|
-
|
|
4148
|
-
|
|
4149
|
-
|
|
4150
|
-
offload_buffers: bool = False,
|
|
4151
|
-
dtype: Optional[torch.dtype] = None,
|
|
4152
|
-
hf_quantizer: Optional[HfQuantizer] = None,
|
|
4153
|
-
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
|
4154
|
-
weights_only: bool = True,
|
|
4155
|
-
weight_mapping: Optional[Sequence[WeightConverter | WeightRenaming]] = None,
|
|
4156
|
-
):
|
|
4157
|
-
is_quantized = hf_quantizer is not None
|
|
4158
|
-
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
|
4108
|
+
state_dict: dict | None,
|
|
4109
|
+
checkpoint_files: list[str] | None,
|
|
4110
|
+
load_config: LoadStateDictConfig,
|
|
4111
|
+
) -> tuple[LoadStateDictInfo, dict]:
|
|
4112
|
+
"""Perform the actual loading of some checkpoints into a `model`, by reading them from disk and dispatching them accordingly."""
|
|
4113
|
+
is_quantized = load_config.is_quantized
|
|
4114
|
+
is_hqq_or_quark = is_quantized and load_config.hf_quantizer.quantization_config.quant_method in {
|
|
4159
4115
|
QuantizationMethod.HQQ,
|
|
4160
4116
|
QuantizationMethod.QUARK,
|
|
4161
4117
|
}
|
|
@@ -4169,21 +4125,21 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4169
4125
|
# This offload index if for params explicitly on the "disk" in the device_map
|
|
4170
4126
|
disk_offload_index = None
|
|
4171
4127
|
# Prepare parameters offloading if needed
|
|
4172
|
-
if device_map is not None and "disk" in device_map.values():
|
|
4128
|
+
if load_config.device_map is not None and "disk" in load_config.device_map.values():
|
|
4173
4129
|
disk_offload_index = accelerate_disk_offload(
|
|
4174
4130
|
model,
|
|
4175
|
-
disk_offload_folder,
|
|
4131
|
+
load_config.disk_offload_folder,
|
|
4176
4132
|
checkpoint_files,
|
|
4177
|
-
device_map,
|
|
4178
|
-
sharded_metadata,
|
|
4179
|
-
dtype,
|
|
4180
|
-
weight_mapping,
|
|
4133
|
+
load_config.device_map,
|
|
4134
|
+
load_config.sharded_metadata,
|
|
4135
|
+
load_config.dtype,
|
|
4136
|
+
load_config.weight_mapping,
|
|
4181
4137
|
)
|
|
4182
4138
|
|
|
4183
4139
|
# Warmup cuda to load the weights much faster on devices
|
|
4184
|
-
if device_map is not None and not is_hqq_or_quark:
|
|
4185
|
-
expanded_device_map = expand_device_map(device_map, expected_keys)
|
|
4186
|
-
caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
|
|
4140
|
+
if load_config.device_map is not None and not is_hqq_or_quark:
|
|
4141
|
+
expanded_device_map = expand_device_map(load_config.device_map, expected_keys)
|
|
4142
|
+
caching_allocator_warmup(model, expanded_device_map, load_config.hf_quantizer)
|
|
4187
4143
|
|
|
4188
4144
|
error_msgs = []
|
|
4189
4145
|
|
|
@@ -4191,24 +4147,30 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4191
4147
|
if state_dict is None:
|
|
4192
4148
|
merged_state_dict = {}
|
|
4193
4149
|
for ckpt_file in checkpoint_files:
|
|
4194
|
-
merged_state_dict.update(
|
|
4150
|
+
merged_state_dict.update(
|
|
4151
|
+
load_state_dict(ckpt_file, map_location="cpu", weights_only=load_config.weights_only)
|
|
4152
|
+
)
|
|
4195
4153
|
state_dict = merged_state_dict
|
|
4196
|
-
error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict)
|
|
4154
|
+
error_msgs, missing_keys = _load_state_dict_into_zero3_model(model, state_dict, load_config)
|
|
4197
4155
|
# This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
|
|
4198
|
-
|
|
4156
|
+
loading_info = LoadStateDictInfo(
|
|
4157
|
+
missing_keys=missing_keys,
|
|
4158
|
+
error_msgs=error_msgs,
|
|
4159
|
+
unexpected_keys=set(),
|
|
4160
|
+
mismatched_keys=set(),
|
|
4161
|
+
conversion_errors={},
|
|
4162
|
+
)
|
|
4199
4163
|
else:
|
|
4200
4164
|
all_pointer = set()
|
|
4201
|
-
|
|
4202
|
-
|
|
4165
|
+
if state_dict is not None:
|
|
4166
|
+
merged_state_dict = state_dict
|
|
4167
|
+
elif checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") and state_dict is None:
|
|
4203
4168
|
merged_state_dict = {}
|
|
4204
4169
|
for file in checkpoint_files:
|
|
4205
4170
|
file_pointer = safe_open(file, framework="pt", device="cpu")
|
|
4206
4171
|
all_pointer.add(file_pointer)
|
|
4207
4172
|
for k in file_pointer.keys():
|
|
4208
4173
|
merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
|
|
4209
|
-
# User passed an explicit state_dict
|
|
4210
|
-
elif state_dict is not None:
|
|
4211
|
-
merged_state_dict = state_dict
|
|
4212
4174
|
# Checkpoints are .bin
|
|
4213
4175
|
elif checkpoint_files is not None:
|
|
4214
4176
|
merged_state_dict = {}
|
|
@@ -4217,58 +4179,58 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4217
4179
|
else:
|
|
4218
4180
|
raise ValueError("Neither a state dict nor checkpoint files were found.")
|
|
4219
4181
|
|
|
4220
|
-
|
|
4221
|
-
|
|
4222
|
-
|
|
4223
|
-
|
|
4224
|
-
|
|
4225
|
-
|
|
4226
|
-
hf_quantizer=hf_quantizer,
|
|
4227
|
-
dtype=dtype,
|
|
4228
|
-
device_map=device_map,
|
|
4229
|
-
dtype_plan=model.dtype_plan,
|
|
4230
|
-
device_mesh=device_mesh,
|
|
4231
|
-
disk_offload_index=disk_offload_index,
|
|
4232
|
-
disk_offload_folder=disk_offload_folder,
|
|
4233
|
-
offload_buffers=offload_buffers,
|
|
4234
|
-
)
|
|
4182
|
+
loading_info, disk_offload_index = convert_and_load_state_dict_in_model(
|
|
4183
|
+
model=model,
|
|
4184
|
+
state_dict=merged_state_dict,
|
|
4185
|
+
load_config=load_config,
|
|
4186
|
+
tp_plan=model._tp_plan,
|
|
4187
|
+
disk_offload_index=disk_offload_index,
|
|
4235
4188
|
)
|
|
4236
4189
|
|
|
4237
4190
|
# finally close all opened file pointers
|
|
4238
4191
|
for k in all_pointer:
|
|
4239
4192
|
k.__exit__(None, None, None)
|
|
4240
4193
|
|
|
4241
|
-
|
|
4242
|
-
model.mark_tied_weights_as_initialized()
|
|
4243
|
-
|
|
4244
|
-
# Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
|
|
4245
|
-
# meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
|
|
4246
|
-
missing_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
|
|
4247
|
-
model._move_missing_keys_from_meta_to_device(missing_and_mismatched, device_map, device_mesh, hf_quantizer)
|
|
4194
|
+
return loading_info, disk_offload_index
|
|
4248
4195
|
|
|
4249
|
-
|
|
4250
|
-
|
|
4251
|
-
|
|
4252
|
-
|
|
4253
|
-
model
|
|
4196
|
+
@staticmethod
|
|
4197
|
+
def _finalize_model_loading(
|
|
4198
|
+
model, load_config: LoadStateDictConfig, loading_info: LoadStateDictInfo
|
|
4199
|
+
) -> LoadStateDictInfo:
|
|
4200
|
+
"""Perform all post processing operations after having loaded some checkpoints into a model, such as moving
|
|
4201
|
+
missing keys from meta device to their expected device, reinitializing missing weights according to proper
|
|
4202
|
+
distributions, tying the weights and logging the loading report."""
|
|
4203
|
+
try:
|
|
4204
|
+
# Marks tied weights as `_is_hf_initialized` to avoid initializing them (it's very important for efficiency)
|
|
4205
|
+
model.mark_tied_weights_as_initialized()
|
|
4206
|
+
|
|
4207
|
+
# Move missing (and potentially mismatched) keys and non-persistent buffers back to their expected device from
|
|
4208
|
+
# meta device (because they were not moved when loading the weights as they were not in the loaded state dict)
|
|
4209
|
+
model._move_missing_keys_from_meta_to_device(
|
|
4210
|
+
loading_info.missing_and_mismatched(),
|
|
4211
|
+
load_config.device_map,
|
|
4212
|
+
load_config.device_mesh,
|
|
4213
|
+
load_config.hf_quantizer,
|
|
4214
|
+
)
|
|
4254
4215
|
|
|
4255
|
-
|
|
4256
|
-
|
|
4216
|
+
# Correctly initialize the missing (and potentially mismatched) keys (all parameters without the `_is_hf_initialized` flag)
|
|
4217
|
+
model._initialize_missing_keys(load_config.is_quantized)
|
|
4257
4218
|
|
|
4258
|
-
|
|
4259
|
-
model=
|
|
4260
|
-
|
|
4261
|
-
|
|
4262
|
-
|
|
4263
|
-
|
|
4264
|
-
|
|
4265
|
-
|
|
4266
|
-
|
|
4267
|
-
|
|
4268
|
-
|
|
4269
|
-
|
|
4219
|
+
# Tie the weights
|
|
4220
|
+
model.tie_weights(missing_keys=loading_info.missing_keys, recompute_mapping=False)
|
|
4221
|
+
|
|
4222
|
+
# Adjust missing and unexpected keys
|
|
4223
|
+
model._adjust_missing_and_unexpected_keys(loading_info)
|
|
4224
|
+
finally:
|
|
4225
|
+
log_state_dict_report(
|
|
4226
|
+
model=model,
|
|
4227
|
+
pretrained_model_name_or_path=load_config.pretrained_model_name_or_path,
|
|
4228
|
+
ignore_mismatched_sizes=load_config.ignore_mismatched_sizes,
|
|
4229
|
+
loading_info=loading_info,
|
|
4230
|
+
logger=logger,
|
|
4231
|
+
)
|
|
4270
4232
|
|
|
4271
|
-
return
|
|
4233
|
+
return loading_info
|
|
4272
4234
|
|
|
4273
4235
|
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
|
|
4274
4236
|
module_keys = {".".join(key.split(".")[:-1]) for key in names}
|
|
@@ -4337,15 +4299,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4337
4299
|
|
|
4338
4300
|
# If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
|
|
4339
4301
|
# attention_mask or not. In this case, we should still show a warning because this is a rare case.
|
|
4302
|
+
# NOTE: `sep_token_id` is not used in all models and it can be absent in the config
|
|
4303
|
+
sep_token_id = getattr(self.config, "sep_token_id", None)
|
|
4340
4304
|
if (
|
|
4341
4305
|
(self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
|
|
4342
4306
|
or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
|
|
4343
|
-
or (
|
|
4307
|
+
or (sep_token_id is not None and sep_token_id == self.config.pad_token_id)
|
|
4344
4308
|
):
|
|
4345
4309
|
warn_string += (
|
|
4346
4310
|
f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
|
|
4347
4311
|
f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
|
|
4348
|
-
f"or the `sep_token_id` ({
|
|
4312
|
+
f"or the `sep_token_id` ({sep_token_id}), and your input is not padded."
|
|
4349
4313
|
)
|
|
4350
4314
|
|
|
4351
4315
|
logger.warning_once(warn_string)
|
|
@@ -4430,7 +4394,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4430
4394
|
)
|
|
4431
4395
|
self._use_kernels = False
|
|
4432
4396
|
|
|
4433
|
-
def get_compiled_call(self, compile_config:
|
|
4397
|
+
def get_compiled_call(self, compile_config: CompileConfig | None) -> Callable:
|
|
4434
4398
|
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
|
|
4435
4399
|
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
|
|
4436
4400
|
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
|
|
@@ -4522,11 +4486,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4522
4486
|
else:
|
|
4523
4487
|
self.initialize_weights()
|
|
4524
4488
|
|
|
4525
|
-
def _adjust_missing_and_unexpected_keys(
|
|
4526
|
-
self, missing_keys: set[str], unexpected_keys: set[str]
|
|
4527
|
-
) -> tuple[set[str], set[str]]:
|
|
4489
|
+
def _adjust_missing_and_unexpected_keys(self, loading_info: LoadStateDictInfo) -> None:
|
|
4528
4490
|
"""Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
|
|
4529
|
-
raising unneeded warnings/errors.
|
|
4491
|
+
raising unneeded warnings/errors. This is performed in-place.
|
|
4530
4492
|
"""
|
|
4531
4493
|
# Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
|
|
4532
4494
|
# (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
|
|
@@ -4544,13 +4506,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|
|
4544
4506
|
|
|
4545
4507
|
# Clean-up missing keys
|
|
4546
4508
|
if ignore_missing_regex is not None:
|
|
4547
|
-
missing_keys = {
|
|
4509
|
+
loading_info.missing_keys = {
|
|
4510
|
+
key for key in loading_info.missing_keys if ignore_missing_regex.search(key) is None
|
|
4511
|
+
}
|
|
4548
4512
|
|
|
4549
4513
|
# Clean-up unexpected keys
|
|
4550
4514
|
if ignore_unexpected_regex is not None:
|
|
4551
|
-
unexpected_keys = {
|
|
4552
|
-
|
|
4553
|
-
|
|
4515
|
+
loading_info.unexpected_keys = {
|
|
4516
|
+
key for key in loading_info.unexpected_keys if ignore_unexpected_regex.search(key) is None
|
|
4517
|
+
}
|
|
4554
4518
|
|
|
4555
4519
|
def mark_tied_weights_as_initialized(self):
|
|
4556
4520
|
"""Adds the `_is_hf_initialized` flag on parameters that will be tied, in order to avoid initializing them
|
|
@@ -4640,7 +4604,7 @@ def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
|
|
|
4640
4604
|
return model
|
|
4641
4605
|
|
|
4642
4606
|
|
|
4643
|
-
def is_accelerator_device(device:
|
|
4607
|
+
def is_accelerator_device(device: str | int | torch.device) -> bool:
|
|
4644
4608
|
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
|
|
4645
4609
|
a proper `torch.device`.
|
|
4646
4610
|
"""
|
|
@@ -4651,7 +4615,7 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
|
|
4651
4615
|
|
|
4652
4616
|
|
|
4653
4617
|
def get_total_byte_count(
|
|
4654
|
-
model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer:
|
|
4618
|
+
model: PreTrainedModel, accelerator_device_map: dict, hf_quantizer: HfQuantizer | None = None
|
|
4655
4619
|
):
|
|
4656
4620
|
"""
|
|
4657
4621
|
This utility function calculates the total bytes count needed to load the model on each device.
|
|
@@ -4684,7 +4648,7 @@ def get_total_byte_count(
|
|
|
4684
4648
|
return total_byte_count
|
|
4685
4649
|
|
|
4686
4650
|
|
|
4687
|
-
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer:
|
|
4651
|
+
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: HfQuantizer | None):
|
|
4688
4652
|
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
|
4689
4653
|
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
|
4690
4654
|
the model, which is actually the loading speed bottleneck.
|
|
@@ -4732,7 +4696,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|
|
4732
4696
|
) - torch_accelerator_module.memory_allocated(index)
|
|
4733
4697
|
byte_count = int(max(0, byte_count - unused_memory))
|
|
4734
4698
|
# We divide by 2 here as we allocate in fp16
|
|
4735
|
-
_ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
|
|
4699
|
+
_ = torch.empty(int(byte_count // 2), dtype=torch.float16, device=device, requires_grad=False)
|
|
4736
4700
|
|
|
4737
4701
|
|
|
4738
4702
|
class AttentionInterface(GeneralInterface):
|
|
@@ -4755,6 +4719,20 @@ class AttentionInterface(GeneralInterface):
|
|
|
4755
4719
|
"paged|eager": eager_paged_attention_forward,
|
|
4756
4720
|
}
|
|
4757
4721
|
|
|
4722
|
+
def get_interface(self, attn_implementation: str, default: Callable) -> Callable:
|
|
4723
|
+
"""Return the requested `attn_implementation`. Also strictly check its validity, and raise if invalid."""
|
|
4724
|
+
if attn_implementation is None:
|
|
4725
|
+
logger.warning_once(
|
|
4726
|
+
"You tried to access the `AttentionInterface` with a `config._attn_implementation` set to `None`. This "
|
|
4727
|
+
"is expected if you use an Attention Module as a standalone Module. If this is not the case, something went "
|
|
4728
|
+
"wrong with the dispatch of `config._attn_implementation`"
|
|
4729
|
+
)
|
|
4730
|
+
elif attn_implementation != "eager" and attn_implementation not in self:
|
|
4731
|
+
raise KeyError(
|
|
4732
|
+
f"`{attn_implementation}` is not a valid attention implementation registered in the `AttentionInterface`"
|
|
4733
|
+
)
|
|
4734
|
+
return super().get(attn_implementation, default)
|
|
4735
|
+
|
|
4758
4736
|
|
|
4759
4737
|
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
|
4760
4738
|
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
|