transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -16,22 +16,24 @@
|
|
|
16
16
|
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
|
+
import math
|
|
19
20
|
import os
|
|
20
21
|
import re
|
|
21
22
|
from abc import abstractmethod
|
|
22
23
|
from collections import defaultdict
|
|
23
|
-
from collections.abc import MutableMapping, MutableSet
|
|
24
|
+
from collections.abc import Callable, MutableMapping, MutableSet
|
|
24
25
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
25
26
|
from contextlib import contextmanager
|
|
26
27
|
from copy import deepcopy
|
|
27
28
|
from dataclasses import dataclass, field
|
|
29
|
+
from itertools import chain
|
|
28
30
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
29
31
|
|
|
30
32
|
import torch
|
|
31
33
|
|
|
32
|
-
from .integrations.accelerate import offload_weight
|
|
34
|
+
from .integrations.accelerate import get_device, offload_weight
|
|
33
35
|
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
|
|
34
|
-
from .utils import is_torch_greater_or_equal, logging
|
|
36
|
+
from .utils import is_env_variable_true, is_torch_greater_or_equal, logging
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
_torch_distributed_available = torch.distributed.is_available()
|
|
@@ -278,6 +280,166 @@ class PermuteForRope(ConversionOps):
|
|
|
278
280
|
return output
|
|
279
281
|
|
|
280
282
|
|
|
283
|
+
class ErnieFuseAndSplitTextVisionExperts(ConversionOps):
|
|
284
|
+
r"""
|
|
285
|
+
Special operation that splits a module list over all keys and fuses over the number of original modules.
|
|
286
|
+
|
|
287
|
+
Example with 2 original modules "Gate" and "Up" with 2 target keys "Text" and "Vision":
|
|
288
|
+
|
|
289
|
+
ModuleList 1 ModuleList 2
|
|
290
|
+
[ Gate ] [ Up ]
|
|
291
|
+
| | | |
|
|
292
|
+
[Gate_Text] [Gate_Vision] [Up_Text] [Up_Vision]
|
|
293
|
+
\ \ / /
|
|
294
|
+
\ \ / /
|
|
295
|
+
\ / \ /
|
|
296
|
+
\ / \ /
|
|
297
|
+
[GateUp_Text] [GateUp_Vision]
|
|
298
|
+
|
|
299
|
+
The splits are equal and are defined by the amount of target keys.
|
|
300
|
+
The final fusions are defined by the amount of original module lists.
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
def __init__(self, stack_dim: int = 0, concat_dim: int = 1):
|
|
304
|
+
self.stack_dim = stack_dim
|
|
305
|
+
self.concat_dim = concat_dim
|
|
306
|
+
|
|
307
|
+
def split_list_into_chunks(self, tensor_list: list[torch.Tensor], chunks: int = 2):
|
|
308
|
+
split_size = math.ceil(len(tensor_list) / chunks) # best effort split size
|
|
309
|
+
return [tensor_list[i * split_size : (i + 1) * split_size] for i in range(chunks)]
|
|
310
|
+
|
|
311
|
+
@torch.no_grad()
|
|
312
|
+
def convert(
|
|
313
|
+
self,
|
|
314
|
+
input_dict: dict[str, list[torch.Tensor]],
|
|
315
|
+
source_patterns: list[str],
|
|
316
|
+
target_patterns: list[str],
|
|
317
|
+
config,
|
|
318
|
+
**kwargs,
|
|
319
|
+
) -> dict[str, list[torch.Tensor]]:
|
|
320
|
+
valid_keys = input_dict.keys()
|
|
321
|
+
split_and_fused = defaultdict(list)
|
|
322
|
+
for key in source_patterns:
|
|
323
|
+
if key not in valid_keys:
|
|
324
|
+
raise ValueError(
|
|
325
|
+
f"Expected pattern {key} in collected tensors but only found tensors for: {valid_keys}"
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
tensors = input_dict.get(key, [])
|
|
329
|
+
split_tensor_lists = self.split_list_into_chunks(tensors, chunks=len(target_patterns))
|
|
330
|
+
stacked_tensors = (torch.stack(tensor_group, dim=self.stack_dim) for tensor_group in split_tensor_lists)
|
|
331
|
+
for idx, tensor_group in enumerate(stacked_tensors):
|
|
332
|
+
split_and_fused[target_patterns[idx]].append(tensor_group)
|
|
333
|
+
|
|
334
|
+
for k, v in split_and_fused.items():
|
|
335
|
+
split_and_fused[k] = torch.cat(v, dim=self.concat_dim)
|
|
336
|
+
|
|
337
|
+
return split_and_fused
|
|
338
|
+
|
|
339
|
+
@property
|
|
340
|
+
def reverse_op(self) -> ConversionOps:
|
|
341
|
+
return ErnieSplitAndDecoupleTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class ErnieSplitAndDecoupleTextVisionExperts(ConversionOps):
|
|
345
|
+
r"""
|
|
346
|
+
Special operation that splits a fused module list over all original modules and
|
|
347
|
+
then decouples them into a mixed module list each over all keys.
|
|
348
|
+
|
|
349
|
+
Example with 2 original modules "Gate" and "Up" with 2 target keys "Text" and "Vision":
|
|
350
|
+
|
|
351
|
+
[GateUp_Text] [GateUp_Vision]
|
|
352
|
+
/ \ / \
|
|
353
|
+
/ \ / \
|
|
354
|
+
/ / \ \
|
|
355
|
+
/ / \ \
|
|
356
|
+
[Gate_Text] [Gate_Vision] [Up_Text] [Up_Vision]
|
|
357
|
+
| | | |
|
|
358
|
+
[ Gate ] [ Up ]
|
|
359
|
+
ModuleList 1 ModuleList 2
|
|
360
|
+
|
|
361
|
+
The splits are equal and are defined by the amount of original module lists.
|
|
362
|
+
The final decoupled module lists are defined by the amount of keys.
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
def __init__(self, stack_dim: int = 0, concat_dim: int = 1):
|
|
366
|
+
self.stack_dim = stack_dim
|
|
367
|
+
self.concat_dim = concat_dim
|
|
368
|
+
|
|
369
|
+
@torch.no_grad()
|
|
370
|
+
def convert(
|
|
371
|
+
self,
|
|
372
|
+
input_dict: dict[str, list[torch.Tensor]],
|
|
373
|
+
source_patterns: list[str],
|
|
374
|
+
target_patterns: list[str],
|
|
375
|
+
config,
|
|
376
|
+
**kwargs,
|
|
377
|
+
) -> dict[str, list[torch.Tensor]]:
|
|
378
|
+
fused_modules = len(target_patterns)
|
|
379
|
+
valid_keys = input_dict.keys()
|
|
380
|
+
split_tensors = []
|
|
381
|
+
for key in source_patterns:
|
|
382
|
+
if key not in valid_keys:
|
|
383
|
+
raise ValueError(
|
|
384
|
+
f"Expected pattern {key} in collected tensors but only found tensors for: {valid_keys}"
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Assuming that we get single sized lists here to index with 0
|
|
388
|
+
split_tensors.append(input_dict[key][0].chunk(fused_modules, dim=self.concat_dim))
|
|
389
|
+
|
|
390
|
+
decoupled = {}
|
|
391
|
+
for idx, key in enumerate(target_patterns):
|
|
392
|
+
tensor_groups = [
|
|
393
|
+
list(torch.unbind(tensor_group[idx], dim=self.stack_dim)) for tensor_group in split_tensors
|
|
394
|
+
]
|
|
395
|
+
tensor_list = list(chain.from_iterable(tensor_groups))
|
|
396
|
+
targets = [key.replace("*", f"{i}") for i in range(len(tensor_list))]
|
|
397
|
+
decoupled |= dict(zip(targets, tensor_list))
|
|
398
|
+
|
|
399
|
+
return decoupled
|
|
400
|
+
|
|
401
|
+
@property
|
|
402
|
+
def reverse_op(self) -> ConversionOps:
|
|
403
|
+
return ErnieFuseAndSplitTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
class Transpose(ConversionOps):
|
|
407
|
+
"""
|
|
408
|
+
Transposes the given tensor along dim0 and dim1.
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
def __init__(self, dim0: int = 0, dim1: int = 1):
|
|
412
|
+
self.dim0 = dim0
|
|
413
|
+
self.dim1 = dim1
|
|
414
|
+
|
|
415
|
+
@torch.no_grad()
|
|
416
|
+
def convert(
|
|
417
|
+
self,
|
|
418
|
+
input_dict: dict[str, list[torch.Tensor]],
|
|
419
|
+
source_patterns: list[str],
|
|
420
|
+
target_patterns: list[str],
|
|
421
|
+
config,
|
|
422
|
+
**kwargs,
|
|
423
|
+
) -> dict[str, list[torch.Tensor]]:
|
|
424
|
+
if len(input_dict) != len(target_patterns):
|
|
425
|
+
raise ValueError(
|
|
426
|
+
f"Transpose conversion can only happen on each key ({len(input_dict)}) "
|
|
427
|
+
f"and should match exact one target ({len(target_patterns)})."
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
output: dict[str, list[torch.Tensor]] = {}
|
|
431
|
+
for key, target_pattern in zip(input_dict.keys(), target_patterns):
|
|
432
|
+
tensor = input_dict.get(key, [])
|
|
433
|
+
if len(tensor) != 1:
|
|
434
|
+
raise ValueError(f"Transpose conversion requires exactly one tensor, found {len(tensor)}.")
|
|
435
|
+
output[target_pattern] = torch.transpose(tensor[0], dim0=self.dim0, dim1=self.dim1).contiguous()
|
|
436
|
+
return output
|
|
437
|
+
|
|
438
|
+
@property
|
|
439
|
+
def reverse_op(self) -> ConversionOps:
|
|
440
|
+
return Transpose(dim0=self.dim1, dim1=self.dim0)
|
|
441
|
+
|
|
442
|
+
|
|
281
443
|
@dataclass(slots=True)
|
|
282
444
|
class WeightTransform:
|
|
283
445
|
source_patterns: Union[str, list[str]] = field(init=True)
|
|
@@ -302,8 +464,11 @@ class WeightTransform:
|
|
|
302
464
|
for i, pattern in enumerate(self.target_patterns):
|
|
303
465
|
# Some mapping contains `^` to notify start of string when matching -> remove it during reverse mapping
|
|
304
466
|
pattern = pattern.removeprefix("^")
|
|
305
|
-
#
|
|
306
|
-
pattern =
|
|
467
|
+
# Some mapping contains `$` to notify end of string when matching -> remove it during reverse mapping
|
|
468
|
+
pattern = pattern.removesuffix("$")
|
|
469
|
+
# Remove negative lookahead/behind if any. This is ugly but needed for reverse mapping of
|
|
470
|
+
# Qwen2.5, Sam3, Ernie4.5 VL MoE!
|
|
471
|
+
pattern = re.sub(r"\(\?.+\)", "", pattern)
|
|
307
472
|
# Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3)
|
|
308
473
|
if r"(.+)" in pattern:
|
|
309
474
|
pattern = pattern.replace(r"(.+)", r"\1")
|
|
@@ -327,10 +492,6 @@ class WeightTransform:
|
|
|
327
492
|
self.collected_tensors[source_pattern].append(future)
|
|
328
493
|
self.layer_targets[target_key].add(source_key)
|
|
329
494
|
|
|
330
|
-
def reset(self) -> None:
|
|
331
|
-
"""Clean-up the collected tensors to make sure we don't keep references to past tensors in memory."""
|
|
332
|
-
self.collected_tensors = defaultdict(list)
|
|
333
|
-
|
|
334
495
|
def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
|
|
335
496
|
"""
|
|
336
497
|
Return a tuple (renamed_key, source_pattern_producing_the_match).
|
|
@@ -342,19 +503,19 @@ class WeightTransform:
|
|
|
342
503
|
match_object = self.compiled_sources.search(source_key)
|
|
343
504
|
if match_object is None:
|
|
344
505
|
return source_key, None
|
|
506
|
+
|
|
345
507
|
# Find the source that produced the match (it's the first group that matched, as the search stops after first branch match)
|
|
346
508
|
matching_group_name = next(name for name, val in match_object.groupdict().items() if val is not None)
|
|
347
509
|
source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])]
|
|
348
510
|
# If we matched, we always replace with the first target pattern, in case we have several (one to many transform)
|
|
349
511
|
replacement = self.target_patterns[0]
|
|
350
|
-
#
|
|
512
|
+
# Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
|
|
351
513
|
if r"\1" in replacement:
|
|
352
514
|
# The index of the internal group we need to replace is the index of the matched named group as it comes
|
|
353
515
|
# inside that matched named group
|
|
354
516
|
replaced_group_idx = self.compiled_sources.groupindex[matching_group_name] + 1
|
|
355
517
|
replacement = replacement.replace(r"\1", match_object.group(replaced_group_idx))
|
|
356
518
|
renamed_key = source_key.replace(match_object.group(0), replacement)
|
|
357
|
-
|
|
358
519
|
return renamed_key, source_pattern_that_matched
|
|
359
520
|
|
|
360
521
|
def reverse_transform(self) -> WeightTransform:
|
|
@@ -375,6 +536,32 @@ class WeightTransform:
|
|
|
375
536
|
|
|
376
537
|
return reverse_transform
|
|
377
538
|
|
|
539
|
+
def materialize_tensors(self) -> dict[str, list[torch.Tensor]]:
|
|
540
|
+
"""
|
|
541
|
+
Materialize all the tensors that were saved in `self.collected_tensors`. This function removes them from the
|
|
542
|
+
internal attribute to avoid keeping them in memory during the different `self.convert` operations, and return
|
|
543
|
+
a new dictionary (otherwise we use more memory than needed during loading).
|
|
544
|
+
|
|
545
|
+
We basically have 3 cases here:
|
|
546
|
+
- async loading (default): the tensors are Future instances that we need to wait for
|
|
547
|
+
- sync loading: the tensors are Callable, we need to call the Callable to actually load them from disk
|
|
548
|
+
- saving: the tensors are already torch.Tensor instances (the existing model weights)
|
|
549
|
+
"""
|
|
550
|
+
collected_tensors = {}
|
|
551
|
+
for key in set(self.collected_tensors.keys()):
|
|
552
|
+
# Remove from internal attribute
|
|
553
|
+
tensors = self.collected_tensors.pop(key)
|
|
554
|
+
# Async loading
|
|
555
|
+
if isinstance(tensors[0], Future):
|
|
556
|
+
tensors = [future.result() for future in tensors]
|
|
557
|
+
# Sync loading
|
|
558
|
+
elif callable(tensors[0]):
|
|
559
|
+
tensors = [func() for func in tensors]
|
|
560
|
+
# Add them to the new dictionary
|
|
561
|
+
collected_tensors[key] = tensors
|
|
562
|
+
|
|
563
|
+
return collected_tensors
|
|
564
|
+
|
|
378
565
|
|
|
379
566
|
@dataclass(slots=True)
|
|
380
567
|
class WeightRenaming(WeightTransform):
|
|
@@ -387,21 +574,21 @@ class WeightRenaming(WeightTransform):
|
|
|
387
574
|
config=None,
|
|
388
575
|
hf_quantizer=None,
|
|
389
576
|
missing_keys: Optional[MutableSet[str]] = None,
|
|
390
|
-
|
|
577
|
+
conversion_errors: Optional[MutableMapping[str, str]] = None,
|
|
391
578
|
):
|
|
392
|
-
# Collect the
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
futures if isinstance(futures[0], torch.Tensor) else [future.result() for future in futures]
|
|
396
|
-
)
|
|
579
|
+
# Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
|
|
580
|
+
# attribute during the whole process
|
|
581
|
+
collected_tensors = self.materialize_tensors()
|
|
397
582
|
|
|
398
583
|
# Perform renaming op (for a simple WeightRenaming, `self.source_patterns` and `self.target_patterns` can
|
|
399
584
|
# only be of length 1, and are actually the full key names - we also have only 1 single related tensor)
|
|
400
585
|
target_key = self.target_patterns[0]
|
|
401
|
-
collected_tensors = {target_key:
|
|
586
|
+
collected_tensors = {target_key: collected_tensors[self.source_patterns[0]]}
|
|
402
587
|
|
|
403
588
|
if hf_quantizer is not None and self.quantization_operation is not None:
|
|
404
|
-
with
|
|
589
|
+
with log_conversion_errors(
|
|
590
|
+
layer_name, conversion_errors, (len(collected_tensors), layer_name), self.quantization_operation
|
|
591
|
+
):
|
|
405
592
|
collected_tensors = self.quantization_operation.convert(
|
|
406
593
|
collected_tensors,
|
|
407
594
|
source_patterns=self.source_patterns,
|
|
@@ -412,7 +599,14 @@ class WeightRenaming(WeightTransform):
|
|
|
412
599
|
missing_keys=missing_keys,
|
|
413
600
|
)
|
|
414
601
|
|
|
415
|
-
return collected_tensors,
|
|
602
|
+
return collected_tensors, conversion_errors
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
# List of classes that are known to be able to use m:n
|
|
606
|
+
_INTERNAL_MANY_TO_MANY_CONVERSIONS = (
|
|
607
|
+
ErnieFuseAndSplitTextVisionExperts,
|
|
608
|
+
ErnieSplitAndDecoupleTextVisionExperts,
|
|
609
|
+
)
|
|
416
610
|
|
|
417
611
|
|
|
418
612
|
@dataclass(slots=True)
|
|
@@ -422,9 +616,11 @@ class WeightConverter(WeightTransform):
|
|
|
422
616
|
def __post_init__(self):
|
|
423
617
|
WeightTransform.__post_init__(self)
|
|
424
618
|
if bool(len(self.source_patterns) - 1) + bool(len(self.target_patterns) - 1) >= 2:
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
619
|
+
# We allow many-to-many only if we use an internal operation that can handle it
|
|
620
|
+
if not any(isinstance(op, _INTERNAL_MANY_TO_MANY_CONVERSIONS) for op in self.operations):
|
|
621
|
+
raise ValueError(
|
|
622
|
+
f"source keys={self.source_patterns}, target_patterns={self.target_patterns} but you can only have one to many, one to one or many to one."
|
|
623
|
+
)
|
|
428
624
|
if not self.operations:
|
|
429
625
|
raise ValueError("WeightConverter requires at least one operation.")
|
|
430
626
|
|
|
@@ -435,17 +631,14 @@ class WeightConverter(WeightTransform):
|
|
|
435
631
|
config=None,
|
|
436
632
|
hf_quantizer=None,
|
|
437
633
|
missing_keys: Optional[MutableSet[str]] = None,
|
|
438
|
-
|
|
634
|
+
conversion_errors: Optional[MutableMapping[str, str]] = None,
|
|
439
635
|
):
|
|
440
|
-
# Collect
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
futures if isinstance(futures[0], torch.Tensor) else [future.result() for future in futures]
|
|
444
|
-
)
|
|
636
|
+
# Collect the tensors here - we use a new dictionary to avoid keeping them in memory in the internal
|
|
637
|
+
# attribute during the whole process
|
|
638
|
+
collected_tensors = self.materialize_tensors()
|
|
445
639
|
|
|
446
|
-
collected_tensors = self.collected_tensors
|
|
447
640
|
for op in self.operations:
|
|
448
|
-
with
|
|
641
|
+
with log_conversion_errors(layer_name, conversion_errors, (len(collected_tensors), layer_name), op):
|
|
449
642
|
collected_tensors = op.convert(
|
|
450
643
|
collected_tensors,
|
|
451
644
|
source_patterns=self.source_patterns,
|
|
@@ -462,11 +655,19 @@ class WeightConverter(WeightTransform):
|
|
|
462
655
|
full_name = layer_name
|
|
463
656
|
if ".*." in layer_name:
|
|
464
657
|
full_name = layer_name.replace(".*.", ".0.")
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
658
|
+
|
|
659
|
+
try:
|
|
660
|
+
prefix, _, suffix = next(full_name.partition(k) for k in collected_tensors.keys() if k in full_name)
|
|
661
|
+
# Rename the tensors
|
|
662
|
+
collected_tensors = {prefix + k + suffix: v for k, v in collected_tensors.items()}
|
|
663
|
+
# some quantizers need to already rename in `convert` as they cannot only rely on prefix and suffix
|
|
664
|
+
except StopIteration:
|
|
665
|
+
pass
|
|
666
|
+
|
|
468
667
|
if hf_quantizer is not None and self.quantization_operation is not None:
|
|
469
|
-
with
|
|
668
|
+
with log_conversion_errors(
|
|
669
|
+
layer_name, conversion_errors, (len(collected_tensors), layer_name), self.quantization_operation
|
|
670
|
+
):
|
|
470
671
|
collected_tensors = self.quantization_operation.convert(
|
|
471
672
|
collected_tensors,
|
|
472
673
|
source_patterns=self.source_patterns,
|
|
@@ -476,7 +677,7 @@ class WeightConverter(WeightTransform):
|
|
|
476
677
|
model=model,
|
|
477
678
|
missing_keys=missing_keys,
|
|
478
679
|
)
|
|
479
|
-
return collected_tensors,
|
|
680
|
+
return collected_tensors, conversion_errors
|
|
480
681
|
|
|
481
682
|
|
|
482
683
|
# For I/O bound operations (i.e. here reading files), it is better to have fewer threads, e.g. 4 is a good default.
|
|
@@ -485,25 +686,46 @@ class WeightConverter(WeightTransform):
|
|
|
485
686
|
GLOBAL_WORKERS = min(4, os.cpu_count() or 4)
|
|
486
687
|
|
|
487
688
|
|
|
488
|
-
def _materialize_copy(tensor, device=None, dtype=None):
|
|
689
|
+
def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Tensor:
|
|
690
|
+
# This slicing is what actually loads the tensor from the safetensors slice object
|
|
489
691
|
tensor = tensor[...]
|
|
490
692
|
if dtype is not None or device is not None:
|
|
491
693
|
tensor = tensor.to(device=device, dtype=dtype)
|
|
492
694
|
return tensor
|
|
493
695
|
|
|
494
696
|
|
|
495
|
-
def spawn_materialize(
|
|
697
|
+
def spawn_materialize(
|
|
698
|
+
thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, device=None, dtype=None
|
|
699
|
+
) -> Future | Callable:
|
|
700
|
+
"""Materialize a tensor from file asynchronously if `thread_pool` is provided, or return a Callable that will
|
|
701
|
+
load the tensor synchronously when called."""
|
|
702
|
+
|
|
496
703
|
def _job():
|
|
497
704
|
return _materialize_copy(tensor, device, dtype)
|
|
498
705
|
|
|
499
|
-
|
|
706
|
+
if thread_pool is not None:
|
|
707
|
+
return thread_pool.submit(_job)
|
|
708
|
+
else:
|
|
709
|
+
# Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
|
|
710
|
+
# memory during Conversion
|
|
711
|
+
return _job
|
|
712
|
+
|
|
500
713
|
|
|
714
|
+
def spawn_tp_materialize(
|
|
715
|
+
thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, sharding_method, tensor_idx, device=None, dtype=None
|
|
716
|
+
) -> Future | Callable:
|
|
717
|
+
"""Materialize and shard a tensor (according to the TP-plan) from file asynchronously if `thread_pool` is provided, or
|
|
718
|
+
return a Callable that will load the tensor synchronously when called."""
|
|
501
719
|
|
|
502
|
-
def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
|
|
503
720
|
def _job():
|
|
504
|
-
return sharding_method.shard_tensor(tensor,
|
|
721
|
+
return sharding_method.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype)
|
|
505
722
|
|
|
506
|
-
|
|
723
|
+
if thread_pool is not None:
|
|
724
|
+
return thread_pool.submit(_job)
|
|
725
|
+
else:
|
|
726
|
+
# Return the Callable here, not the Tensor itself, so we actually delay loading to avoid saturating cpu
|
|
727
|
+
# memory during Conversion
|
|
728
|
+
return _job
|
|
507
729
|
|
|
508
730
|
|
|
509
731
|
def dot_natural_key(s: str):
|
|
@@ -516,13 +738,14 @@ def dot_natural_key(s: str):
|
|
|
516
738
|
|
|
517
739
|
|
|
518
740
|
@contextmanager
|
|
519
|
-
def
|
|
741
|
+
def log_conversion_errors(
|
|
520
742
|
first_target_key: str,
|
|
521
|
-
|
|
743
|
+
conversion_errors: MutableMapping[str, str],
|
|
522
744
|
extras: Any = None,
|
|
523
745
|
op: Union[list[ConversionOps], ConversionOps, None] = None,
|
|
524
746
|
):
|
|
525
|
-
|
|
747
|
+
"""Catch all exceptions during `convert` calls, and log the errors for later. Re-raise a `SkipParameters` exception
|
|
748
|
+
that will be catched later to skip the parameters that raised the original Exception."""
|
|
526
749
|
try:
|
|
527
750
|
yield
|
|
528
751
|
except Exception as e:
|
|
@@ -539,19 +762,21 @@ def log_to_misc(
|
|
|
539
762
|
|
|
540
763
|
op_name = _format_op_name(op)
|
|
541
764
|
if isinstance(extras, tuple) and len(extras) == 2:
|
|
542
|
-
|
|
765
|
+
length, target_keys = extras
|
|
543
766
|
descriptor = f"{op_name} " if op_name else ""
|
|
544
|
-
|
|
545
|
-
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {
|
|
767
|
+
conversion_errors[first_target_key] = (
|
|
768
|
+
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {length}"
|
|
546
769
|
)
|
|
547
770
|
elif isinstance(extras, str):
|
|
548
771
|
suffix = f" via {op_name}" if op_name else ""
|
|
549
|
-
|
|
772
|
+
conversion_errors[first_target_key] = f"{e}\nError{suffix} when processing parameter {extras}"
|
|
550
773
|
elif extras is None and op_name:
|
|
551
|
-
|
|
774
|
+
conversion_errors[first_target_key] = f"{op_name}: {e}"
|
|
552
775
|
else:
|
|
553
|
-
|
|
554
|
-
|
|
776
|
+
conversion_errors[first_target_key] = f"{extras} |Error: {e}"
|
|
777
|
+
|
|
778
|
+
# Raise a specific Exception that we can catch easily
|
|
779
|
+
raise SkipParameters()
|
|
555
780
|
|
|
556
781
|
|
|
557
782
|
def set_param_for_module(
|
|
@@ -560,22 +785,20 @@ def set_param_for_module(
|
|
|
560
785
|
param_value: torch.Tensor,
|
|
561
786
|
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
|
|
562
787
|
missing_keys: MutableSet[str],
|
|
563
|
-
misc: MutableMapping[str, Any],
|
|
564
788
|
unexpected_keys: MutableSet[str],
|
|
565
789
|
distributed_operation: Optional[TensorParallelLayer],
|
|
566
790
|
hf_quantizer: HfQuantizer,
|
|
567
791
|
):
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
if distributed_operation is not None:
|
|
792
|
+
module_path, _, param_name = target_name.rpartition(".")
|
|
793
|
+
module_obj = model.get_submodule(module_path) if module_path else model
|
|
794
|
+
|
|
795
|
+
ref = getattr(module_obj, param_name)
|
|
796
|
+
if ref is None:
|
|
797
|
+
unexpected_keys.add(target_name)
|
|
798
|
+
else:
|
|
799
|
+
if not isinstance(param_value, torch.nn.Parameter):
|
|
800
|
+
if distributed_operation is not None:
|
|
801
|
+
if getattr(distributed_operation, "use_dtensor", False):
|
|
579
802
|
param_value = DTensor.from_local(
|
|
580
803
|
param_value,
|
|
581
804
|
distributed_operation.device_mesh,
|
|
@@ -584,20 +807,17 @@ def set_param_for_module(
|
|
|
584
807
|
shape=ref.size(),
|
|
585
808
|
stride=ref.stride(),
|
|
586
809
|
)
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
# super important otherwise _init_weight will re-init the param
|
|
599
|
-
param_value._is_hf_initialized = True
|
|
600
|
-
setattr(module_obj, param_name, param_value)
|
|
810
|
+
if param_name not in module_obj._buffers:
|
|
811
|
+
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
|
|
812
|
+
|
|
813
|
+
# Remove from missing keys (it's either mismatched, or all good)
|
|
814
|
+
missing_keys.discard(target_name)
|
|
815
|
+
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
|
|
816
|
+
mismatch_keys.add((target_name, param_value.shape, ref.shape))
|
|
817
|
+
else:
|
|
818
|
+
# super important otherwise _init_weight will re-init the param
|
|
819
|
+
param_value._is_hf_initialized = True
|
|
820
|
+
setattr(module_obj, param_name, param_value)
|
|
601
821
|
|
|
602
822
|
|
|
603
823
|
def offload_and_maybe_resave_param(
|
|
@@ -619,8 +839,9 @@ def offload_and_maybe_resave_param(
|
|
|
619
839
|
return disk_offload_index
|
|
620
840
|
|
|
621
841
|
|
|
622
|
-
class
|
|
623
|
-
"""Control-flow sentinel: abort processing of the current
|
|
842
|
+
class SkipParameters(Exception):
|
|
843
|
+
"""Control-flow sentinel: abort processing of the current parameters only (that were supposed to be created
|
|
844
|
+
by a WeightConverter)."""
|
|
624
845
|
|
|
625
846
|
pass
|
|
626
847
|
|
|
@@ -675,6 +896,7 @@ def convert_and_load_state_dict_in_model(
|
|
|
675
896
|
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
|
|
676
897
|
disk_offload_index: dict | None = None,
|
|
677
898
|
disk_offload_folder: str | None = None,
|
|
899
|
+
offload_buffers: bool = False,
|
|
678
900
|
):
|
|
679
901
|
r"""
|
|
680
902
|
We build a mapping from the keys obtained by renaming each of the checkpoint keys according to the weight_mapping rules.
|
|
@@ -688,7 +910,7 @@ def convert_and_load_state_dict_in_model(
|
|
|
688
910
|
target_patterns=["q", "k","v"],
|
|
689
911
|
operations=[Chunk(dim=0, chunks=3)]),
|
|
690
912
|
collected_tensors={
|
|
691
|
-
"qkv": [Future
|
|
913
|
+
"qkv": [Future]},
|
|
692
914
|
layer_targets={
|
|
693
915
|
"model.layers.0.attention.q.weight": {"model.layers.0.attention.qkv.weight"},
|
|
694
916
|
"model.layers.0.attention.k.weight": {"model.layers.0.attention.qkv.weight"},
|
|
@@ -765,25 +987,26 @@ def convert_and_load_state_dict_in_model(
|
|
|
765
987
|
prefix = model.base_model_prefix
|
|
766
988
|
tp_plan = tp_plan or {}
|
|
767
989
|
device_map = device_map or {"": "cpu"}
|
|
768
|
-
# Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
|
|
769
|
-
device_map_regex = re.compile(
|
|
770
|
-
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
|
|
771
|
-
)
|
|
772
990
|
dtype_plan = dtype_plan or {}
|
|
773
991
|
weight_mapping = weight_mapping or []
|
|
774
992
|
meta_model_state_dict = model.state_dict()
|
|
775
|
-
|
|
993
|
+
model_buffers = {k for k, _ in model.named_buffers()}
|
|
776
994
|
|
|
777
|
-
|
|
995
|
+
missing_keys = set(meta_model_state_dict.keys())
|
|
996
|
+
conversion_errors = {}
|
|
778
997
|
mismatch_keys = set()
|
|
779
998
|
unexpected_keys = set()
|
|
780
|
-
|
|
781
|
-
|
|
999
|
+
|
|
1000
|
+
# We use threading by default, if not explicitly deactivated via env variable. If we have to offload,
|
|
1001
|
+
# we cannot use it either to control the memory as we are under memory constraints, so we need to be sequential
|
|
1002
|
+
if is_env_variable_true("HF_DEACTIVATE_ASYNC_LOAD") or "disk" in device_map.values():
|
|
1003
|
+
thread_pool = None
|
|
1004
|
+
else:
|
|
1005
|
+
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
|
|
782
1006
|
|
|
783
1007
|
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
|
|
784
1008
|
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
|
|
785
|
-
|
|
786
|
-
param_name_to_load: dict[str, Union[WeightRenaming | WeightConverter]] = {}
|
|
1009
|
+
param_name_to_load: dict[str, WeightRenaming | WeightConverter] = {}
|
|
787
1010
|
|
|
788
1011
|
# build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
|
|
789
1012
|
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
|
|
@@ -826,41 +1049,40 @@ def convert_and_load_state_dict_in_model(
|
|
|
826
1049
|
if hf_quantizer and hf_quantizer.pre_quantized and original_key != renamed_key:
|
|
827
1050
|
# if the key was renamed as it is not available in the state dict otherwise, it means that we are deserializing it,
|
|
828
1051
|
# so we need to make sure to load the tensor with the same dtype from the checkpoint
|
|
1052
|
+
# TODO: make the condition more srict for native fp8 model such as qwen2moe fp8
|
|
829
1053
|
_dtype = None
|
|
830
1054
|
elif dtype_plan != {} and dtype_policy_alt.search(renamed_key):
|
|
831
1055
|
matched_dtype_pattern = dtype_policy_alt.search(renamed_key)
|
|
832
1056
|
if matched_dtype_pattern is not None:
|
|
833
|
-
_dtype = dtype_plan[matched_dtype_pattern.
|
|
1057
|
+
_dtype = dtype_plan[dtype_policy_by_group_name[matched_dtype_pattern.lastgroup]]
|
|
834
1058
|
elif empty_param is not None and empty_param.dtype != _dtype:
|
|
835
1059
|
_dtype = empty_param.dtype # usually correct when initializing
|
|
836
1060
|
|
|
837
|
-
# 4. Handle TP sharding or device_map placement
|
|
838
|
-
|
|
1061
|
+
# 4. Handle TP sharding or device_map placement
|
|
1062
|
+
future_or_tensor = None
|
|
839
1063
|
if device_mesh:
|
|
840
1064
|
if matched_tp_pattern := tp_plan_alt.search(renamed_key):
|
|
841
1065
|
matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup]
|
|
842
1066
|
if getattr(mapping, "distributed_operation", None) is None:
|
|
843
1067
|
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
|
|
844
1068
|
mapping.distributed_operation = tp_layer(
|
|
845
|
-
device_mesh=device_mesh, rank=
|
|
1069
|
+
device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone()
|
|
846
1070
|
)
|
|
847
1071
|
shard_index = len(mapping.collected_tensors.get(original_key, []))
|
|
848
|
-
|
|
1072
|
+
future_or_tensor = spawn_tp_materialize(
|
|
849
1073
|
thread_pool,
|
|
850
1074
|
tensor,
|
|
851
1075
|
mapping.distributed_operation,
|
|
852
1076
|
shard_index,
|
|
1077
|
+
device_map[""],
|
|
853
1078
|
_dtype,
|
|
854
1079
|
)
|
|
855
1080
|
|
|
856
|
-
if
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
# If disk, we need to materialize on cpu first
|
|
860
|
-
param_device = "cpu" if param_device == "disk" else param_device
|
|
861
|
-
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
|
|
1081
|
+
if future_or_tensor is None:
|
|
1082
|
+
param_device = get_device(device_map, renamed_key, valid_torch_device=True)
|
|
1083
|
+
future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype)
|
|
862
1084
|
|
|
863
|
-
mapping.add_tensor(renamed_key, original_key, source_pattern,
|
|
1085
|
+
mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor)
|
|
864
1086
|
elif source_pattern is not None: # add all target keys as unexpected
|
|
865
1087
|
mapping = pattern_to_converter[source_pattern]
|
|
866
1088
|
for k in mapping.target_patterns:
|
|
@@ -868,52 +1090,57 @@ def convert_and_load_state_dict_in_model(
|
|
|
868
1090
|
else:
|
|
869
1091
|
unexpected_keys.add(renamed_key)
|
|
870
1092
|
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
1093
|
+
try:
|
|
1094
|
+
total_entries = len(param_name_to_load)
|
|
1095
|
+
with logging.tqdm(total=total_entries, desc="Loading weights") as pbar:
|
|
1096
|
+
for first_param_name, mapping in param_name_to_load.items():
|
|
1097
|
+
pbar.update(1)
|
|
1098
|
+
pbar.set_postfix({"Materializing param": first_param_name})
|
|
1099
|
+
pbar.refresh()
|
|
1100
|
+
try:
|
|
1101
|
+
realized_value, conversion_errors = mapping.convert(
|
|
1102
|
+
first_param_name,
|
|
1103
|
+
model=model,
|
|
1104
|
+
config=model.config,
|
|
1105
|
+
hf_quantizer=hf_quantizer,
|
|
1106
|
+
missing_keys=missing_keys,
|
|
1107
|
+
conversion_errors=conversion_errors,
|
|
1108
|
+
)
|
|
1109
|
+
for target_name, param in realized_value.items():
|
|
1110
|
+
param = param[0] if isinstance(param, list) else param
|
|
1111
|
+
param_device = get_device(device_map, target_name)
|
|
1112
|
+
# Offloading support
|
|
1113
|
+
if param_device == "disk" and (target_name not in model_buffers or offload_buffers):
|
|
1114
|
+
disk_offload_index = offload_and_maybe_resave_param(
|
|
1115
|
+
target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
|
|
1116
|
+
)
|
|
1117
|
+
else:
|
|
1118
|
+
set_param_for_module(
|
|
1119
|
+
model,
|
|
1120
|
+
target_name,
|
|
1121
|
+
param,
|
|
1122
|
+
mismatch_keys,
|
|
1123
|
+
missing_keys,
|
|
1124
|
+
unexpected_keys,
|
|
1125
|
+
mapping.distributed_operation,
|
|
1126
|
+
hf_quantizer,
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
# Cleanup all the tensors that were gathered before next iteration
|
|
1130
|
+
del realized_value
|
|
1131
|
+
|
|
1132
|
+
except SkipParameters:
|
|
1133
|
+
continue
|
|
1134
|
+
|
|
1135
|
+
# Close the pool, independently of whether the code was interrupted or finished successfully
|
|
1136
|
+
finally:
|
|
1137
|
+
if thread_pool is not None:
|
|
1138
|
+
# `cancel_futures=True` in case the program was interupted, to avoid wasting time on exit
|
|
1139
|
+
thread_pool.shutdown(wait=False, cancel_futures=True)
|
|
912
1140
|
|
|
913
1141
|
# Keep the current weight conversion mapping for later saving (in case it was coming directly from the user)
|
|
914
1142
|
model._weight_conversions = weight_mapping
|
|
915
|
-
|
|
916
|
-
return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, misc
|
|
1143
|
+
return missing_keys, unexpected_keys, mismatch_keys, disk_offload_index, conversion_errors
|
|
917
1144
|
|
|
918
1145
|
|
|
919
1146
|
def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch.Tensor]):
|
|
@@ -960,7 +1187,7 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
|
|
|
960
1187
|
new_state_dict = {}
|
|
961
1188
|
for first_param_name, reversed_converter in conversion_mapping.items():
|
|
962
1189
|
# Apply the reverse converter
|
|
963
|
-
realized_value,
|
|
1190
|
+
realized_value, _ = reversed_converter.convert(first_param_name, model=model, config=model.config)
|
|
964
1191
|
for target_name, param in realized_value.items():
|
|
965
1192
|
param = param[0] if isinstance(param, list) else param
|
|
966
1193
|
new_state_dict[target_name] = param
|