transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
|
+
import math
|
|
19
20
|
import os
|
|
20
21
|
import re
|
|
21
22
|
from abc import abstractmethod
|
|
@@ -25,11 +26,12 @@ from concurrent.futures import Future, ThreadPoolExecutor
|
|
|
25
26
|
from contextlib import contextmanager
|
|
26
27
|
from copy import deepcopy
|
|
27
28
|
from dataclasses import dataclass, field
|
|
29
|
+
from itertools import chain
|
|
28
30
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
29
31
|
|
|
30
32
|
import torch
|
|
31
33
|
|
|
32
|
-
from .integrations.accelerate import offload_weight
|
|
34
|
+
from .integrations.accelerate import get_device, offload_weight
|
|
33
35
|
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
|
|
34
36
|
from .utils import is_env_variable_true, is_torch_greater_or_equal, logging
|
|
35
37
|
|
|
@@ -278,6 +280,166 @@ class PermuteForRope(ConversionOps):
|
|
|
278
280
|
return output
|
|
279
281
|
|
|
280
282
|
|
|
283
|
+
class ErnieFuseAndSplitTextVisionExperts(ConversionOps):
|
|
284
|
+
r"""
|
|
285
|
+
Special operation that splits a module list over all keys and fuses over the number of original modules.
|
|
286
|
+
|
|
287
|
+
Example with 2 original modules "Gate" and "Up" with 2 target keys "Text" and "Vision":
|
|
288
|
+
|
|
289
|
+
ModuleList 1 ModuleList 2
|
|
290
|
+
[ Gate ] [ Up ]
|
|
291
|
+
| | | |
|
|
292
|
+
[Gate_Text] [Gate_Vision] [Up_Text] [Up_Vision]
|
|
293
|
+
\ \ / /
|
|
294
|
+
\ \ / /
|
|
295
|
+
\ / \ /
|
|
296
|
+
\ / \ /
|
|
297
|
+
[GateUp_Text] [GateUp_Vision]
|
|
298
|
+
|
|
299
|
+
The splits are equal and are defined by the amount of target keys.
|
|
300
|
+
The final fusions are defined by the amount of original module lists.
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
def __init__(self, stack_dim: int = 0, concat_dim: int = 1):
|
|
304
|
+
self.stack_dim = stack_dim
|
|
305
|
+
self.concat_dim = concat_dim
|
|
306
|
+
|
|
307
|
+
def split_list_into_chunks(self, tensor_list: list[torch.Tensor], chunks: int = 2):
|
|
308
|
+
split_size = math.ceil(len(tensor_list) / chunks) # best effort split size
|
|
309
|
+
return [tensor_list[i * split_size : (i + 1) * split_size] for i in range(chunks)]
|
|
310
|
+
|
|
311
|
+
@torch.no_grad()
|
|
312
|
+
def convert(
|
|
313
|
+
self,
|
|
314
|
+
input_dict: dict[str, list[torch.Tensor]],
|
|
315
|
+
source_patterns: list[str],
|
|
316
|
+
target_patterns: list[str],
|
|
317
|
+
config,
|
|
318
|
+
**kwargs,
|
|
319
|
+
) -> dict[str, list[torch.Tensor]]:
|
|
320
|
+
valid_keys = input_dict.keys()
|
|
321
|
+
split_and_fused = defaultdict(list)
|
|
322
|
+
for key in source_patterns:
|
|
323
|
+
if key not in valid_keys:
|
|
324
|
+
raise ValueError(
|
|
325
|
+
f"Expected pattern {key} in collected tensors but only found tensors for: {valid_keys}"
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
tensors = input_dict.get(key, [])
|
|
329
|
+
split_tensor_lists = self.split_list_into_chunks(tensors, chunks=len(target_patterns))
|
|
330
|
+
stacked_tensors = (torch.stack(tensor_group, dim=self.stack_dim) for tensor_group in split_tensor_lists)
|
|
331
|
+
for idx, tensor_group in enumerate(stacked_tensors):
|
|
332
|
+
split_and_fused[target_patterns[idx]].append(tensor_group)
|
|
333
|
+
|
|
334
|
+
for k, v in split_and_fused.items():
|
|
335
|
+
split_and_fused[k] = torch.cat(v, dim=self.concat_dim)
|
|
336
|
+
|
|
337
|
+
return split_and_fused
|
|
338
|
+
|
|
339
|
+
@property
|
|
340
|
+
def reverse_op(self) -> ConversionOps:
|
|
341
|
+
return ErnieSplitAndDecoupleTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class ErnieSplitAndDecoupleTextVisionExperts(ConversionOps):
|
|
345
|
+
r"""
|
|
346
|
+
Special operation that splits a fused module list over all original modules and
|
|
347
|
+
then decouples them into a mixed module list each over all keys.
|
|
348
|
+
|
|
349
|
+
Example with 2 original modules "Gate" and "Up" with 2 target keys "Text" and "Vision":
|
|
350
|
+
|
|
351
|
+
[GateUp_Text] [GateUp_Vision]
|
|
352
|
+
/ \ / \
|
|
353
|
+
/ \ / \
|
|
354
|
+
/ / \ \
|
|
355
|
+
/ / \ \
|
|
356
|
+
[Gate_Text] [Gate_Vision] [Up_Text] [Up_Vision]
|
|
357
|
+
| | | |
|
|
358
|
+
[ Gate ] [ Up ]
|
|
359
|
+
ModuleList 1 ModuleList 2
|
|
360
|
+
|
|
361
|
+
The splits are equal and are defined by the amount of original module lists.
|
|
362
|
+
The final decoupled module lists are defined by the amount of keys.
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
def __init__(self, stack_dim: int = 0, concat_dim: int = 1):
|
|
366
|
+
self.stack_dim = stack_dim
|
|
367
|
+
self.concat_dim = concat_dim
|
|
368
|
+
|
|
369
|
+
@torch.no_grad()
|
|
370
|
+
def convert(
|
|
371
|
+
self,
|
|
372
|
+
input_dict: dict[str, list[torch.Tensor]],
|
|
373
|
+
source_patterns: list[str],
|
|
374
|
+
target_patterns: list[str],
|
|
375
|
+
config,
|
|
376
|
+
**kwargs,
|
|
377
|
+
) -> dict[str, list[torch.Tensor]]:
|
|
378
|
+
fused_modules = len(target_patterns)
|
|
379
|
+
valid_keys = input_dict.keys()
|
|
380
|
+
split_tensors = []
|
|
381
|
+
for key in source_patterns:
|
|
382
|
+
if key not in valid_keys:
|
|
383
|
+
raise ValueError(
|
|
384
|
+
f"Expected pattern {key} in collected tensors but only found tensors for: {valid_keys}"
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Assuming that we get single sized lists here to index with 0
|
|
388
|
+
split_tensors.append(input_dict[key][0].chunk(fused_modules, dim=self.concat_dim))
|
|
389
|
+
|
|
390
|
+
decoupled = {}
|
|
391
|
+
for idx, key in enumerate(target_patterns):
|
|
392
|
+
tensor_groups = [
|
|
393
|
+
list(torch.unbind(tensor_group[idx], dim=self.stack_dim)) for tensor_group in split_tensors
|
|
394
|
+
]
|
|
395
|
+
tensor_list = list(chain.from_iterable(tensor_groups))
|
|
396
|
+
targets = [key.replace("*", f"{i}") for i in range(len(tensor_list))]
|
|
397
|
+
decoupled |= dict(zip(targets, tensor_list))
|
|
398
|
+
|
|
399
|
+
return decoupled
|
|
400
|
+
|
|
401
|
+
@property
|
|
402
|
+
def reverse_op(self) -> ConversionOps:
|
|
403
|
+
return ErnieFuseAndSplitTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
class Transpose(ConversionOps):
|
|
407
|
+
"""
|
|
408
|
+
Transposes the given tensor along dim0 and dim1.
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
def __init__(self, dim0: int = 0, dim1: int = 1):
|
|
412
|
+
self.dim0 = dim0
|
|
413
|
+
self.dim1 = dim1
|
|
414
|
+
|
|
415
|
+
@torch.no_grad()
|
|
416
|
+
def convert(
|
|
417
|
+
self,
|
|
418
|
+
input_dict: dict[str, list[torch.Tensor]],
|
|
419
|
+
source_patterns: list[str],
|
|
420
|
+
target_patterns: list[str],
|
|
421
|
+
config,
|
|
422
|
+
**kwargs,
|
|
423
|
+
) -> dict[str, list[torch.Tensor]]:
|
|
424
|
+
if len(input_dict) != len(target_patterns):
|
|
425
|
+
raise ValueError(
|
|
426
|
+
f"Transpose conversion can only happen on each key ({len(input_dict)}) "
|
|
427
|
+
f"and should match exact one target ({len(target_patterns)})."
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
output: dict[str, list[torch.Tensor]] = {}
|
|
431
|
+
for key, target_pattern in zip(input_dict.keys(), target_patterns):
|
|
432
|
+
tensor = input_dict.get(key, [])
|
|
433
|
+
if len(tensor) != 1:
|
|
434
|
+
raise ValueError(f"Transpose conversion requires exactly one tensor, found {len(tensor)}.")
|
|
435
|
+
output[target_pattern] = torch.transpose(tensor[0], dim0=self.dim0, dim1=self.dim1).contiguous()
|
|
436
|
+
return output
|
|
437
|
+
|
|
438
|
+
@property
|
|
439
|
+
def reverse_op(self) -> ConversionOps:
|
|
440
|
+
return Transpose(dim0=self.dim1, dim1=self.dim0)
|
|
441
|
+
|
|
442
|
+
|
|
281
443
|
@dataclass(slots=True)
|
|
282
444
|
class WeightTransform:
|
|
283
445
|
source_patterns: Union[str, list[str]] = field(init=True)
|
|
@@ -302,8 +464,11 @@ class WeightTransform:
|
|
|
302
464
|
for i, pattern in enumerate(self.target_patterns):
|
|
303
465
|
# Some mapping contains `^` to notify start of string when matching -> remove it during reverse mapping
|
|
304
466
|
pattern = pattern.removeprefix("^")
|
|
305
|
-
#
|
|
306
|
-
pattern =
|
|
467
|
+
# Some mapping contains `$` to notify end of string when matching -> remove it during reverse mapping
|
|
468
|
+
pattern = pattern.removesuffix("$")
|
|
469
|
+
# Remove negative lookahead/behind if any. This is ugly but needed for reverse mapping of
|
|
470
|
+
# Qwen2.5, Sam3, Ernie4.5 VL MoE!
|
|
471
|
+
pattern = re.sub(r"\(\?.+\)", "", pattern)
|
|
307
472
|
# Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3)
|
|
308
473
|
if r"(.+)" in pattern:
|
|
309
474
|
pattern = pattern.replace(r"(.+)", r"\1")
|
|
@@ -338,19 +503,19 @@ class WeightTransform:
|
|
|
338
503
|
match_object = self.compiled_sources.search(source_key)
|
|
339
504
|
if match_object is None:
|
|
340
505
|
return source_key, None
|
|
506
|
+
|
|
341
507
|
# Find the source that produced the match (it's the first group that matched, as the search stops after first branch match)
|
|
342
508
|
matching_group_name = next(name for name, val in match_object.groupdict().items() if val is not None)
|
|
343
509
|
source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])]
|
|
344
510
|
# If we matched, we always replace with the first target pattern, in case we have several (one to many transform)
|
|
345
511
|
replacement = self.target_patterns[0]
|
|
346
|
-
#
|
|
512
|
+
# Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
|
|
347
513
|
if r"\1" in replacement:
|
|
348
514
|
# The index of the internal group we need to replace is the index of the matched named group as it comes
|
|
349
515
|
# inside that matched named group
|
|
350
516
|
replaced_group_idx = self.compiled_sources.groupindex[matching_group_name] + 1
|
|
351
517
|
replacement = replacement.replace(r"\1", match_object.group(replaced_group_idx))
|
|
352
518
|
renamed_key = source_key.replace(match_object.group(0), replacement)
|
|
353
|
-
|
|
354
519
|
return renamed_key, source_pattern_that_matched
|
|
355
520
|
|
|
356
521
|
def reverse_transform(self) -> WeightTransform:
|
|
@@ -437,6 +602,13 @@ class WeightRenaming(WeightTransform):
|
|
|
437
602
|
return collected_tensors, conversion_errors
|
|
438
603
|
|
|
439
604
|
|
|
605
|
+
# List of classes that are known to be able to use m:n
|
|
606
|
+
_INTERNAL_MANY_TO_MANY_CONVERSIONS = (
|
|
607
|
+
ErnieFuseAndSplitTextVisionExperts,
|
|
608
|
+
ErnieSplitAndDecoupleTextVisionExperts,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
|
|
440
612
|
@dataclass(slots=True)
|
|
441
613
|
class WeightConverter(WeightTransform):
|
|
442
614
|
operations: list[ConversionOps] = field(default_factory=list, repr=False)
|
|
@@ -444,9 +616,11 @@ class WeightConverter(WeightTransform):
|
|
|
444
616
|
def __post_init__(self):
|
|
445
617
|
WeightTransform.__post_init__(self)
|
|
446
618
|
if bool(len(self.source_patterns) - 1) + bool(len(self.target_patterns) - 1) >= 2:
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
619
|
+
# We allow many-to-many only if we use an internal operation that can handle it
|
|
620
|
+
if not any(isinstance(op, _INTERNAL_MANY_TO_MANY_CONVERSIONS) for op in self.operations):
|
|
621
|
+
raise ValueError(
|
|
622
|
+
f"source keys={self.source_patterns}, target_patterns={self.target_patterns} but you can only have one to many, one to one or many to one."
|
|
623
|
+
)
|
|
450
624
|
if not self.operations:
|
|
451
625
|
raise ValueError("WeightConverter requires at least one operation.")
|
|
452
626
|
|
|
@@ -538,13 +712,13 @@ def spawn_materialize(
|
|
|
538
712
|
|
|
539
713
|
|
|
540
714
|
def spawn_tp_materialize(
|
|
541
|
-
thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, sharding_method, tensor_idx, dtype=None
|
|
715
|
+
thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, sharding_method, tensor_idx, device=None, dtype=None
|
|
542
716
|
) -> Future | Callable:
|
|
543
717
|
"""Materialize and shard a tensor (according to the TP-plan) from file asynchronously if `thread_pool` is provided, or
|
|
544
718
|
return a Callable that will load the tensor synchronously when called."""
|
|
545
719
|
|
|
546
720
|
def _job():
|
|
547
|
-
return sharding_method.shard_tensor(tensor,
|
|
721
|
+
return sharding_method.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype)
|
|
548
722
|
|
|
549
723
|
if thread_pool is not None:
|
|
550
724
|
return thread_pool.submit(_job)
|
|
@@ -622,20 +796,17 @@ def set_param_for_module(
|
|
|
622
796
|
if ref is None:
|
|
623
797
|
unexpected_keys.add(target_name)
|
|
624
798
|
else:
|
|
625
|
-
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
|
|
626
799
|
if not isinstance(param_value, torch.nn.Parameter):
|
|
627
800
|
if distributed_operation is not None:
|
|
628
|
-
|
|
629
|
-
param_value
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
# we convert to local
|
|
638
|
-
param_value = param_value.to_local()
|
|
801
|
+
if getattr(distributed_operation, "use_dtensor", False):
|
|
802
|
+
param_value = DTensor.from_local(
|
|
803
|
+
param_value,
|
|
804
|
+
distributed_operation.device_mesh,
|
|
805
|
+
getattr(distributed_operation, "shard", Replicate()),
|
|
806
|
+
run_check=False,
|
|
807
|
+
shape=ref.size(),
|
|
808
|
+
stride=ref.stride(),
|
|
809
|
+
)
|
|
639
810
|
if param_name not in module_obj._buffers:
|
|
640
811
|
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
|
|
641
812
|
|
|
@@ -725,6 +896,7 @@ def convert_and_load_state_dict_in_model(
|
|
|
725
896
|
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
|
|
726
897
|
disk_offload_index: dict | None = None,
|
|
727
898
|
disk_offload_folder: str | None = None,
|
|
899
|
+
offload_buffers: bool = False,
|
|
728
900
|
):
|
|
729
901
|
r"""
|
|
730
902
|
We build a mapping from the keys obtained by renaming each of the checkpoint keys according to the weight_mapping rules.
|
|
@@ -815,15 +987,12 @@ def convert_and_load_state_dict_in_model(
|
|
|
815
987
|
prefix = model.base_model_prefix
|
|
816
988
|
tp_plan = tp_plan or {}
|
|
817
989
|
device_map = device_map or {"": "cpu"}
|
|
818
|
-
# Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
|
|
819
|
-
device_map_regex = re.compile(
|
|
820
|
-
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
|
|
821
|
-
)
|
|
822
990
|
dtype_plan = dtype_plan or {}
|
|
823
991
|
weight_mapping = weight_mapping or []
|
|
824
992
|
meta_model_state_dict = model.state_dict()
|
|
825
|
-
|
|
993
|
+
model_buffers = {k for k, _ in model.named_buffers()}
|
|
826
994
|
|
|
995
|
+
missing_keys = set(meta_model_state_dict.keys())
|
|
827
996
|
conversion_errors = {}
|
|
828
997
|
mismatch_keys = set()
|
|
829
998
|
unexpected_keys = set()
|
|
@@ -897,7 +1066,7 @@ def convert_and_load_state_dict_in_model(
|
|
|
897
1066
|
if getattr(mapping, "distributed_operation", None) is None:
|
|
898
1067
|
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
|
|
899
1068
|
mapping.distributed_operation = tp_layer(
|
|
900
|
-
device_mesh=device_mesh, rank=
|
|
1069
|
+
device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone()
|
|
901
1070
|
)
|
|
902
1071
|
shard_index = len(mapping.collected_tensors.get(original_key, []))
|
|
903
1072
|
future_or_tensor = spawn_tp_materialize(
|
|
@@ -905,14 +1074,12 @@ def convert_and_load_state_dict_in_model(
|
|
|
905
1074
|
tensor,
|
|
906
1075
|
mapping.distributed_operation,
|
|
907
1076
|
shard_index,
|
|
1077
|
+
device_map[""],
|
|
908
1078
|
_dtype,
|
|
909
1079
|
)
|
|
910
1080
|
|
|
911
1081
|
if future_or_tensor is None:
|
|
912
|
-
|
|
913
|
-
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
|
|
914
|
-
# If disk, we need to materialize on cpu first
|
|
915
|
-
param_device = "cpu" if param_device == "disk" else param_device
|
|
1082
|
+
param_device = get_device(device_map, renamed_key, valid_torch_device=True)
|
|
916
1083
|
future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype)
|
|
917
1084
|
|
|
918
1085
|
mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor)
|
|
@@ -941,10 +1108,9 @@ def convert_and_load_state_dict_in_model(
|
|
|
941
1108
|
)
|
|
942
1109
|
for target_name, param in realized_value.items():
|
|
943
1110
|
param = param[0] if isinstance(param, list) else param
|
|
944
|
-
|
|
945
|
-
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
|
|
1111
|
+
param_device = get_device(device_map, target_name)
|
|
946
1112
|
# Offloading support
|
|
947
|
-
if param_device == "disk":
|
|
1113
|
+
if param_device == "disk" and (target_name not in model_buffers or offload_buffers):
|
|
948
1114
|
disk_offload_index = offload_and_maybe_resave_param(
|
|
949
1115
|
target_name, param, missing_keys, disk_offload_folder, disk_offload_index, mapping
|
|
950
1116
|
)
|
|
@@ -75,7 +75,7 @@ deps = {
|
|
|
75
75
|
"tensorboard": "tensorboard",
|
|
76
76
|
"timeout-decorator": "timeout-decorator",
|
|
77
77
|
"tiktoken": "tiktoken",
|
|
78
|
-
"timm": "timm
|
|
78
|
+
"timm": "timm>=1.0.23",
|
|
79
79
|
"tokenizers": "tokenizers>=0.22.0,<=0.23.0",
|
|
80
80
|
"torch": "torch>=2.2",
|
|
81
81
|
"torchaudio": "torchaudio",
|
|
@@ -30,6 +30,7 @@ from .utils import (
|
|
|
30
30
|
PROCESSOR_NAME,
|
|
31
31
|
PushToHubMixin,
|
|
32
32
|
TensorType,
|
|
33
|
+
_is_tensor_or_array_like,
|
|
33
34
|
copy_func,
|
|
34
35
|
is_numpy_array,
|
|
35
36
|
is_torch_available,
|
|
@@ -67,11 +68,18 @@ class BatchFeature(UserDict):
|
|
|
67
68
|
tensor_type (`Union[None, str, TensorType]`, *optional*):
|
|
68
69
|
You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
|
|
69
70
|
initialization.
|
|
71
|
+
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
|
|
72
|
+
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
|
|
70
73
|
"""
|
|
71
74
|
|
|
72
|
-
def __init__(
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
data: Optional[dict[str, Any]] = None,
|
|
78
|
+
tensor_type: Union[None, str, TensorType] = None,
|
|
79
|
+
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
|
|
80
|
+
):
|
|
73
81
|
super().__init__(data)
|
|
74
|
-
self.convert_to_tensors(tensor_type=tensor_type)
|
|
82
|
+
self.convert_to_tensors(tensor_type=tensor_type, skip_tensor_conversion=skip_tensor_conversion)
|
|
75
83
|
|
|
76
84
|
def __getitem__(self, item: str) -> Any:
|
|
77
85
|
"""
|
|
@@ -110,6 +118,14 @@ class BatchFeature(UserDict):
|
|
|
110
118
|
import torch
|
|
111
119
|
|
|
112
120
|
def as_tensor(value):
|
|
121
|
+
if torch.is_tensor(value):
|
|
122
|
+
return value
|
|
123
|
+
|
|
124
|
+
# stack list of tensors if tensor_type is PyTorch (# torch.tensor() does not support list of tensors)
|
|
125
|
+
if isinstance(value, (list, tuple)) and len(value) > 0 and torch.is_tensor(value[0]):
|
|
126
|
+
return torch.stack(value)
|
|
127
|
+
|
|
128
|
+
# convert list of numpy arrays to numpy array (stack) if tensor_type is Numpy
|
|
113
129
|
if isinstance(value, (list, tuple)) and len(value) > 0:
|
|
114
130
|
if isinstance(value[0], np.ndarray):
|
|
115
131
|
value = np.array(value)
|
|
@@ -138,7 +154,11 @@ class BatchFeature(UserDict):
|
|
|
138
154
|
is_tensor = is_numpy_array
|
|
139
155
|
return is_tensor, as_tensor
|
|
140
156
|
|
|
141
|
-
def convert_to_tensors(
|
|
157
|
+
def convert_to_tensors(
|
|
158
|
+
self,
|
|
159
|
+
tensor_type: Optional[Union[str, TensorType]] = None,
|
|
160
|
+
skip_tensor_conversion: Optional[Union[list[str], set[str]]] = None,
|
|
161
|
+
):
|
|
142
162
|
"""
|
|
143
163
|
Convert the inner content to tensors.
|
|
144
164
|
|
|
@@ -146,6 +166,13 @@ class BatchFeature(UserDict):
|
|
|
146
166
|
tensor_type (`str` or [`~utils.TensorType`], *optional*):
|
|
147
167
|
The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
|
|
148
168
|
`None`, no modification is done.
|
|
169
|
+
skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
|
|
170
|
+
List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
|
|
171
|
+
|
|
172
|
+
Note:
|
|
173
|
+
Values that don't have an array-like structure (e.g., strings, dicts, lists of strings) are
|
|
174
|
+
automatically skipped and won't be converted to tensors. Ragged arrays (lists of arrays with
|
|
175
|
+
different lengths) are still attempted, though they may raise errors during conversion.
|
|
149
176
|
"""
|
|
150
177
|
if tensor_type is None:
|
|
151
178
|
return self
|
|
@@ -154,18 +181,30 @@ class BatchFeature(UserDict):
|
|
|
154
181
|
|
|
155
182
|
# Do the tensor conversion in batch
|
|
156
183
|
for key, value in self.items():
|
|
184
|
+
# Skip keys explicitly marked for no conversion
|
|
185
|
+
if skip_tensor_conversion and key in skip_tensor_conversion:
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
# Skip values that are not array-like
|
|
189
|
+
if not _is_tensor_or_array_like(value):
|
|
190
|
+
continue
|
|
191
|
+
|
|
157
192
|
try:
|
|
158
193
|
if not is_tensor(value):
|
|
159
194
|
tensor = as_tensor(value)
|
|
160
|
-
|
|
161
195
|
self[key] = tensor
|
|
162
|
-
except
|
|
196
|
+
except Exception as e:
|
|
163
197
|
if key == "overflowing_values":
|
|
164
|
-
raise ValueError(
|
|
198
|
+
raise ValueError(
|
|
199
|
+
f"Unable to create tensor for '{key}' with overflowing values of different lengths. "
|
|
200
|
+
f"Original error: {str(e)}"
|
|
201
|
+
) from e
|
|
165
202
|
raise ValueError(
|
|
166
|
-
"Unable to
|
|
167
|
-
"
|
|
168
|
-
|
|
203
|
+
f"Unable to convert output '{key}' (type: {type(value).__name__}) to tensor: {str(e)}\n"
|
|
204
|
+
f"You can try:\n"
|
|
205
|
+
f" 1. Use padding=True to ensure all outputs have the same shape\n"
|
|
206
|
+
f" 2. Set return_tensors=None to return Python objects instead of tensors"
|
|
207
|
+
) from e
|
|
169
208
|
|
|
170
209
|
return self
|
|
171
210
|
|
|
@@ -204,12 +243,15 @@ class BatchFeature(UserDict):
|
|
|
204
243
|
|
|
205
244
|
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
|
206
245
|
def maybe_to(v):
|
|
207
|
-
# check if v is a floating point
|
|
246
|
+
# check if v is a floating point tensor
|
|
208
247
|
if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
|
|
209
248
|
# cast and send to device
|
|
210
249
|
return v.to(*args, **kwargs)
|
|
211
250
|
elif isinstance(v, torch.Tensor) and device is not None:
|
|
212
251
|
return v.to(device=device, non_blocking=non_blocking)
|
|
252
|
+
# recursively handle lists and tuples
|
|
253
|
+
elif isinstance(v, (list, tuple)):
|
|
254
|
+
return type(v)(maybe_to(item) for item in v)
|
|
213
255
|
else:
|
|
214
256
|
return v
|
|
215
257
|
|
|
@@ -227,8 +269,8 @@ class FeatureExtractionMixin(PushToHubMixin):
|
|
|
227
269
|
|
|
228
270
|
def __init__(self, **kwargs):
|
|
229
271
|
"""Set elements of `kwargs` as attributes."""
|
|
230
|
-
# Pop "processor_class"
|
|
231
|
-
|
|
272
|
+
# Pop "processor_class", it should not be saved in feature extractor config
|
|
273
|
+
kwargs.pop("processor_class", None)
|
|
232
274
|
# Additional attributes without default values
|
|
233
275
|
for key, value in kwargs.items():
|
|
234
276
|
try:
|
|
@@ -237,10 +279,6 @@ class FeatureExtractionMixin(PushToHubMixin):
|
|
|
237
279
|
logger.error(f"Can't set {key} with value {value} for {self}")
|
|
238
280
|
raise err
|
|
239
281
|
|
|
240
|
-
def _set_processor_class(self, processor_class: str):
|
|
241
|
-
"""Sets processor class as an attribute."""
|
|
242
|
-
self._processor_class = processor_class
|
|
243
|
-
|
|
244
282
|
@classmethod
|
|
245
283
|
def from_pretrained(
|
|
246
284
|
cls: type[SpecificFeatureExtractorType],
|
|
@@ -584,12 +622,6 @@ class FeatureExtractionMixin(PushToHubMixin):
|
|
|
584
622
|
if isinstance(value, np.ndarray):
|
|
585
623
|
dictionary[key] = value.tolist()
|
|
586
624
|
|
|
587
|
-
# make sure private name "_processor_class" is correctly
|
|
588
|
-
# saved as "processor_class"
|
|
589
|
-
_processor_class = dictionary.pop("_processor_class", None)
|
|
590
|
-
if _processor_class is not None:
|
|
591
|
-
dictionary["processor_class"] = _processor_class
|
|
592
|
-
|
|
593
625
|
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
|
|
594
626
|
|
|
595
627
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|