transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -19,6 +19,9 @@ import os
|
|
|
19
19
|
import re
|
|
20
20
|
from functools import partial, reduce
|
|
21
21
|
|
|
22
|
+
from ..distributed import DistributedConfig
|
|
23
|
+
from ..utils import is_torch_greater_or_equal, logging
|
|
24
|
+
from ..utils.generic import GeneralInterface
|
|
22
25
|
from ..utils.import_utils import is_torch_available
|
|
23
26
|
|
|
24
27
|
|
|
@@ -27,14 +30,6 @@ if is_torch_available():
|
|
|
27
30
|
import torch.distributed as dist
|
|
28
31
|
from torch import nn
|
|
29
32
|
|
|
30
|
-
from ..distributed import DistributedConfig
|
|
31
|
-
from ..utils import is_torch_greater_or_equal, logging
|
|
32
|
-
from ..utils.generic import GeneralInterface
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
logger = logging.get_logger(__name__)
|
|
36
|
-
|
|
37
|
-
if is_torch_available():
|
|
38
33
|
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
|
39
34
|
_torch_distributed_available = torch.distributed.is_available()
|
|
40
35
|
|
|
@@ -42,6 +37,9 @@ if is_torch_available():
|
|
|
42
37
|
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
|
|
43
38
|
|
|
44
39
|
|
|
40
|
+
logger = logging.get_logger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
45
43
|
def initialize_tensor_parallelism(
|
|
46
44
|
tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
|
|
47
45
|
):
|
|
@@ -470,7 +468,12 @@ class TensorParallelLayer:
|
|
|
470
468
|
@staticmethod
|
|
471
469
|
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): ...
|
|
472
470
|
|
|
473
|
-
def
|
|
471
|
+
def shard_tensor(
|
|
472
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
473
|
+
) -> torch.Tensor:
|
|
474
|
+
raise NotImplementedError
|
|
475
|
+
|
|
476
|
+
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
474
477
|
raise NotImplementedError
|
|
475
478
|
|
|
476
479
|
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
|
@@ -519,19 +522,10 @@ class GatherParallel(TensorParallelLayer):
|
|
|
519
522
|
return outputs
|
|
520
523
|
|
|
521
524
|
def shard_tensor(
|
|
522
|
-
self,
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
to_contiguous=None,
|
|
527
|
-
rank=None,
|
|
528
|
-
device_mesh=None,
|
|
529
|
-
tensor_idx=None,
|
|
530
|
-
):
|
|
531
|
-
shard = [Replicate()]
|
|
532
|
-
parameter = param[...].to(param_casting_dtype)
|
|
533
|
-
self.shard = shard
|
|
534
|
-
return parameter, shard
|
|
525
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
526
|
+
) -> torch.Tensor:
|
|
527
|
+
self.shard = [Replicate()]
|
|
528
|
+
return param[...].to(device=device, dtype=dtype)
|
|
535
529
|
|
|
536
530
|
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
|
537
531
|
distribute_module(
|
|
@@ -562,29 +556,20 @@ class IsolatedParallel(TensorParallelLayer):
|
|
|
562
556
|
return outputs
|
|
563
557
|
|
|
564
558
|
def shard_tensor(
|
|
565
|
-
self,
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
rank=None,
|
|
571
|
-
device_mesh=None,
|
|
572
|
-
tensor_idx=None,
|
|
573
|
-
):
|
|
574
|
-
mesh = device_mesh or self.device_mesh
|
|
575
|
-
parameter = param[...].to(param_casting_dtype)
|
|
576
|
-
if mesh is not None:
|
|
577
|
-
parameter = parameter / mesh.size()
|
|
559
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
560
|
+
) -> torch.Tensor:
|
|
561
|
+
parameter = param[...].to(device=device, dtype=dtype)
|
|
562
|
+
if self.device_mesh is not None:
|
|
563
|
+
parameter = parameter / self.device_mesh.size()
|
|
578
564
|
self.shard = None
|
|
579
|
-
return parameter
|
|
565
|
+
return parameter
|
|
580
566
|
|
|
581
|
-
def partition_tensor(self, param
|
|
582
|
-
|
|
567
|
+
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
568
|
+
parameter = self.shard_tensor(param, dtype=dtype)
|
|
583
569
|
if to_contiguous:
|
|
584
|
-
|
|
585
|
-
param = param / device_mesh.size() # TODO should be optionable
|
|
570
|
+
parameter = parameter.contiguous()
|
|
586
571
|
# TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel)
|
|
587
|
-
return
|
|
572
|
+
return parameter
|
|
588
573
|
|
|
589
574
|
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
|
590
575
|
distribute_module(
|
|
@@ -623,31 +608,15 @@ class ReplicateParallel(TensorParallelLayer):
|
|
|
623
608
|
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
|
|
624
609
|
|
|
625
610
|
def shard_tensor(
|
|
626
|
-
self,
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
tensor_idx=None,
|
|
634
|
-
):
|
|
635
|
-
parameter = param[...].to(param_casting_dtype)
|
|
636
|
-
shard = [Replicate()]
|
|
637
|
-
self.shard = shard
|
|
638
|
-
return parameter, shard
|
|
639
|
-
|
|
640
|
-
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
|
641
|
-
parameter, shard = self.shard_tensor(
|
|
642
|
-
param,
|
|
643
|
-
param_type=param_type,
|
|
644
|
-
param_casting_dtype=param_casting_dtype,
|
|
645
|
-
to_contiguous=to_contiguous,
|
|
646
|
-
rank=rank,
|
|
647
|
-
device_mesh=device_mesh,
|
|
648
|
-
)
|
|
611
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
612
|
+
) -> torch.Tensor:
|
|
613
|
+
self.shard = [Replicate()]
|
|
614
|
+
return param[...].to(device=device, dtype=dtype)
|
|
615
|
+
|
|
616
|
+
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
617
|
+
parameter = self.shard_tensor(param, dtype=dtype)
|
|
649
618
|
if self.use_dtensor:
|
|
650
|
-
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
|
619
|
+
parameter = DTensor.from_local(parameter, self.device_mesh, self.shard, run_check=False)
|
|
651
620
|
return parameter
|
|
652
621
|
|
|
653
622
|
|
|
@@ -685,38 +654,34 @@ class ColwiseParallel(TensorParallelLayer):
|
|
|
685
654
|
return input_tensor
|
|
686
655
|
|
|
687
656
|
def shard_tensor(
|
|
688
|
-
self,
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
device_mesh=None,
|
|
695
|
-
tensor_idx=None,
|
|
696
|
-
):
|
|
697
|
-
device_mesh = self.device_mesh
|
|
698
|
-
empty_param = self.empty_param
|
|
699
|
-
rank = self.rank
|
|
700
|
-
if param_type == "bias":
|
|
701
|
-
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
|
|
657
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
658
|
+
) -> torch.Tensor:
|
|
659
|
+
# If only 1 dim, shard this one (usually it's a `bias`)
|
|
660
|
+
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
661
|
+
if dim == 1:
|
|
662
|
+
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1, tensor_idx)
|
|
702
663
|
shard = [Shard(-1)]
|
|
703
664
|
else:
|
|
704
665
|
shard = [Shard(-2)]
|
|
705
|
-
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
|
|
706
|
-
parameter = parameter.to(param_casting_dtype)
|
|
666
|
+
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2, tensor_idx)
|
|
707
667
|
self.shard = shard
|
|
708
|
-
return parameter,
|
|
668
|
+
return parameter.to(device=device, dtype=dtype)
|
|
709
669
|
|
|
710
|
-
def partition_tensor(self, param
|
|
670
|
+
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
711
671
|
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
|
712
672
|
# means Colwise as Linear is input * weight^T + bias, where
|
|
713
673
|
# weight would become Shard(1)
|
|
714
|
-
parameter
|
|
674
|
+
parameter = self.shard_tensor(param, dtype=dtype)
|
|
715
675
|
if to_contiguous:
|
|
716
676
|
parameter = parameter.contiguous()
|
|
717
677
|
if self.use_dtensor:
|
|
718
678
|
parameter = DTensor.from_local(
|
|
719
|
-
parameter,
|
|
679
|
+
parameter,
|
|
680
|
+
self.device_mesh,
|
|
681
|
+
self.shard,
|
|
682
|
+
run_check=False,
|
|
683
|
+
shape=self.empty_param.size(),
|
|
684
|
+
stride=self.empty_param.stride(),
|
|
720
685
|
)
|
|
721
686
|
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
|
722
687
|
|
|
@@ -731,33 +696,41 @@ class ColwiseParallel(TensorParallelLayer):
|
|
|
731
696
|
|
|
732
697
|
class PackedColwiseParallel(ColwiseParallel):
|
|
733
698
|
def shard_tensor(
|
|
734
|
-
self,
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
to_contiguous=None,
|
|
739
|
-
rank=None,
|
|
740
|
-
device_mesh=None,
|
|
741
|
-
tensor_idx=None,
|
|
742
|
-
):
|
|
743
|
-
device_mesh = device_mesh or self.device_mesh
|
|
744
|
-
empty_param = self.empty_param
|
|
745
|
-
rank = rank if rank is not None else self.rank
|
|
746
|
-
return get_packed_weights(param, empty_param, device_mesh, rank, -2).to(param_casting_dtype), [Shard(-2)]
|
|
699
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
700
|
+
) -> torch.Tensor:
|
|
701
|
+
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
|
|
702
|
+
return parameter.to(device=device, dtype=dtype)
|
|
747
703
|
|
|
748
|
-
def partition_tensor(self, param
|
|
704
|
+
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
749
705
|
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
|
750
706
|
# means Colwise as Linear is input * weight^T + bias, where
|
|
751
707
|
# weight would become Shard(1)
|
|
752
|
-
parameter =
|
|
753
|
-
parameter = parameter.to(param_casting_dtype)
|
|
708
|
+
parameter = self.shard_tensor(param, dtype=dtype)
|
|
754
709
|
if to_contiguous:
|
|
755
710
|
parameter = parameter.contiguous()
|
|
756
711
|
if self.use_dtensor:
|
|
757
|
-
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
|
|
712
|
+
parameter = DTensor.from_local(parameter, self.device_mesh, [Shard(-2)], run_check=False)
|
|
758
713
|
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
|
759
714
|
|
|
760
715
|
|
|
716
|
+
class LocalColwiseParallel(ColwiseParallel):
|
|
717
|
+
"""
|
|
718
|
+
Colwise parallel with use_dtensor=False for local tensor operations.
|
|
719
|
+
"""
|
|
720
|
+
|
|
721
|
+
def __init__(self, **kwargs):
|
|
722
|
+
super().__init__(use_dtensor=False, **kwargs)
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
class ColwiseParallelReplicate(ColwiseParallel):
|
|
726
|
+
"""
|
|
727
|
+
Colwise parallel with output layouts replicated.
|
|
728
|
+
"""
|
|
729
|
+
|
|
730
|
+
def __init__(self, **kwargs):
|
|
731
|
+
super().__init__(output_layouts=Replicate(), **kwargs)
|
|
732
|
+
|
|
733
|
+
|
|
761
734
|
class RowwiseParallel(TensorParallelLayer):
|
|
762
735
|
"""
|
|
763
736
|
Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
|
|
@@ -782,7 +755,7 @@ class RowwiseParallel(TensorParallelLayer):
|
|
|
782
755
|
input_layouts: Placement | None = None,
|
|
783
756
|
output_layouts: Placement | None = None,
|
|
784
757
|
use_local_output: bool = True,
|
|
785
|
-
use_dtensor=True,
|
|
758
|
+
use_dtensor: bool = True,
|
|
786
759
|
**kwargs,
|
|
787
760
|
):
|
|
788
761
|
super().__init__(**kwargs)
|
|
@@ -792,45 +765,36 @@ class RowwiseParallel(TensorParallelLayer):
|
|
|
792
765
|
self.use_dtensor = use_dtensor
|
|
793
766
|
|
|
794
767
|
def shard_tensor(
|
|
795
|
-
self,
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
rank=None,
|
|
801
|
-
device_mesh=None,
|
|
802
|
-
tensor_idx=None,
|
|
803
|
-
):
|
|
804
|
-
device_mesh = device_mesh or self.device_mesh
|
|
805
|
-
empty_param = self.empty_param
|
|
806
|
-
rank = rank if rank is not None else self.rank
|
|
807
|
-
if param_type == "bias":
|
|
768
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
769
|
+
) -> torch.Tensor:
|
|
770
|
+
# If only 1 dim, it should not be sharded (usually it's a `bias`)
|
|
771
|
+
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
772
|
+
if dim == 1:
|
|
808
773
|
shard = [Replicate()]
|
|
809
774
|
parameter = param[...]
|
|
810
775
|
else:
|
|
811
|
-
parameter = get_tensor_shard(
|
|
776
|
+
parameter = get_tensor_shard(
|
|
777
|
+
param, self.empty_param, self.device_mesh, self.rank, -1, tensor_idx=tensor_idx
|
|
778
|
+
)
|
|
812
779
|
shard = [Shard(-1)]
|
|
813
|
-
parameter = parameter.to(param_casting_dtype)
|
|
814
780
|
self.shard = shard
|
|
815
|
-
return parameter,
|
|
781
|
+
return parameter.to(device=device, dtype=dtype)
|
|
816
782
|
|
|
817
|
-
def partition_tensor(self, param
|
|
783
|
+
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
818
784
|
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
|
819
785
|
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
|
820
786
|
# weight would become Shard(0)
|
|
821
|
-
|
|
822
|
-
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
|
|
823
|
-
shard = [Shard(-1)]
|
|
824
|
-
else:
|
|
825
|
-
shard = [Replicate()]
|
|
826
|
-
parameter = param[:]
|
|
827
|
-
|
|
828
|
-
parameter = parameter.to(param_casting_dtype)
|
|
787
|
+
parameter = self.shard_tensor(param, dtype=dtype)
|
|
829
788
|
if to_contiguous:
|
|
830
789
|
parameter = parameter.contiguous()
|
|
831
790
|
if self.use_dtensor:
|
|
832
791
|
parameter = DTensor.from_local(
|
|
833
|
-
parameter,
|
|
792
|
+
parameter,
|
|
793
|
+
self.device_mesh,
|
|
794
|
+
self.shard,
|
|
795
|
+
run_check=False,
|
|
796
|
+
shape=self.empty_param.size(),
|
|
797
|
+
stride=self.empty_param.stride(),
|
|
834
798
|
)
|
|
835
799
|
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
|
836
800
|
|
|
@@ -886,33 +850,50 @@ class RowwiseParallel(TensorParallelLayer):
|
|
|
886
850
|
|
|
887
851
|
class PackedRowwiseParallel(RowwiseParallel):
|
|
888
852
|
def shard_tensor(
|
|
889
|
-
self,
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
to_contiguous=None,
|
|
894
|
-
rank=None,
|
|
895
|
-
device_mesh=None,
|
|
896
|
-
tensor_idx=None,
|
|
897
|
-
):
|
|
898
|
-
device_mesh = device_mesh or self.device_mesh
|
|
899
|
-
empty_param = self.empty_param
|
|
900
|
-
rank = rank if rank is not None else self.rank
|
|
901
|
-
return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)]
|
|
853
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
854
|
+
) -> torch.Tensor:
|
|
855
|
+
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
856
|
+
return parameter.to(device=device, dtype=dtype)
|
|
902
857
|
|
|
903
|
-
def partition_tensor(self, param
|
|
858
|
+
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
904
859
|
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
|
905
860
|
# means Colwise as Linear is input * weight^T + bias, where
|
|
906
861
|
# weight would become Shard(1)
|
|
907
|
-
parameter =
|
|
908
|
-
parameter = parameter.to(param_casting_dtype)
|
|
862
|
+
parameter = self.shard_tensor(param, dtype=dtype)
|
|
909
863
|
if to_contiguous:
|
|
910
864
|
parameter = parameter.contiguous()
|
|
911
865
|
if self.use_dtensor:
|
|
912
|
-
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-1)], run_check=False)
|
|
866
|
+
parameter = DTensor.from_local(parameter, self.device_mesh, [Shard(-1)], run_check=False)
|
|
913
867
|
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
|
914
868
|
|
|
915
869
|
|
|
870
|
+
class LocalRowwiseParallel(RowwiseParallel):
|
|
871
|
+
"""
|
|
872
|
+
Rowwise parallel with use_dtensor=False for local tensor operations.
|
|
873
|
+
"""
|
|
874
|
+
|
|
875
|
+
def __init__(self, **kwargs):
|
|
876
|
+
super().__init__(use_dtensor=False, **kwargs)
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
class LocalPackedRowwiseParallel(PackedRowwiseParallel):
|
|
880
|
+
"""
|
|
881
|
+
Packed rowwise parallel with use_dtensor=False for local tensor operations.
|
|
882
|
+
"""
|
|
883
|
+
|
|
884
|
+
def __init__(self, **kwargs):
|
|
885
|
+
super().__init__(use_dtensor=False, **kwargs)
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
class RowwiseParallelReplicate(RowwiseParallel):
|
|
889
|
+
"""
|
|
890
|
+
Rowwise parallel with input layouts replicated.
|
|
891
|
+
"""
|
|
892
|
+
|
|
893
|
+
def __init__(self, **kwargs):
|
|
894
|
+
super().__init__(input_layouts=Replicate(), **kwargs)
|
|
895
|
+
|
|
896
|
+
|
|
916
897
|
class SequenceParallel(TensorParallelLayer):
|
|
917
898
|
"""
|
|
918
899
|
SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
|
|
@@ -970,18 +951,13 @@ class SequenceParallel(TensorParallelLayer):
|
|
|
970
951
|
|
|
971
952
|
def shard_tensor(
|
|
972
953
|
self,
|
|
973
|
-
param,
|
|
974
|
-
param_type=None,
|
|
975
|
-
param_casting_dtype=None,
|
|
976
|
-
to_contiguous=None,
|
|
977
|
-
rank=None,
|
|
978
|
-
device_mesh=None,
|
|
954
|
+
param: torch.Tensor,
|
|
979
955
|
tensor_idx=None,
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
self.shard =
|
|
984
|
-
return
|
|
956
|
+
device=None,
|
|
957
|
+
dtype=None,
|
|
958
|
+
) -> torch.Tensor:
|
|
959
|
+
self.shard = [Replicate()]
|
|
960
|
+
return param[...].to(device=device, dtype=dtype)
|
|
985
961
|
|
|
986
962
|
@staticmethod
|
|
987
963
|
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
|
@@ -999,16 +975,15 @@ class SequenceParallel(TensorParallelLayer):
|
|
|
999
975
|
) # maybe we have to replicate ? because next layer is not sharded
|
|
1000
976
|
return outputs.to_local() # if use_local_output else outputs
|
|
1001
977
|
|
|
1002
|
-
def partition_tensor(self, param
|
|
978
|
+
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
1003
979
|
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
|
1004
980
|
# means Colwise as Linear is input * weight^T + bias, where
|
|
1005
981
|
# weight would become Shard(1)
|
|
1006
|
-
parameter = param
|
|
1007
|
-
parameter = parameter.to(param_casting_dtype)
|
|
982
|
+
parameter = self.shard_tensor(param, dtype=dtype)
|
|
1008
983
|
if to_contiguous:
|
|
1009
984
|
parameter = parameter.contiguous()
|
|
1010
985
|
if self.use_dtensor:
|
|
1011
|
-
parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
|
|
986
|
+
parameter = DTensor.from_local(parameter, self.device_mesh, [Replicate()], run_check=False)
|
|
1012
987
|
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
|
1013
988
|
|
|
1014
989
|
|
|
@@ -1022,41 +997,23 @@ class GroupedGemmParallel(TensorParallelLayer):
|
|
|
1022
997
|
self.use_dtensor = False
|
|
1023
998
|
|
|
1024
999
|
def shard_tensor(
|
|
1025
|
-
self,
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
to_contiguous=None,
|
|
1030
|
-
rank=None,
|
|
1031
|
-
device_mesh=None,
|
|
1032
|
-
tensor_idx=None,
|
|
1033
|
-
):
|
|
1034
|
-
empty_param = self.empty_param
|
|
1035
|
-
ep_rank = self.rank
|
|
1036
|
-
device_mesh = self.device_mesh
|
|
1037
|
-
|
|
1038
|
-
global_num_experts = empty_param.shape[0]
|
|
1039
|
-
if global_num_experts % device_mesh.size() != 0:
|
|
1000
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
1001
|
+
) -> torch.Tensor:
|
|
1002
|
+
global_num_experts = self.empty_param.shape[0]
|
|
1003
|
+
if global_num_experts % self.device_mesh.size() != 0:
|
|
1040
1004
|
raise ValueError(
|
|
1041
|
-
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
|
|
1005
|
+
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {self.device_mesh.size()} != 0"
|
|
1042
1006
|
)
|
|
1043
|
-
local_num_experts = global_num_experts // device_mesh.size()
|
|
1044
|
-
parameter = param[
|
|
1007
|
+
local_num_experts = global_num_experts // self.device_mesh.size()
|
|
1008
|
+
parameter = param[self.rank * local_num_experts : (self.rank + 1) * local_num_experts]
|
|
1045
1009
|
self.shard = None
|
|
1046
|
-
return parameter,
|
|
1010
|
+
return parameter.to(device=device, dtype=dtype)
|
|
1047
1011
|
|
|
1048
|
-
def partition_tensor(self, param
|
|
1049
|
-
|
|
1050
|
-
global_num_experts = empty_param.shape[0]
|
|
1051
|
-
if global_num_experts % device_mesh.size() != 0:
|
|
1052
|
-
raise ValueError(
|
|
1053
|
-
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
|
|
1054
|
-
)
|
|
1055
|
-
local_num_experts = global_num_experts // device_mesh.size()
|
|
1056
|
-
param = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
|
|
1012
|
+
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
1013
|
+
parameter = self.shard_tensor(param, dtype=dtype)
|
|
1057
1014
|
if to_contiguous:
|
|
1058
|
-
|
|
1059
|
-
return
|
|
1015
|
+
parameter = parameter.contiguous()
|
|
1016
|
+
return parameter
|
|
1060
1017
|
|
|
1061
1018
|
|
|
1062
1019
|
class RouterParallel(TensorParallelLayer):
|
|
@@ -1064,10 +1021,10 @@ class RouterParallel(TensorParallelLayer):
|
|
|
1064
1021
|
Allows to reshape the router scores to support running expert parallel.
|
|
1065
1022
|
"""
|
|
1066
1023
|
|
|
1067
|
-
def __init__(self, *args, **kwargs):
|
|
1024
|
+
def __init__(self, use_dtensor: bool = False, *args, **kwargs):
|
|
1068
1025
|
super().__init__(**kwargs)
|
|
1069
1026
|
self.args = args
|
|
1070
|
-
self.use_dtensor =
|
|
1027
|
+
self.use_dtensor = use_dtensor
|
|
1071
1028
|
|
|
1072
1029
|
@staticmethod
|
|
1073
1030
|
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
|
@@ -1118,7 +1075,7 @@ class RouterParallel(TensorParallelLayer):
|
|
|
1118
1075
|
f"The number of experts must be divisible by number of ep_size: {mod.num_experts} % {ep_size} != 0"
|
|
1119
1076
|
)
|
|
1120
1077
|
num_local_experts = mod.num_experts // ep_size
|
|
1121
|
-
router_scores, router_indices = outputs
|
|
1078
|
+
router_logits, router_scores, router_indices = outputs
|
|
1122
1079
|
router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
|
|
1123
1080
|
router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
|
|
1124
1081
|
# As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
|
|
@@ -1129,28 +1086,20 @@ class RouterParallel(TensorParallelLayer):
|
|
|
1129
1086
|
router_indices = router_indices.masked_fill(
|
|
1130
1087
|
router_indices == -1, num_local_experts
|
|
1131
1088
|
) # masking class for one hot
|
|
1132
|
-
return router_scores, router_indices
|
|
1089
|
+
return router_logits, router_scores, router_indices
|
|
1133
1090
|
|
|
1134
1091
|
def shard_tensor(
|
|
1135
|
-
self,
|
|
1136
|
-
|
|
1137
|
-
param_type=None,
|
|
1138
|
-
param_casting_dtype=None,
|
|
1139
|
-
to_contiguous=None,
|
|
1140
|
-
rank=None,
|
|
1141
|
-
device_mesh=None,
|
|
1142
|
-
tensor_idx=None,
|
|
1143
|
-
):
|
|
1144
|
-
parameter = param[...].to(param_casting_dtype)
|
|
1092
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
1093
|
+
) -> torch.Tensor:
|
|
1145
1094
|
self.shard = None
|
|
1146
|
-
return
|
|
1095
|
+
return param[...].to(device=device, dtype=dtype)
|
|
1147
1096
|
|
|
1148
|
-
def partition_tensor(self, param
|
|
1097
|
+
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
1149
1098
|
# TODO: i'd like for this to be the default
|
|
1150
|
-
|
|
1099
|
+
parameter = self.shard_tensor(param, dtype=dtype)
|
|
1151
1100
|
if to_contiguous:
|
|
1152
|
-
|
|
1153
|
-
return
|
|
1101
|
+
parameter = parameter.contiguous()
|
|
1102
|
+
return parameter
|
|
1154
1103
|
|
|
1155
1104
|
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
|
1156
1105
|
# TODO: need an abstract Parallel class that is different from TensorParallelLayer
|
|
@@ -1169,13 +1118,13 @@ class ParallelInterface(GeneralInterface):
|
|
|
1169
1118
|
{
|
|
1170
1119
|
"colwise": ColwiseParallel(),
|
|
1171
1120
|
"rowwise": RowwiseParallel(),
|
|
1172
|
-
"colwise_rep":
|
|
1173
|
-
"rowwise_rep":
|
|
1174
|
-
"local_colwise":
|
|
1175
|
-
"local_rowwise":
|
|
1121
|
+
"colwise_rep": ColwiseParallelReplicate(),
|
|
1122
|
+
"rowwise_rep": RowwiseParallelReplicate(),
|
|
1123
|
+
"local_colwise": LocalColwiseParallel(),
|
|
1124
|
+
"local_rowwise": LocalRowwiseParallel(),
|
|
1176
1125
|
"local": IsolatedParallel(),
|
|
1177
1126
|
"gather": GatherParallel(),
|
|
1178
|
-
"local_packed_rowwise":
|
|
1127
|
+
"local_packed_rowwise": LocalPackedRowwiseParallel(),
|
|
1179
1128
|
"sequence_parallel": SequenceParallel(),
|
|
1180
1129
|
"replicate": ReplicateParallel(),
|
|
1181
1130
|
"grouped_gemm": GroupedGemmParallel(),
|
|
@@ -1286,13 +1235,10 @@ def shard_and_distribute_module(
|
|
|
1286
1235
|
|
|
1287
1236
|
if current_shard_plan is not None:
|
|
1288
1237
|
try:
|
|
1289
|
-
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
|
|
1290
|
-
|
|
1291
|
-
tp_layer.device_mesh = device_mesh
|
|
1292
|
-
tp_layer.rank = rank
|
|
1293
|
-
param = tp_layer.partition_tensor(
|
|
1294
|
-
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
|
1238
|
+
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan](
|
|
1239
|
+
empty_param=empty_param, device_mesh=device_mesh, rank=rank
|
|
1295
1240
|
)
|
|
1241
|
+
param = tp_layer.partition_tensor(param, param_casting_dtype, is_contiguous)
|
|
1296
1242
|
except NotImplementedError as e:
|
|
1297
1243
|
print(
|
|
1298
1244
|
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
|
|
@@ -14,13 +14,11 @@
|
|
|
14
14
|
"VPTQ (Vector Post-Training Quantization) integration file"
|
|
15
15
|
|
|
16
16
|
from ..quantizers.quantizers_utils import should_convert_module
|
|
17
|
-
from ..utils import
|
|
17
|
+
from ..utils import is_torch_available, logging
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
if is_accelerate_available():
|
|
21
|
-
from accelerate import init_empty_weights
|
|
22
|
-
|
|
23
20
|
if is_torch_available():
|
|
21
|
+
import torch
|
|
24
22
|
import torch.nn as nn
|
|
25
23
|
|
|
26
24
|
logger = logging.get_logger(__name__)
|
|
@@ -48,7 +46,7 @@ def replace_with_vptq_linear(model, modules_to_not_convert: list[str] | None = N
|
|
|
48
46
|
for module_name, module in model.named_modules():
|
|
49
47
|
if not should_convert_module(module_name, modules_to_not_convert):
|
|
50
48
|
continue
|
|
51
|
-
with
|
|
49
|
+
with torch.device("meta"):
|
|
52
50
|
if isinstance(module, nn.Linear):
|
|
53
51
|
layer_params = config_for_layers.get(module_name, None) or shared_layer_config.get(
|
|
54
52
|
module_name.rsplit(".")[1], None
|