transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +11 -37
- transformers/activations.py +2 -2
- transformers/audio_utils.py +32 -32
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +26 -126
- transformers/cli/chat.py +3 -3
- transformers/cli/serve.py +13 -10
- transformers/cli/transformers.py +2 -1
- transformers/configuration_utils.py +22 -92
- transformers/conversion_mapping.py +150 -26
- transformers/convert_slow_tokenizer.py +9 -12
- transformers/core_model_loading.py +217 -129
- transformers/data/processors/glue.py +0 -1
- transformers/data/processors/utils.py +0 -1
- transformers/data/processors/xnli.py +0 -1
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +10 -11
- transformers/distributed/configuration_utils.py +1 -2
- transformers/dynamic_module_utils.py +23 -23
- transformers/feature_extraction_sequence_utils.py +19 -23
- transformers/feature_extraction_utils.py +14 -14
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +2 -4
- transformers/generation/configuration_utils.py +54 -39
- transformers/generation/continuous_batching/__init__.py +0 -1
- transformers/generation/continuous_batching/cache.py +74 -44
- transformers/generation/continuous_batching/cache_manager.py +28 -28
- transformers/generation/continuous_batching/continuous_api.py +133 -414
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +77 -19
- transformers/generation/continuous_batching/scheduler.py +154 -104
- transformers/generation/logits_process.py +10 -133
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/streamers.py +0 -1
- transformers/generation/utils.py +91 -121
- transformers/generation/watermarking.py +2 -3
- transformers/hf_argparser.py +9 -13
- transformers/hyperparameter_search.py +1 -2
- transformers/image_processing_base.py +9 -9
- transformers/image_processing_utils.py +11 -15
- transformers/image_processing_utils_fast.py +70 -71
- transformers/image_transforms.py +73 -42
- transformers/image_utils.py +30 -37
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/awq.py +1 -3
- transformers/integrations/deepspeed.py +146 -4
- transformers/integrations/eetq.py +0 -1
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/fbgemm_fp8.py +1 -2
- transformers/integrations/finegrained_fp8.py +149 -13
- transformers/integrations/flash_attention.py +3 -8
- transformers/integrations/flex_attention.py +1 -1
- transformers/integrations/fp_quant.py +4 -6
- transformers/integrations/ggml.py +0 -1
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/integration_utils.py +2 -3
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +52 -40
- transformers/integrations/peft.py +488 -176
- transformers/integrations/quark.py +2 -4
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/integrations/torchao.py +4 -6
- transformers/loss/loss_lw_detr.py +356 -0
- transformers/loss/loss_utils.py +2 -0
- transformers/masking_utils.py +199 -59
- transformers/model_debugging_utils.py +4 -5
- transformers/modelcard.py +14 -192
- transformers/modeling_attn_mask_utils.py +19 -19
- transformers/modeling_flash_attention_utils.py +28 -29
- transformers/modeling_gguf_pytorch_utils.py +5 -5
- transformers/modeling_layers.py +21 -22
- transformers/modeling_outputs.py +242 -253
- transformers/modeling_rope_utils.py +32 -32
- transformers/modeling_utils.py +416 -438
- transformers/models/__init__.py +10 -0
- transformers/models/afmoe/configuration_afmoe.py +40 -33
- transformers/models/afmoe/modeling_afmoe.py +38 -41
- transformers/models/afmoe/modular_afmoe.py +23 -25
- transformers/models/aimv2/configuration_aimv2.py +2 -10
- transformers/models/aimv2/modeling_aimv2.py +46 -45
- transformers/models/aimv2/modular_aimv2.py +13 -19
- transformers/models/albert/configuration_albert.py +8 -2
- transformers/models/albert/modeling_albert.py +70 -72
- transformers/models/albert/tokenization_albert.py +1 -4
- transformers/models/align/configuration_align.py +8 -6
- transformers/models/align/modeling_align.py +83 -86
- transformers/models/align/processing_align.py +2 -30
- transformers/models/altclip/configuration_altclip.py +4 -7
- transformers/models/altclip/modeling_altclip.py +106 -103
- transformers/models/altclip/processing_altclip.py +2 -15
- transformers/models/apertus/__init__.py +0 -1
- transformers/models/apertus/configuration_apertus.py +23 -28
- transformers/models/apertus/modeling_apertus.py +35 -38
- transformers/models/apertus/modular_apertus.py +36 -40
- transformers/models/arcee/configuration_arcee.py +25 -30
- transformers/models/arcee/modeling_arcee.py +35 -38
- transformers/models/arcee/modular_arcee.py +20 -23
- transformers/models/aria/configuration_aria.py +31 -44
- transformers/models/aria/image_processing_aria.py +25 -27
- transformers/models/aria/modeling_aria.py +102 -102
- transformers/models/aria/modular_aria.py +111 -124
- transformers/models/aria/processing_aria.py +28 -35
- transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
- transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
- transformers/models/audioflamingo3/__init__.py +0 -1
- transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
- transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
- transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
- transformers/models/auto/auto_factory.py +12 -11
- transformers/models/auto/configuration_auto.py +48 -5
- transformers/models/auto/feature_extraction_auto.py +5 -7
- transformers/models/auto/image_processing_auto.py +30 -39
- transformers/models/auto/modeling_auto.py +33 -199
- transformers/models/auto/processing_auto.py +11 -19
- transformers/models/auto/tokenization_auto.py +38 -37
- transformers/models/auto/video_processing_auto.py +7 -8
- transformers/models/autoformer/configuration_autoformer.py +4 -7
- transformers/models/autoformer/modeling_autoformer.py +100 -101
- transformers/models/aya_vision/configuration_aya_vision.py +4 -1
- transformers/models/aya_vision/modeling_aya_vision.py +64 -99
- transformers/models/aya_vision/modular_aya_vision.py +46 -74
- transformers/models/aya_vision/processing_aya_vision.py +25 -53
- transformers/models/bamba/configuration_bamba.py +46 -39
- transformers/models/bamba/modeling_bamba.py +83 -119
- transformers/models/bamba/modular_bamba.py +70 -109
- transformers/models/bark/configuration_bark.py +6 -8
- transformers/models/bark/generation_configuration_bark.py +3 -5
- transformers/models/bark/modeling_bark.py +64 -65
- transformers/models/bark/processing_bark.py +19 -41
- transformers/models/bart/configuration_bart.py +9 -5
- transformers/models/bart/modeling_bart.py +124 -129
- transformers/models/barthez/tokenization_barthez.py +1 -4
- transformers/models/bartpho/tokenization_bartpho.py +6 -7
- transformers/models/beit/configuration_beit.py +2 -15
- transformers/models/beit/image_processing_beit.py +53 -56
- transformers/models/beit/image_processing_beit_fast.py +11 -12
- transformers/models/beit/modeling_beit.py +65 -62
- transformers/models/bert/configuration_bert.py +12 -2
- transformers/models/bert/modeling_bert.py +117 -152
- transformers/models/bert/tokenization_bert.py +2 -4
- transformers/models/bert/tokenization_bert_legacy.py +3 -5
- transformers/models/bert_generation/configuration_bert_generation.py +17 -2
- transformers/models/bert_generation/modeling_bert_generation.py +53 -55
- transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
- transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
- transformers/models/bertweet/tokenization_bertweet.py +1 -3
- transformers/models/big_bird/configuration_big_bird.py +12 -9
- transformers/models/big_bird/modeling_big_bird.py +107 -124
- transformers/models/big_bird/tokenization_big_bird.py +1 -4
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
- transformers/models/biogpt/configuration_biogpt.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +73 -79
- transformers/models/biogpt/modular_biogpt.py +60 -66
- transformers/models/biogpt/tokenization_biogpt.py +3 -5
- transformers/models/bit/configuration_bit.py +2 -5
- transformers/models/bit/image_processing_bit.py +21 -24
- transformers/models/bit/image_processing_bit_fast.py +0 -1
- transformers/models/bit/modeling_bit.py +15 -16
- transformers/models/bitnet/configuration_bitnet.py +23 -28
- transformers/models/bitnet/modeling_bitnet.py +34 -38
- transformers/models/bitnet/modular_bitnet.py +7 -10
- transformers/models/blenderbot/configuration_blenderbot.py +8 -5
- transformers/models/blenderbot/modeling_blenderbot.py +68 -99
- transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
- transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
- transformers/models/blip/configuration_blip.py +9 -10
- transformers/models/blip/image_processing_blip.py +17 -20
- transformers/models/blip/image_processing_blip_fast.py +0 -1
- transformers/models/blip/modeling_blip.py +115 -108
- transformers/models/blip/modeling_blip_text.py +63 -65
- transformers/models/blip/processing_blip.py +5 -36
- transformers/models/blip_2/configuration_blip_2.py +2 -2
- transformers/models/blip_2/modeling_blip_2.py +145 -121
- transformers/models/blip_2/processing_blip_2.py +8 -38
- transformers/models/bloom/configuration_bloom.py +5 -2
- transformers/models/bloom/modeling_bloom.py +60 -60
- transformers/models/blt/configuration_blt.py +94 -86
- transformers/models/blt/modeling_blt.py +93 -90
- transformers/models/blt/modular_blt.py +127 -69
- transformers/models/bridgetower/configuration_bridgetower.py +7 -2
- transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
- transformers/models/bridgetower/modeling_bridgetower.py +136 -124
- transformers/models/bridgetower/processing_bridgetower.py +2 -16
- transformers/models/bros/configuration_bros.py +24 -18
- transformers/models/bros/modeling_bros.py +78 -80
- transformers/models/bros/processing_bros.py +2 -12
- transformers/models/byt5/tokenization_byt5.py +4 -6
- transformers/models/camembert/configuration_camembert.py +8 -2
- transformers/models/camembert/modeling_camembert.py +97 -99
- transformers/models/camembert/modular_camembert.py +51 -54
- transformers/models/camembert/tokenization_camembert.py +1 -4
- transformers/models/canine/configuration_canine.py +4 -2
- transformers/models/canine/modeling_canine.py +73 -75
- transformers/models/canine/tokenization_canine.py +0 -1
- transformers/models/chameleon/configuration_chameleon.py +29 -34
- transformers/models/chameleon/image_processing_chameleon.py +21 -24
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
- transformers/models/chameleon/modeling_chameleon.py +135 -92
- transformers/models/chameleon/processing_chameleon.py +16 -41
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
- transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
- transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
- transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
- transformers/models/clap/configuration_clap.py +4 -9
- transformers/models/clap/feature_extraction_clap.py +9 -10
- transformers/models/clap/modeling_clap.py +109 -111
- transformers/models/clap/processing_clap.py +2 -15
- transformers/models/clip/configuration_clip.py +4 -2
- transformers/models/clip/image_processing_clip.py +21 -24
- transformers/models/clip/image_processing_clip_fast.py +9 -1
- transformers/models/clip/modeling_clip.py +70 -68
- transformers/models/clip/processing_clip.py +2 -14
- transformers/models/clip/tokenization_clip.py +2 -5
- transformers/models/clipseg/configuration_clipseg.py +4 -2
- transformers/models/clipseg/modeling_clipseg.py +113 -112
- transformers/models/clipseg/processing_clipseg.py +19 -42
- transformers/models/clvp/configuration_clvp.py +15 -5
- transformers/models/clvp/feature_extraction_clvp.py +7 -10
- transformers/models/clvp/modeling_clvp.py +138 -145
- transformers/models/clvp/number_normalizer.py +1 -2
- transformers/models/clvp/processing_clvp.py +3 -20
- transformers/models/clvp/tokenization_clvp.py +0 -1
- transformers/models/code_llama/tokenization_code_llama.py +3 -6
- transformers/models/codegen/configuration_codegen.py +4 -4
- transformers/models/codegen/modeling_codegen.py +50 -49
- transformers/models/codegen/tokenization_codegen.py +5 -6
- transformers/models/cohere/configuration_cohere.py +25 -30
- transformers/models/cohere/modeling_cohere.py +39 -42
- transformers/models/cohere/modular_cohere.py +27 -31
- transformers/models/cohere/tokenization_cohere.py +5 -6
- transformers/models/cohere2/configuration_cohere2.py +27 -32
- transformers/models/cohere2/modeling_cohere2.py +38 -41
- transformers/models/cohere2/modular_cohere2.py +48 -52
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
- transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
- transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
- transformers/models/colpali/configuration_colpali.py +0 -1
- transformers/models/colpali/modeling_colpali.py +14 -16
- transformers/models/colpali/modular_colpali.py +11 -51
- transformers/models/colpali/processing_colpali.py +14 -52
- transformers/models/colqwen2/modeling_colqwen2.py +27 -28
- transformers/models/colqwen2/modular_colqwen2.py +36 -74
- transformers/models/colqwen2/processing_colqwen2.py +16 -52
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
- transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
- transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
- transformers/models/convbert/configuration_convbert.py +11 -8
- transformers/models/convbert/modeling_convbert.py +85 -87
- transformers/models/convbert/tokenization_convbert.py +0 -1
- transformers/models/convnext/configuration_convnext.py +2 -5
- transformers/models/convnext/image_processing_convnext.py +18 -21
- transformers/models/convnext/image_processing_convnext_fast.py +7 -8
- transformers/models/convnext/modeling_convnext.py +12 -14
- transformers/models/convnextv2/configuration_convnextv2.py +2 -5
- transformers/models/convnextv2/modeling_convnextv2.py +12 -14
- transformers/models/cpm/tokenization_cpm.py +6 -7
- transformers/models/cpm/tokenization_cpm_fast.py +3 -5
- transformers/models/cpmant/configuration_cpmant.py +4 -1
- transformers/models/cpmant/modeling_cpmant.py +38 -40
- transformers/models/cpmant/tokenization_cpmant.py +1 -3
- transformers/models/csm/configuration_csm.py +58 -66
- transformers/models/csm/generation_csm.py +13 -14
- transformers/models/csm/modeling_csm.py +81 -84
- transformers/models/csm/modular_csm.py +56 -58
- transformers/models/csm/processing_csm.py +25 -68
- transformers/models/ctrl/configuration_ctrl.py +16 -1
- transformers/models/ctrl/modeling_ctrl.py +51 -66
- transformers/models/ctrl/tokenization_ctrl.py +0 -1
- transformers/models/cvt/configuration_cvt.py +0 -1
- transformers/models/cvt/modeling_cvt.py +13 -15
- transformers/models/cwm/__init__.py +0 -1
- transformers/models/cwm/configuration_cwm.py +8 -12
- transformers/models/cwm/modeling_cwm.py +36 -38
- transformers/models/cwm/modular_cwm.py +10 -12
- transformers/models/d_fine/configuration_d_fine.py +10 -57
- transformers/models/d_fine/modeling_d_fine.py +786 -927
- transformers/models/d_fine/modular_d_fine.py +339 -417
- transformers/models/dab_detr/configuration_dab_detr.py +22 -49
- transformers/models/dab_detr/modeling_dab_detr.py +79 -77
- transformers/models/dac/configuration_dac.py +0 -1
- transformers/models/dac/feature_extraction_dac.py +6 -9
- transformers/models/dac/modeling_dac.py +22 -24
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
- transformers/models/data2vec/configuration_data2vec_text.py +11 -3
- transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
- transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
- transformers/models/data2vec/modeling_data2vec_text.py +97 -99
- transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
- transformers/models/data2vec/modular_data2vec_audio.py +6 -1
- transformers/models/data2vec/modular_data2vec_text.py +51 -54
- transformers/models/dbrx/configuration_dbrx.py +29 -22
- transformers/models/dbrx/modeling_dbrx.py +45 -48
- transformers/models/dbrx/modular_dbrx.py +37 -39
- transformers/models/deberta/configuration_deberta.py +6 -1
- transformers/models/deberta/modeling_deberta.py +57 -60
- transformers/models/deberta/tokenization_deberta.py +2 -5
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
- transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
- transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
- transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
- transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
- transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
- transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
- transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
- transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
- transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
- transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
- transformers/models/deit/configuration_deit.py +0 -1
- transformers/models/deit/image_processing_deit.py +18 -21
- transformers/models/deit/image_processing_deit_fast.py +0 -1
- transformers/models/deit/modeling_deit.py +27 -25
- transformers/models/depth_anything/configuration_depth_anything.py +12 -43
- transformers/models/depth_anything/modeling_depth_anything.py +10 -11
- transformers/models/depth_pro/configuration_depth_pro.py +0 -1
- transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
- transformers/models/depth_pro/modeling_depth_pro.py +29 -27
- transformers/models/detr/configuration_detr.py +18 -50
- transformers/models/detr/image_processing_detr.py +64 -66
- transformers/models/detr/image_processing_detr_fast.py +33 -34
- transformers/models/detr/modeling_detr.py +748 -789
- transformers/models/dia/configuration_dia.py +9 -15
- transformers/models/dia/feature_extraction_dia.py +6 -9
- transformers/models/dia/generation_dia.py +48 -53
- transformers/models/dia/modeling_dia.py +68 -71
- transformers/models/dia/modular_dia.py +56 -58
- transformers/models/dia/processing_dia.py +39 -29
- transformers/models/dia/tokenization_dia.py +3 -6
- transformers/models/diffllama/configuration_diffllama.py +25 -30
- transformers/models/diffllama/modeling_diffllama.py +45 -53
- transformers/models/diffllama/modular_diffllama.py +18 -25
- transformers/models/dinat/configuration_dinat.py +2 -5
- transformers/models/dinat/modeling_dinat.py +47 -48
- transformers/models/dinov2/configuration_dinov2.py +2 -5
- transformers/models/dinov2/modeling_dinov2.py +20 -21
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
- transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
- transformers/models/distilbert/configuration_distilbert.py +8 -2
- transformers/models/distilbert/modeling_distilbert.py +47 -49
- transformers/models/distilbert/tokenization_distilbert.py +0 -1
- transformers/models/doge/__init__.py +0 -1
- transformers/models/doge/configuration_doge.py +42 -35
- transformers/models/doge/modeling_doge.py +46 -49
- transformers/models/doge/modular_doge.py +77 -68
- transformers/models/donut/configuration_donut_swin.py +0 -1
- transformers/models/donut/image_processing_donut.py +26 -29
- transformers/models/donut/image_processing_donut_fast.py +9 -14
- transformers/models/donut/modeling_donut_swin.py +44 -46
- transformers/models/donut/processing_donut.py +5 -26
- transformers/models/dots1/configuration_dots1.py +43 -36
- transformers/models/dots1/modeling_dots1.py +35 -38
- transformers/models/dots1/modular_dots1.py +0 -1
- transformers/models/dpr/configuration_dpr.py +19 -2
- transformers/models/dpr/modeling_dpr.py +37 -39
- transformers/models/dpr/tokenization_dpr.py +7 -9
- transformers/models/dpr/tokenization_dpr_fast.py +7 -9
- transformers/models/dpt/configuration_dpt.py +23 -66
- transformers/models/dpt/image_processing_dpt.py +65 -66
- transformers/models/dpt/image_processing_dpt_fast.py +18 -19
- transformers/models/dpt/modeling_dpt.py +38 -36
- transformers/models/dpt/modular_dpt.py +14 -15
- transformers/models/edgetam/configuration_edgetam.py +1 -2
- transformers/models/edgetam/modeling_edgetam.py +87 -89
- transformers/models/edgetam/modular_edgetam.py +7 -13
- transformers/models/edgetam_video/__init__.py +0 -1
- transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
- transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
- transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
- transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
- transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
- transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
- transformers/models/efficientnet/configuration_efficientnet.py +0 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
- transformers/models/efficientnet/modeling_efficientnet.py +12 -14
- transformers/models/electra/configuration_electra.py +13 -3
- transformers/models/electra/modeling_electra.py +107 -109
- transformers/models/emu3/configuration_emu3.py +17 -17
- transformers/models/emu3/image_processing_emu3.py +44 -39
- transformers/models/emu3/modeling_emu3.py +143 -109
- transformers/models/emu3/modular_emu3.py +109 -73
- transformers/models/emu3/processing_emu3.py +18 -43
- transformers/models/encodec/configuration_encodec.py +2 -4
- transformers/models/encodec/feature_extraction_encodec.py +10 -13
- transformers/models/encodec/modeling_encodec.py +25 -29
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
- transformers/models/eomt/configuration_eomt.py +12 -14
- transformers/models/eomt/image_processing_eomt.py +53 -55
- transformers/models/eomt/image_processing_eomt_fast.py +18 -19
- transformers/models/eomt/modeling_eomt.py +19 -21
- transformers/models/eomt/modular_eomt.py +28 -30
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -3
- transformers/models/ernie/modeling_ernie.py +127 -162
- transformers/models/ernie/modular_ernie.py +91 -103
- transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
- transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
- transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
- transformers/models/esm/configuration_esm.py +11 -15
- transformers/models/esm/modeling_esm.py +35 -37
- transformers/models/esm/modeling_esmfold.py +43 -50
- transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
- transformers/models/esm/openfold_utils/loss.py +1 -2
- transformers/models/esm/openfold_utils/protein.py +15 -16
- transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
- transformers/models/esm/tokenization_esm.py +2 -4
- transformers/models/evolla/configuration_evolla.py +50 -40
- transformers/models/evolla/modeling_evolla.py +69 -68
- transformers/models/evolla/modular_evolla.py +50 -48
- transformers/models/evolla/processing_evolla.py +23 -35
- transformers/models/exaone4/configuration_exaone4.py +27 -27
- transformers/models/exaone4/modeling_exaone4.py +36 -39
- transformers/models/exaone4/modular_exaone4.py +51 -50
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +31 -26
- transformers/models/falcon/modeling_falcon.py +76 -84
- transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
- transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
- transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
- transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
- transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
- transformers/models/flaubert/configuration_flaubert.py +10 -5
- transformers/models/flaubert/modeling_flaubert.py +125 -129
- transformers/models/flaubert/tokenization_flaubert.py +3 -5
- transformers/models/flava/configuration_flava.py +9 -9
- transformers/models/flava/image_processing_flava.py +66 -67
- transformers/models/flava/image_processing_flava_fast.py +46 -47
- transformers/models/flava/modeling_flava.py +144 -135
- transformers/models/flava/processing_flava.py +2 -12
- transformers/models/flex_olmo/__init__.py +0 -1
- transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
- transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
- transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
- transformers/models/florence2/configuration_florence2.py +4 -1
- transformers/models/florence2/modeling_florence2.py +96 -72
- transformers/models/florence2/modular_florence2.py +100 -107
- transformers/models/florence2/processing_florence2.py +18 -47
- transformers/models/fnet/configuration_fnet.py +6 -2
- transformers/models/fnet/modeling_fnet.py +69 -80
- transformers/models/fnet/tokenization_fnet.py +0 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -5
- transformers/models/focalnet/modeling_focalnet.py +49 -48
- transformers/models/fsmt/configuration_fsmt.py +12 -17
- transformers/models/fsmt/modeling_fsmt.py +47 -48
- transformers/models/fsmt/tokenization_fsmt.py +3 -5
- transformers/models/funnel/configuration_funnel.py +8 -1
- transformers/models/funnel/modeling_funnel.py +91 -93
- transformers/models/funnel/tokenization_funnel.py +2 -5
- transformers/models/fuyu/configuration_fuyu.py +28 -34
- transformers/models/fuyu/image_processing_fuyu.py +29 -31
- transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
- transformers/models/fuyu/modeling_fuyu.py +50 -52
- transformers/models/fuyu/processing_fuyu.py +9 -36
- transformers/models/gemma/configuration_gemma.py +25 -30
- transformers/models/gemma/modeling_gemma.py +36 -38
- transformers/models/gemma/modular_gemma.py +33 -36
- transformers/models/gemma/tokenization_gemma.py +3 -6
- transformers/models/gemma2/configuration_gemma2.py +30 -35
- transformers/models/gemma2/modeling_gemma2.py +38 -41
- transformers/models/gemma2/modular_gemma2.py +63 -67
- transformers/models/gemma3/configuration_gemma3.py +53 -48
- transformers/models/gemma3/image_processing_gemma3.py +29 -31
- transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
- transformers/models/gemma3/modeling_gemma3.py +123 -122
- transformers/models/gemma3/modular_gemma3.py +128 -125
- transformers/models/gemma3/processing_gemma3.py +5 -5
- transformers/models/gemma3n/configuration_gemma3n.py +42 -30
- transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
- transformers/models/gemma3n/modeling_gemma3n.py +166 -147
- transformers/models/gemma3n/modular_gemma3n.py +176 -148
- transformers/models/gemma3n/processing_gemma3n.py +12 -26
- transformers/models/git/configuration_git.py +5 -8
- transformers/models/git/modeling_git.py +115 -127
- transformers/models/git/processing_git.py +2 -14
- transformers/models/glm/configuration_glm.py +26 -30
- transformers/models/glm/modeling_glm.py +36 -39
- transformers/models/glm/modular_glm.py +4 -7
- transformers/models/glm4/configuration_glm4.py +26 -30
- transformers/models/glm4/modeling_glm4.py +39 -41
- transformers/models/glm4/modular_glm4.py +8 -10
- transformers/models/glm46v/configuration_glm46v.py +4 -1
- transformers/models/glm46v/image_processing_glm46v.py +40 -38
- transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
- transformers/models/glm46v/modeling_glm46v.py +138 -93
- transformers/models/glm46v/modular_glm46v.py +5 -3
- transformers/models/glm46v/processing_glm46v.py +7 -41
- transformers/models/glm46v/video_processing_glm46v.py +9 -11
- transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
- transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
- transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
- transformers/models/glm4_moe_lite/__init__.py +28 -0
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
- transformers/models/glm4v/configuration_glm4v.py +25 -24
- transformers/models/glm4v/image_processing_glm4v.py +39 -38
- transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
- transformers/models/glm4v/modeling_glm4v.py +249 -210
- transformers/models/glm4v/modular_glm4v.py +211 -230
- transformers/models/glm4v/processing_glm4v.py +7 -41
- transformers/models/glm4v/video_processing_glm4v.py +9 -11
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
- transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
- transformers/models/glm_image/__init__.py +31 -0
- transformers/models/glm_image/configuration_glm_image.py +358 -0
- transformers/models/glm_image/image_processing_glm_image.py +503 -0
- transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
- transformers/models/glm_image/modeling_glm_image.py +1691 -0
- transformers/models/glm_image/modular_glm_image.py +1640 -0
- transformers/models/glm_image/processing_glm_image.py +265 -0
- transformers/models/glm_ocr/__init__.py +28 -0
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/__init__.py +0 -1
- transformers/models/glmasr/configuration_glmasr.py +0 -1
- transformers/models/glmasr/modeling_glmasr.py +51 -46
- transformers/models/glmasr/modular_glmasr.py +39 -29
- transformers/models/glmasr/processing_glmasr.py +7 -8
- transformers/models/glpn/configuration_glpn.py +0 -1
- transformers/models/glpn/image_processing_glpn.py +11 -12
- transformers/models/glpn/image_processing_glpn_fast.py +11 -12
- transformers/models/glpn/modeling_glpn.py +14 -14
- transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
- transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
- transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
- transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
- transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
- transformers/models/gpt2/configuration_gpt2.py +13 -2
- transformers/models/gpt2/modeling_gpt2.py +111 -113
- transformers/models/gpt2/tokenization_gpt2.py +6 -9
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
- transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
- transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
- transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
- transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
- transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
- transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
- transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
- transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
- transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
- transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
- transformers/models/gptj/configuration_gptj.py +4 -5
- transformers/models/gptj/modeling_gptj.py +85 -88
- transformers/models/granite/configuration_granite.py +28 -33
- transformers/models/granite/modeling_granite.py +43 -45
- transformers/models/granite/modular_granite.py +29 -31
- transformers/models/granite_speech/configuration_granite_speech.py +0 -1
- transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
- transformers/models/granite_speech/modeling_granite_speech.py +84 -60
- transformers/models/granite_speech/processing_granite_speech.py +11 -4
- transformers/models/granitemoe/configuration_granitemoe.py +31 -36
- transformers/models/granitemoe/modeling_granitemoe.py +39 -41
- transformers/models/granitemoe/modular_granitemoe.py +21 -23
- transformers/models/granitemoehybrid/__init__.py +0 -1
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
- transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
- transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
- transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
- transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
- transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
- transformers/models/groupvit/configuration_groupvit.py +4 -2
- transformers/models/groupvit/modeling_groupvit.py +98 -92
- transformers/models/helium/configuration_helium.py +25 -29
- transformers/models/helium/modeling_helium.py +37 -40
- transformers/models/helium/modular_helium.py +3 -7
- transformers/models/herbert/tokenization_herbert.py +4 -6
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
- transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
- transformers/models/hiera/configuration_hiera.py +2 -5
- transformers/models/hiera/modeling_hiera.py +71 -70
- transformers/models/hubert/configuration_hubert.py +4 -2
- transformers/models/hubert/modeling_hubert.py +42 -41
- transformers/models/hubert/modular_hubert.py +8 -11
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
- transformers/models/ibert/configuration_ibert.py +4 -2
- transformers/models/ibert/modeling_ibert.py +60 -62
- transformers/models/ibert/quant_modules.py +0 -1
- transformers/models/idefics/configuration_idefics.py +5 -8
- transformers/models/idefics/image_processing_idefics.py +13 -15
- transformers/models/idefics/modeling_idefics.py +63 -65
- transformers/models/idefics/perceiver.py +1 -3
- transformers/models/idefics/processing_idefics.py +32 -48
- transformers/models/idefics/vision.py +27 -28
- transformers/models/idefics2/configuration_idefics2.py +1 -3
- transformers/models/idefics2/image_processing_idefics2.py +31 -32
- transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
- transformers/models/idefics2/modeling_idefics2.py +126 -106
- transformers/models/idefics2/processing_idefics2.py +10 -68
- transformers/models/idefics3/configuration_idefics3.py +1 -4
- transformers/models/idefics3/image_processing_idefics3.py +42 -43
- transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
- transformers/models/idefics3/modeling_idefics3.py +113 -92
- transformers/models/idefics3/processing_idefics3.py +15 -69
- transformers/models/ijepa/configuration_ijepa.py +0 -1
- transformers/models/ijepa/modeling_ijepa.py +13 -14
- transformers/models/ijepa/modular_ijepa.py +5 -7
- transformers/models/imagegpt/configuration_imagegpt.py +9 -2
- transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
- transformers/models/imagegpt/modeling_imagegpt.py +65 -62
- transformers/models/informer/configuration_informer.py +6 -9
- transformers/models/informer/modeling_informer.py +87 -89
- transformers/models/informer/modular_informer.py +13 -16
- transformers/models/instructblip/configuration_instructblip.py +2 -2
- transformers/models/instructblip/modeling_instructblip.py +104 -79
- transformers/models/instructblip/processing_instructblip.py +10 -36
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
- transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
- transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
- transformers/models/internvl/configuration_internvl.py +5 -1
- transformers/models/internvl/modeling_internvl.py +76 -98
- transformers/models/internvl/modular_internvl.py +45 -59
- transformers/models/internvl/processing_internvl.py +12 -45
- transformers/models/internvl/video_processing_internvl.py +10 -11
- transformers/models/jais2/configuration_jais2.py +25 -29
- transformers/models/jais2/modeling_jais2.py +36 -38
- transformers/models/jais2/modular_jais2.py +20 -22
- transformers/models/jamba/configuration_jamba.py +5 -8
- transformers/models/jamba/modeling_jamba.py +47 -50
- transformers/models/jamba/modular_jamba.py +40 -41
- transformers/models/janus/configuration_janus.py +0 -1
- transformers/models/janus/image_processing_janus.py +37 -39
- transformers/models/janus/image_processing_janus_fast.py +20 -21
- transformers/models/janus/modeling_janus.py +103 -188
- transformers/models/janus/modular_janus.py +122 -83
- transformers/models/janus/processing_janus.py +17 -43
- transformers/models/jetmoe/configuration_jetmoe.py +26 -27
- transformers/models/jetmoe/modeling_jetmoe.py +42 -45
- transformers/models/jetmoe/modular_jetmoe.py +33 -36
- transformers/models/kosmos2/configuration_kosmos2.py +10 -9
- transformers/models/kosmos2/modeling_kosmos2.py +199 -178
- transformers/models/kosmos2/processing_kosmos2.py +40 -55
- transformers/models/kosmos2_5/__init__.py +0 -1
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
- transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
- transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
- transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
- transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
- transformers/models/lasr/configuration_lasr.py +3 -7
- transformers/models/lasr/feature_extraction_lasr.py +10 -12
- transformers/models/lasr/modeling_lasr.py +21 -24
- transformers/models/lasr/modular_lasr.py +11 -13
- transformers/models/lasr/processing_lasr.py +12 -6
- transformers/models/lasr/tokenization_lasr.py +2 -4
- transformers/models/layoutlm/configuration_layoutlm.py +14 -2
- transformers/models/layoutlm/modeling_layoutlm.py +70 -72
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
- transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
- transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
- transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
- transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
- transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
- transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
- transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
- transformers/models/led/configuration_led.py +8 -12
- transformers/models/led/modeling_led.py +113 -267
- transformers/models/levit/configuration_levit.py +0 -1
- transformers/models/levit/image_processing_levit.py +19 -21
- transformers/models/levit/image_processing_levit_fast.py +4 -5
- transformers/models/levit/modeling_levit.py +17 -19
- transformers/models/lfm2/configuration_lfm2.py +27 -30
- transformers/models/lfm2/modeling_lfm2.py +46 -48
- transformers/models/lfm2/modular_lfm2.py +32 -32
- transformers/models/lfm2_moe/__init__.py +0 -1
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
- transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
- transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
- transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
- transformers/models/lightglue/image_processing_lightglue.py +16 -15
- transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
- transformers/models/lightglue/modeling_lightglue.py +31 -33
- transformers/models/lightglue/modular_lightglue.py +31 -31
- transformers/models/lighton_ocr/__init__.py +28 -0
- transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
- transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
- transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
- transformers/models/lilt/configuration_lilt.py +6 -2
- transformers/models/lilt/modeling_lilt.py +53 -55
- transformers/models/llama/configuration_llama.py +26 -31
- transformers/models/llama/modeling_llama.py +35 -38
- transformers/models/llama/tokenization_llama.py +2 -4
- transformers/models/llama4/configuration_llama4.py +87 -69
- transformers/models/llama4/image_processing_llama4_fast.py +11 -12
- transformers/models/llama4/modeling_llama4.py +116 -115
- transformers/models/llama4/processing_llama4.py +33 -57
- transformers/models/llava/configuration_llava.py +10 -1
- transformers/models/llava/image_processing_llava.py +25 -28
- transformers/models/llava/image_processing_llava_fast.py +9 -10
- transformers/models/llava/modeling_llava.py +73 -102
- transformers/models/llava/processing_llava.py +18 -51
- transformers/models/llava_next/configuration_llava_next.py +2 -2
- transformers/models/llava_next/image_processing_llava_next.py +43 -45
- transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
- transformers/models/llava_next/modeling_llava_next.py +103 -104
- transformers/models/llava_next/processing_llava_next.py +18 -47
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
- transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
- transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
- transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
- transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
- transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
- transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
- transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
- transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
- transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
- transformers/models/longcat_flash/__init__.py +0 -1
- transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
- transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
- transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
- transformers/models/longformer/configuration_longformer.py +5 -5
- transformers/models/longformer/modeling_longformer.py +99 -101
- transformers/models/longt5/configuration_longt5.py +9 -7
- transformers/models/longt5/modeling_longt5.py +45 -45
- transformers/models/luke/configuration_luke.py +8 -2
- transformers/models/luke/modeling_luke.py +179 -181
- transformers/models/luke/tokenization_luke.py +99 -105
- transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
- transformers/models/lw_detr/configuration_lw_detr.py +362 -0
- transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
- transformers/models/lw_detr/modular_lw_detr.py +1609 -0
- transformers/models/lxmert/configuration_lxmert.py +16 -1
- transformers/models/lxmert/modeling_lxmert.py +63 -74
- transformers/models/m2m_100/configuration_m2m_100.py +7 -9
- transformers/models/m2m_100/modeling_m2m_100.py +72 -74
- transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
- transformers/models/mamba/configuration_mamba.py +5 -3
- transformers/models/mamba/modeling_mamba.py +61 -70
- transformers/models/mamba2/configuration_mamba2.py +5 -8
- transformers/models/mamba2/modeling_mamba2.py +66 -79
- transformers/models/marian/configuration_marian.py +10 -5
- transformers/models/marian/modeling_marian.py +88 -90
- transformers/models/marian/tokenization_marian.py +6 -6
- transformers/models/markuplm/configuration_markuplm.py +4 -7
- transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
- transformers/models/markuplm/modeling_markuplm.py +63 -65
- transformers/models/markuplm/processing_markuplm.py +31 -38
- transformers/models/markuplm/tokenization_markuplm.py +67 -77
- transformers/models/mask2former/configuration_mask2former.py +14 -52
- transformers/models/mask2former/image_processing_mask2former.py +84 -85
- transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
- transformers/models/mask2former/modeling_mask2former.py +108 -104
- transformers/models/mask2former/modular_mask2former.py +6 -8
- transformers/models/maskformer/configuration_maskformer.py +17 -51
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
- transformers/models/maskformer/image_processing_maskformer.py +84 -85
- transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
- transformers/models/maskformer/modeling_maskformer.py +71 -67
- transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
- transformers/models/mbart/configuration_mbart.py +9 -5
- transformers/models/mbart/modeling_mbart.py +120 -119
- transformers/models/mbart/tokenization_mbart.py +2 -4
- transformers/models/mbart50/tokenization_mbart50.py +3 -5
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
- transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
- transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
- transformers/models/mgp_str/configuration_mgp_str.py +0 -1
- transformers/models/mgp_str/modeling_mgp_str.py +18 -18
- transformers/models/mgp_str/processing_mgp_str.py +3 -20
- transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
- transformers/models/mimi/configuration_mimi.py +42 -40
- transformers/models/mimi/modeling_mimi.py +116 -115
- transformers/models/minimax/__init__.py +0 -1
- transformers/models/minimax/configuration_minimax.py +40 -47
- transformers/models/minimax/modeling_minimax.py +46 -49
- transformers/models/minimax/modular_minimax.py +59 -65
- transformers/models/minimax_m2/__init__.py +28 -0
- transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
- transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
- transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
- transformers/models/ministral/configuration_ministral.py +25 -29
- transformers/models/ministral/modeling_ministral.py +35 -37
- transformers/models/ministral/modular_ministral.py +32 -37
- transformers/models/ministral3/configuration_ministral3.py +23 -26
- transformers/models/ministral3/modeling_ministral3.py +35 -37
- transformers/models/ministral3/modular_ministral3.py +7 -8
- transformers/models/mistral/configuration_mistral.py +24 -29
- transformers/models/mistral/modeling_mistral.py +35 -37
- transformers/models/mistral/modular_mistral.py +14 -15
- transformers/models/mistral3/configuration_mistral3.py +4 -1
- transformers/models/mistral3/modeling_mistral3.py +79 -82
- transformers/models/mistral3/modular_mistral3.py +66 -67
- transformers/models/mixtral/configuration_mixtral.py +32 -38
- transformers/models/mixtral/modeling_mixtral.py +39 -42
- transformers/models/mixtral/modular_mixtral.py +26 -29
- transformers/models/mlcd/configuration_mlcd.py +0 -1
- transformers/models/mlcd/modeling_mlcd.py +17 -17
- transformers/models/mlcd/modular_mlcd.py +16 -16
- transformers/models/mllama/configuration_mllama.py +10 -15
- transformers/models/mllama/image_processing_mllama.py +23 -25
- transformers/models/mllama/image_processing_mllama_fast.py +11 -11
- transformers/models/mllama/modeling_mllama.py +100 -103
- transformers/models/mllama/processing_mllama.py +6 -55
- transformers/models/mluke/tokenization_mluke.py +97 -103
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
- transformers/models/mobilebert/configuration_mobilebert.py +4 -2
- transformers/models/mobilebert/modeling_mobilebert.py +78 -88
- transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
- transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
- transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
- transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
- transformers/models/mobilevit/configuration_mobilevit.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
- transformers/models/mobilevit/modeling_mobilevit.py +21 -21
- transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
- transformers/models/modernbert/configuration_modernbert.py +76 -51
- transformers/models/modernbert/modeling_modernbert.py +188 -943
- transformers/models/modernbert/modular_modernbert.py +255 -978
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
- transformers/models/moonshine/configuration_moonshine.py +34 -31
- transformers/models/moonshine/modeling_moonshine.py +70 -72
- transformers/models/moonshine/modular_moonshine.py +91 -86
- transformers/models/moshi/configuration_moshi.py +46 -23
- transformers/models/moshi/modeling_moshi.py +134 -142
- transformers/models/mpnet/configuration_mpnet.py +6 -2
- transformers/models/mpnet/modeling_mpnet.py +55 -57
- transformers/models/mpnet/tokenization_mpnet.py +1 -4
- transformers/models/mpt/configuration_mpt.py +17 -9
- transformers/models/mpt/modeling_mpt.py +58 -60
- transformers/models/mra/configuration_mra.py +8 -2
- transformers/models/mra/modeling_mra.py +54 -56
- transformers/models/mt5/configuration_mt5.py +9 -6
- transformers/models/mt5/modeling_mt5.py +80 -85
- transformers/models/musicgen/configuration_musicgen.py +12 -8
- transformers/models/musicgen/modeling_musicgen.py +114 -116
- transformers/models/musicgen/processing_musicgen.py +3 -21
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
- transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
- transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
- transformers/models/mvp/configuration_mvp.py +8 -5
- transformers/models/mvp/modeling_mvp.py +121 -123
- transformers/models/myt5/tokenization_myt5.py +8 -10
- transformers/models/nanochat/configuration_nanochat.py +5 -8
- transformers/models/nanochat/modeling_nanochat.py +36 -39
- transformers/models/nanochat/modular_nanochat.py +16 -18
- transformers/models/nemotron/configuration_nemotron.py +25 -30
- transformers/models/nemotron/modeling_nemotron.py +53 -66
- transformers/models/nllb/tokenization_nllb.py +14 -14
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
- transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
- transformers/models/nougat/image_processing_nougat.py +29 -32
- transformers/models/nougat/image_processing_nougat_fast.py +12 -13
- transformers/models/nougat/processing_nougat.py +37 -39
- transformers/models/nougat/tokenization_nougat.py +5 -7
- transformers/models/nystromformer/configuration_nystromformer.py +8 -2
- transformers/models/nystromformer/modeling_nystromformer.py +61 -63
- transformers/models/olmo/configuration_olmo.py +23 -28
- transformers/models/olmo/modeling_olmo.py +35 -38
- transformers/models/olmo/modular_olmo.py +8 -12
- transformers/models/olmo2/configuration_olmo2.py +27 -32
- transformers/models/olmo2/modeling_olmo2.py +36 -39
- transformers/models/olmo2/modular_olmo2.py +36 -38
- transformers/models/olmo3/__init__.py +0 -1
- transformers/models/olmo3/configuration_olmo3.py +30 -34
- transformers/models/olmo3/modeling_olmo3.py +35 -38
- transformers/models/olmo3/modular_olmo3.py +44 -47
- transformers/models/olmoe/configuration_olmoe.py +29 -33
- transformers/models/olmoe/modeling_olmoe.py +41 -43
- transformers/models/olmoe/modular_olmoe.py +15 -16
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
- transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
- transformers/models/oneformer/configuration_oneformer.py +11 -51
- transformers/models/oneformer/image_processing_oneformer.py +83 -84
- transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
- transformers/models/oneformer/modeling_oneformer.py +137 -133
- transformers/models/oneformer/processing_oneformer.py +28 -43
- transformers/models/openai/configuration_openai.py +16 -1
- transformers/models/openai/modeling_openai.py +50 -51
- transformers/models/openai/tokenization_openai.py +2 -5
- transformers/models/opt/configuration_opt.py +6 -7
- transformers/models/opt/modeling_opt.py +79 -80
- transformers/models/ovis2/__init__.py +0 -1
- transformers/models/ovis2/configuration_ovis2.py +4 -1
- transformers/models/ovis2/image_processing_ovis2.py +22 -24
- transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
- transformers/models/ovis2/modeling_ovis2.py +99 -142
- transformers/models/ovis2/modular_ovis2.py +82 -45
- transformers/models/ovis2/processing_ovis2.py +12 -40
- transformers/models/owlv2/configuration_owlv2.py +4 -2
- transformers/models/owlv2/image_processing_owlv2.py +20 -21
- transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
- transformers/models/owlv2/modeling_owlv2.py +122 -114
- transformers/models/owlv2/modular_owlv2.py +11 -12
- transformers/models/owlv2/processing_owlv2.py +20 -49
- transformers/models/owlvit/configuration_owlvit.py +4 -2
- transformers/models/owlvit/image_processing_owlvit.py +21 -22
- transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
- transformers/models/owlvit/modeling_owlvit.py +121 -113
- transformers/models/owlvit/processing_owlvit.py +20 -48
- transformers/models/paddleocr_vl/__init__.py +0 -1
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
- transformers/models/paligemma/configuration_paligemma.py +4 -1
- transformers/models/paligemma/modeling_paligemma.py +81 -79
- transformers/models/paligemma/processing_paligemma.py +13 -66
- transformers/models/parakeet/configuration_parakeet.py +3 -8
- transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
- transformers/models/parakeet/modeling_parakeet.py +21 -25
- transformers/models/parakeet/modular_parakeet.py +19 -21
- transformers/models/parakeet/processing_parakeet.py +12 -5
- transformers/models/parakeet/tokenization_parakeet.py +2 -4
- transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
- transformers/models/patchtst/configuration_patchtst.py +6 -9
- transformers/models/patchtst/modeling_patchtst.py +75 -77
- transformers/models/pe_audio/__init__.py +0 -1
- transformers/models/pe_audio/configuration_pe_audio.py +14 -16
- transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
- transformers/models/pe_audio/modeling_pe_audio.py +30 -31
- transformers/models/pe_audio/modular_pe_audio.py +17 -18
- transformers/models/pe_audio/processing_pe_audio.py +0 -1
- transformers/models/pe_audio_video/__init__.py +0 -1
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
- transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
- transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
- transformers/models/pe_video/__init__.py +0 -1
- transformers/models/pe_video/configuration_pe_video.py +14 -16
- transformers/models/pe_video/modeling_pe_video.py +57 -46
- transformers/models/pe_video/modular_pe_video.py +47 -35
- transformers/models/pe_video/video_processing_pe_video.py +2 -4
- transformers/models/pegasus/configuration_pegasus.py +8 -6
- transformers/models/pegasus/modeling_pegasus.py +67 -69
- transformers/models/pegasus/tokenization_pegasus.py +1 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
- transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
- transformers/models/perceiver/configuration_perceiver.py +0 -1
- transformers/models/perceiver/image_processing_perceiver.py +22 -25
- transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
- transformers/models/perceiver/modeling_perceiver.py +152 -145
- transformers/models/perceiver/tokenization_perceiver.py +3 -6
- transformers/models/perception_lm/configuration_perception_lm.py +0 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
- transformers/models/perception_lm/modeling_perception_lm.py +64 -67
- transformers/models/perception_lm/modular_perception_lm.py +58 -58
- transformers/models/perception_lm/processing_perception_lm.py +13 -47
- transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
- transformers/models/persimmon/configuration_persimmon.py +23 -28
- transformers/models/persimmon/modeling_persimmon.py +44 -47
- transformers/models/phi/configuration_phi.py +27 -28
- transformers/models/phi/modeling_phi.py +39 -41
- transformers/models/phi/modular_phi.py +26 -26
- transformers/models/phi3/configuration_phi3.py +32 -37
- transformers/models/phi3/modeling_phi3.py +37 -40
- transformers/models/phi3/modular_phi3.py +16 -20
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
- transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
- transformers/models/phimoe/configuration_phimoe.py +31 -36
- transformers/models/phimoe/modeling_phimoe.py +50 -77
- transformers/models/phimoe/modular_phimoe.py +12 -8
- transformers/models/phobert/tokenization_phobert.py +4 -6
- transformers/models/pix2struct/configuration_pix2struct.py +12 -10
- transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
- transformers/models/pix2struct/modeling_pix2struct.py +56 -52
- transformers/models/pix2struct/processing_pix2struct.py +5 -26
- transformers/models/pixio/__init__.py +0 -1
- transformers/models/pixio/configuration_pixio.py +2 -5
- transformers/models/pixio/modeling_pixio.py +16 -17
- transformers/models/pixio/modular_pixio.py +7 -8
- transformers/models/pixtral/configuration_pixtral.py +11 -14
- transformers/models/pixtral/image_processing_pixtral.py +26 -28
- transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
- transformers/models/pixtral/modeling_pixtral.py +31 -37
- transformers/models/pixtral/processing_pixtral.py +18 -52
- transformers/models/plbart/configuration_plbart.py +8 -6
- transformers/models/plbart/modeling_plbart.py +109 -109
- transformers/models/plbart/modular_plbart.py +31 -33
- transformers/models/plbart/tokenization_plbart.py +4 -5
- transformers/models/poolformer/configuration_poolformer.py +0 -1
- transformers/models/poolformer/image_processing_poolformer.py +21 -24
- transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
- transformers/models/poolformer/modeling_poolformer.py +10 -12
- transformers/models/pop2piano/configuration_pop2piano.py +7 -7
- transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
- transformers/models/pop2piano/modeling_pop2piano.py +24 -24
- transformers/models/pop2piano/processing_pop2piano.py +25 -33
- transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
- transformers/models/prophetnet/configuration_prophetnet.py +37 -38
- transformers/models/prophetnet/modeling_prophetnet.py +121 -153
- transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
- transformers/models/pvt/configuration_pvt.py +0 -1
- transformers/models/pvt/image_processing_pvt.py +24 -27
- transformers/models/pvt/image_processing_pvt_fast.py +1 -2
- transformers/models/pvt/modeling_pvt.py +19 -21
- transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
- transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
- transformers/models/qwen2/configuration_qwen2.py +32 -25
- transformers/models/qwen2/modeling_qwen2.py +35 -37
- transformers/models/qwen2/modular_qwen2.py +14 -15
- transformers/models/qwen2/tokenization_qwen2.py +2 -9
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
- transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
- transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
- transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
- transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
- transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
- transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
- transformers/models/qwen3/configuration_qwen3.py +34 -27
- transformers/models/qwen3/modeling_qwen3.py +35 -38
- transformers/models/qwen3/modular_qwen3.py +7 -9
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
- transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
- transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
- transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
- transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
- transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
- transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
- transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
- transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
- transformers/models/rag/configuration_rag.py +6 -7
- transformers/models/rag/modeling_rag.py +119 -121
- transformers/models/rag/retrieval_rag.py +3 -5
- transformers/models/rag/tokenization_rag.py +0 -50
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
- transformers/models/reformer/configuration_reformer.py +7 -8
- transformers/models/reformer/modeling_reformer.py +67 -68
- transformers/models/reformer/tokenization_reformer.py +3 -6
- transformers/models/regnet/configuration_regnet.py +0 -1
- transformers/models/regnet/modeling_regnet.py +7 -9
- transformers/models/rembert/configuration_rembert.py +8 -2
- transformers/models/rembert/modeling_rembert.py +108 -132
- transformers/models/rembert/tokenization_rembert.py +1 -4
- transformers/models/resnet/configuration_resnet.py +2 -5
- transformers/models/resnet/modeling_resnet.py +14 -15
- transformers/models/roberta/configuration_roberta.py +11 -3
- transformers/models/roberta/modeling_roberta.py +97 -99
- transformers/models/roberta/modular_roberta.py +55 -58
- transformers/models/roberta/tokenization_roberta.py +2 -5
- transformers/models/roberta/tokenization_roberta_old.py +2 -4
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
- transformers/models/roc_bert/configuration_roc_bert.py +8 -2
- transformers/models/roc_bert/modeling_roc_bert.py +125 -162
- transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
- transformers/models/roformer/configuration_roformer.py +13 -3
- transformers/models/roformer/modeling_roformer.py +79 -95
- transformers/models/roformer/tokenization_roformer.py +3 -6
- transformers/models/roformer/tokenization_utils.py +0 -1
- transformers/models/rt_detr/configuration_rt_detr.py +8 -50
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
- transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
- transformers/models/rt_detr/modeling_rt_detr.py +643 -804
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
- transformers/models/rt_detr/modular_rt_detr.py +1522 -20
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
- transformers/models/rwkv/configuration_rwkv.py +2 -4
- transformers/models/rwkv/modeling_rwkv.py +29 -54
- transformers/models/sam/configuration_sam.py +2 -1
- transformers/models/sam/image_processing_sam.py +59 -60
- transformers/models/sam/image_processing_sam_fast.py +25 -26
- transformers/models/sam/modeling_sam.py +46 -43
- transformers/models/sam/processing_sam.py +39 -27
- transformers/models/sam2/configuration_sam2.py +1 -2
- transformers/models/sam2/image_processing_sam2_fast.py +14 -15
- transformers/models/sam2/modeling_sam2.py +96 -94
- transformers/models/sam2/modular_sam2.py +85 -94
- transformers/models/sam2/processing_sam2.py +31 -47
- transformers/models/sam2_video/configuration_sam2_video.py +0 -1
- transformers/models/sam2_video/modeling_sam2_video.py +114 -116
- transformers/models/sam2_video/modular_sam2_video.py +72 -89
- transformers/models/sam2_video/processing_sam2_video.py +49 -66
- transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
- transformers/models/sam3/configuration_sam3.py +0 -1
- transformers/models/sam3/image_processing_sam3_fast.py +17 -20
- transformers/models/sam3/modeling_sam3.py +94 -100
- transformers/models/sam3/modular_sam3.py +3 -8
- transformers/models/sam3/processing_sam3.py +37 -52
- transformers/models/sam3_tracker/__init__.py +0 -1
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
- transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
- transformers/models/sam3_tracker_video/__init__.py +0 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
- transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
- transformers/models/sam3_video/configuration_sam3_video.py +0 -1
- transformers/models/sam3_video/modeling_sam3_video.py +56 -45
- transformers/models/sam3_video/processing_sam3_video.py +25 -45
- transformers/models/sam_hq/__init__.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +2 -1
- transformers/models/sam_hq/modeling_sam_hq.py +52 -50
- transformers/models/sam_hq/modular_sam_hq.py +23 -25
- transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
- transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
- transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
- transformers/models/seed_oss/configuration_seed_oss.py +30 -34
- transformers/models/seed_oss/modeling_seed_oss.py +34 -36
- transformers/models/seed_oss/modular_seed_oss.py +6 -7
- transformers/models/segformer/configuration_segformer.py +0 -10
- transformers/models/segformer/image_processing_segformer.py +39 -42
- transformers/models/segformer/image_processing_segformer_fast.py +11 -12
- transformers/models/segformer/modeling_segformer.py +28 -28
- transformers/models/segformer/modular_segformer.py +8 -9
- transformers/models/seggpt/configuration_seggpt.py +0 -1
- transformers/models/seggpt/image_processing_seggpt.py +38 -41
- transformers/models/seggpt/modeling_seggpt.py +48 -38
- transformers/models/sew/configuration_sew.py +4 -2
- transformers/models/sew/modeling_sew.py +42 -40
- transformers/models/sew/modular_sew.py +12 -13
- transformers/models/sew_d/configuration_sew_d.py +4 -2
- transformers/models/sew_d/modeling_sew_d.py +32 -31
- transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
- transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
- transformers/models/siglip/configuration_siglip.py +4 -2
- transformers/models/siglip/image_processing_siglip.py +17 -20
- transformers/models/siglip/image_processing_siglip_fast.py +0 -1
- transformers/models/siglip/modeling_siglip.py +65 -110
- transformers/models/siglip/processing_siglip.py +2 -14
- transformers/models/siglip/tokenization_siglip.py +6 -7
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2.py +15 -16
- transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
- transformers/models/siglip2/modeling_siglip2.py +89 -130
- transformers/models/siglip2/modular_siglip2.py +95 -48
- transformers/models/siglip2/processing_siglip2.py +2 -14
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +29 -32
- transformers/models/smollm3/modeling_smollm3.py +35 -38
- transformers/models/smollm3/modular_smollm3.py +36 -38
- transformers/models/smolvlm/configuration_smolvlm.py +2 -4
- transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
- transformers/models/smolvlm/modeling_smolvlm.py +124 -96
- transformers/models/smolvlm/modular_smolvlm.py +50 -39
- transformers/models/smolvlm/processing_smolvlm.py +15 -76
- transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
- transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
- transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
- transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
- transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
- transformers/models/speecht5/configuration_speecht5.py +7 -9
- transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
- transformers/models/speecht5/modeling_speecht5.py +172 -174
- transformers/models/speecht5/number_normalizer.py +0 -1
- transformers/models/speecht5/processing_speecht5.py +3 -37
- transformers/models/speecht5/tokenization_speecht5.py +4 -5
- transformers/models/splinter/configuration_splinter.py +6 -7
- transformers/models/splinter/modeling_splinter.py +62 -59
- transformers/models/splinter/tokenization_splinter.py +2 -4
- transformers/models/squeezebert/configuration_squeezebert.py +14 -2
- transformers/models/squeezebert/modeling_squeezebert.py +60 -62
- transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
- transformers/models/stablelm/configuration_stablelm.py +28 -29
- transformers/models/stablelm/modeling_stablelm.py +44 -47
- transformers/models/starcoder2/configuration_starcoder2.py +30 -27
- transformers/models/starcoder2/modeling_starcoder2.py +38 -41
- transformers/models/starcoder2/modular_starcoder2.py +17 -19
- transformers/models/superglue/configuration_superglue.py +7 -3
- transformers/models/superglue/image_processing_superglue.py +15 -15
- transformers/models/superglue/image_processing_superglue_fast.py +8 -8
- transformers/models/superglue/modeling_superglue.py +41 -37
- transformers/models/superpoint/image_processing_superpoint.py +15 -15
- transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
- transformers/models/superpoint/modeling_superpoint.py +17 -16
- transformers/models/swiftformer/configuration_swiftformer.py +0 -1
- transformers/models/swiftformer/modeling_swiftformer.py +12 -14
- transformers/models/swin/configuration_swin.py +2 -5
- transformers/models/swin/modeling_swin.py +69 -78
- transformers/models/swin2sr/configuration_swin2sr.py +0 -1
- transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
- transformers/models/swin2sr/modeling_swin2sr.py +30 -30
- transformers/models/swinv2/configuration_swinv2.py +2 -5
- transformers/models/swinv2/modeling_swinv2.py +65 -74
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
- transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
- transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
- transformers/models/t5/configuration_t5.py +9 -9
- transformers/models/t5/modeling_t5.py +80 -85
- transformers/models/t5/tokenization_t5.py +1 -3
- transformers/models/t5gemma/configuration_t5gemma.py +43 -59
- transformers/models/t5gemma/modeling_t5gemma.py +105 -108
- transformers/models/t5gemma/modular_t5gemma.py +128 -142
- transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
- transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
- transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
- transformers/models/table_transformer/configuration_table_transformer.py +18 -50
- transformers/models/table_transformer/modeling_table_transformer.py +73 -101
- transformers/models/tapas/configuration_tapas.py +12 -2
- transformers/models/tapas/modeling_tapas.py +65 -67
- transformers/models/tapas/tokenization_tapas.py +116 -153
- transformers/models/textnet/configuration_textnet.py +4 -7
- transformers/models/textnet/image_processing_textnet.py +22 -25
- transformers/models/textnet/image_processing_textnet_fast.py +8 -9
- transformers/models/textnet/modeling_textnet.py +28 -28
- transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
- transformers/models/timesfm/configuration_timesfm.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +22 -25
- transformers/models/timesfm/modular_timesfm.py +21 -24
- transformers/models/timesformer/configuration_timesformer.py +0 -1
- transformers/models/timesformer/modeling_timesformer.py +13 -16
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
- transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
- transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
- transformers/models/trocr/configuration_trocr.py +11 -8
- transformers/models/trocr/modeling_trocr.py +42 -42
- transformers/models/trocr/processing_trocr.py +5 -25
- transformers/models/tvp/configuration_tvp.py +10 -36
- transformers/models/tvp/image_processing_tvp.py +50 -52
- transformers/models/tvp/image_processing_tvp_fast.py +15 -15
- transformers/models/tvp/modeling_tvp.py +26 -28
- transformers/models/tvp/processing_tvp.py +2 -14
- transformers/models/udop/configuration_udop.py +16 -8
- transformers/models/udop/modeling_udop.py +73 -72
- transformers/models/udop/processing_udop.py +7 -26
- transformers/models/udop/tokenization_udop.py +80 -93
- transformers/models/umt5/configuration_umt5.py +8 -7
- transformers/models/umt5/modeling_umt5.py +87 -84
- transformers/models/unispeech/configuration_unispeech.py +4 -2
- transformers/models/unispeech/modeling_unispeech.py +54 -53
- transformers/models/unispeech/modular_unispeech.py +20 -22
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
- transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
- transformers/models/univnet/feature_extraction_univnet.py +14 -14
- transformers/models/univnet/modeling_univnet.py +7 -8
- transformers/models/upernet/configuration_upernet.py +8 -36
- transformers/models/upernet/modeling_upernet.py +11 -14
- transformers/models/vaultgemma/__init__.py +0 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
- transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
- transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
- transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
- transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
- transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
- transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
- transformers/models/video_llava/configuration_video_llava.py +4 -1
- transformers/models/video_llava/image_processing_video_llava.py +35 -38
- transformers/models/video_llava/modeling_video_llava.py +139 -143
- transformers/models/video_llava/processing_video_llava.py +38 -78
- transformers/models/video_llava/video_processing_video_llava.py +0 -1
- transformers/models/videomae/configuration_videomae.py +0 -1
- transformers/models/videomae/image_processing_videomae.py +31 -34
- transformers/models/videomae/modeling_videomae.py +17 -20
- transformers/models/videomae/video_processing_videomae.py +0 -1
- transformers/models/vilt/configuration_vilt.py +4 -2
- transformers/models/vilt/image_processing_vilt.py +29 -30
- transformers/models/vilt/image_processing_vilt_fast.py +15 -16
- transformers/models/vilt/modeling_vilt.py +103 -90
- transformers/models/vilt/processing_vilt.py +2 -14
- transformers/models/vipllava/configuration_vipllava.py +4 -1
- transformers/models/vipllava/modeling_vipllava.py +92 -67
- transformers/models/vipllava/modular_vipllava.py +78 -54
- transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
- transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
- transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
- transformers/models/visual_bert/configuration_visual_bert.py +6 -2
- transformers/models/visual_bert/modeling_visual_bert.py +90 -92
- transformers/models/vit/configuration_vit.py +2 -3
- transformers/models/vit/image_processing_vit.py +19 -22
- transformers/models/vit/image_processing_vit_fast.py +0 -1
- transformers/models/vit/modeling_vit.py +20 -20
- transformers/models/vit_mae/configuration_vit_mae.py +0 -1
- transformers/models/vit_mae/modeling_vit_mae.py +32 -30
- transformers/models/vit_msn/configuration_vit_msn.py +0 -1
- transformers/models/vit_msn/modeling_vit_msn.py +21 -19
- transformers/models/vitdet/configuration_vitdet.py +2 -5
- transformers/models/vitdet/modeling_vitdet.py +14 -17
- transformers/models/vitmatte/configuration_vitmatte.py +7 -39
- transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
- transformers/models/vitmatte/modeling_vitmatte.py +10 -12
- transformers/models/vitpose/configuration_vitpose.py +7 -47
- transformers/models/vitpose/image_processing_vitpose.py +24 -25
- transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
- transformers/models/vitpose/modeling_vitpose.py +15 -15
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
- transformers/models/vits/configuration_vits.py +4 -1
- transformers/models/vits/modeling_vits.py +43 -42
- transformers/models/vits/tokenization_vits.py +3 -4
- transformers/models/vivit/configuration_vivit.py +0 -1
- transformers/models/vivit/image_processing_vivit.py +36 -39
- transformers/models/vivit/modeling_vivit.py +9 -11
- transformers/models/vjepa2/__init__.py +0 -1
- transformers/models/vjepa2/configuration_vjepa2.py +0 -1
- transformers/models/vjepa2/modeling_vjepa2.py +39 -41
- transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
- transformers/models/voxtral/__init__.py +0 -1
- transformers/models/voxtral/configuration_voxtral.py +0 -2
- transformers/models/voxtral/modeling_voxtral.py +41 -48
- transformers/models/voxtral/modular_voxtral.py +35 -38
- transformers/models/voxtral/processing_voxtral.py +25 -48
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
- transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
- transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
- transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
- transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
- transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
- transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
- transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
- transformers/models/wavlm/configuration_wavlm.py +4 -2
- transformers/models/wavlm/modeling_wavlm.py +49 -49
- transformers/models/wavlm/modular_wavlm.py +4 -5
- transformers/models/whisper/configuration_whisper.py +6 -5
- transformers/models/whisper/english_normalizer.py +3 -4
- transformers/models/whisper/feature_extraction_whisper.py +9 -24
- transformers/models/whisper/generation_whisper.py +26 -49
- transformers/models/whisper/modeling_whisper.py +71 -73
- transformers/models/whisper/processing_whisper.py +3 -20
- transformers/models/whisper/tokenization_whisper.py +9 -30
- transformers/models/x_clip/configuration_x_clip.py +4 -2
- transformers/models/x_clip/modeling_x_clip.py +94 -96
- transformers/models/x_clip/processing_x_clip.py +2 -14
- transformers/models/xcodec/configuration_xcodec.py +4 -6
- transformers/models/xcodec/modeling_xcodec.py +15 -17
- transformers/models/xglm/configuration_xglm.py +9 -8
- transformers/models/xglm/modeling_xglm.py +49 -55
- transformers/models/xglm/tokenization_xglm.py +1 -4
- transformers/models/xlm/configuration_xlm.py +10 -8
- transformers/models/xlm/modeling_xlm.py +127 -131
- transformers/models/xlm/tokenization_xlm.py +3 -5
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
- transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
- transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
- transformers/models/xlnet/configuration_xlnet.py +3 -12
- transformers/models/xlnet/modeling_xlnet.py +149 -162
- transformers/models/xlnet/tokenization_xlnet.py +1 -4
- transformers/models/xlstm/configuration_xlstm.py +8 -12
- transformers/models/xlstm/modeling_xlstm.py +61 -96
- transformers/models/xmod/configuration_xmod.py +11 -3
- transformers/models/xmod/modeling_xmod.py +111 -116
- transformers/models/yolos/configuration_yolos.py +0 -1
- transformers/models/yolos/image_processing_yolos.py +60 -62
- transformers/models/yolos/image_processing_yolos_fast.py +42 -45
- transformers/models/yolos/modeling_yolos.py +19 -21
- transformers/models/yolos/modular_yolos.py +17 -19
- transformers/models/yoso/configuration_yoso.py +8 -2
- transformers/models/yoso/modeling_yoso.py +60 -62
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -8
- transformers/models/zamba/modeling_zamba.py +93 -125
- transformers/models/zamba2/configuration_zamba2.py +44 -50
- transformers/models/zamba2/modeling_zamba2.py +137 -165
- transformers/models/zamba2/modular_zamba2.py +79 -74
- transformers/models/zoedepth/configuration_zoedepth.py +17 -41
- transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
- transformers/models/zoedepth/modeling_zoedepth.py +19 -19
- transformers/pipelines/__init__.py +47 -106
- transformers/pipelines/any_to_any.py +15 -23
- transformers/pipelines/audio_utils.py +1 -2
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +13 -17
- transformers/pipelines/image_text_to_text.py +1 -2
- transformers/pipelines/question_answering.py +4 -43
- transformers/pipelines/text_classification.py +1 -14
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/pipelines/token_classification.py +1 -22
- transformers/pipelines/video_classification.py +1 -9
- transformers/pipelines/zero_shot_audio_classification.py +0 -1
- transformers/pipelines/zero_shot_classification.py +0 -6
- transformers/pipelines/zero_shot_image_classification.py +0 -7
- transformers/processing_utils.py +128 -137
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/base.py +10 -0
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_quark.py +0 -1
- transformers/quantizers/quantizer_torchao.py +3 -19
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +6 -65
- transformers/tokenization_mistral_common.py +563 -903
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +228 -341
- transformers/tokenization_utils_sentencepiece.py +5 -6
- transformers/tokenization_utils_tokenizers.py +36 -7
- transformers/trainer.py +30 -41
- transformers/trainer_jit_checkpoint.py +1 -2
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +414 -420
- transformers/utils/__init__.py +1 -4
- transformers/utils/attention_visualizer.py +1 -1
- transformers/utils/auto_docstring.py +567 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/dummy_pt_objects.py +0 -42
- transformers/utils/generic.py +70 -34
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +135 -107
- transformers/utils/quantization_config.py +8 -31
- transformers/video_processing_utils.py +24 -25
- transformers/video_utils.py +21 -23
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
- transformers-5.1.0.dist-info/RECORD +2092 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -229
- transformers-5.0.0rc2.dist-info/RECORD +0 -2042
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -17,7 +17,7 @@ import math
|
|
|
17
17
|
import operator
|
|
18
18
|
import os
|
|
19
19
|
import re
|
|
20
|
-
from functools import
|
|
20
|
+
from functools import reduce
|
|
21
21
|
|
|
22
22
|
from ..distributed import DistributedConfig
|
|
23
23
|
from ..utils import is_torch_greater_or_equal, logging
|
|
@@ -33,9 +33,6 @@ if is_torch_available():
|
|
|
33
33
|
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
|
34
34
|
_torch_distributed_available = torch.distributed.is_available()
|
|
35
35
|
|
|
36
|
-
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
|
37
|
-
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
|
|
38
|
-
|
|
39
36
|
|
|
40
37
|
logger = logging.get_logger(__name__)
|
|
41
38
|
|
|
@@ -68,10 +65,6 @@ def initialize_tensor_parallelism(
|
|
|
68
65
|
|
|
69
66
|
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
|
|
70
67
|
backend = backend_map.get(device_type)
|
|
71
|
-
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", "0")):
|
|
72
|
-
backend = "ccl"
|
|
73
|
-
if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
|
|
74
|
-
backend = "ccl"
|
|
75
68
|
|
|
76
69
|
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
|
77
70
|
current_device = getattr(torch, device_type)
|
|
@@ -116,32 +109,6 @@ def initialize_tensor_parallelism(
|
|
|
116
109
|
return device_map, device_mesh, tp_size
|
|
117
110
|
|
|
118
111
|
|
|
119
|
-
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
|
|
120
|
-
"""
|
|
121
|
-
Convert block count or proportions to block sizes.
|
|
122
|
-
|
|
123
|
-
This function accepts
|
|
124
|
-
|
|
125
|
-
- The number of blocks (int), in which case the block size is
|
|
126
|
-
total_size//blocks; or
|
|
127
|
-
- A list of block sizes (list[int]).
|
|
128
|
-
|
|
129
|
-
In the second case, if sum(blocks) < total_size, the ratios between
|
|
130
|
-
the block sizes will be preserved. For instance, if blocks is
|
|
131
|
-
[2, 1, 1] and total_size is 1024, the returned block sizes are
|
|
132
|
-
[512, 256, 256].
|
|
133
|
-
"""
|
|
134
|
-
if isinstance(blocks, list):
|
|
135
|
-
total_blocks = sum(blocks)
|
|
136
|
-
assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
|
|
137
|
-
part_size = total_size // total_blocks
|
|
138
|
-
return [part_size * block for block in blocks]
|
|
139
|
-
else:
|
|
140
|
-
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
|
141
|
-
single_size = total_size // blocks
|
|
142
|
-
return [single_size] * blocks
|
|
143
|
-
|
|
144
|
-
|
|
145
112
|
def replace_layer_number_by_wildcard(name: str) -> str:
|
|
146
113
|
"""
|
|
147
114
|
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
|
|
@@ -170,6 +137,11 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
|
|
|
170
137
|
return None
|
|
171
138
|
|
|
172
139
|
|
|
140
|
+
# =============================================================================
|
|
141
|
+
# Tensor Sharding Utilities
|
|
142
|
+
# =============================================================================
|
|
143
|
+
|
|
144
|
+
|
|
173
145
|
if is_torch_available():
|
|
174
146
|
str_to_dtype = {
|
|
175
147
|
"BOOL": torch.bool,
|
|
@@ -186,6 +158,32 @@ if is_torch_available():
|
|
|
186
158
|
}
|
|
187
159
|
|
|
188
160
|
|
|
161
|
+
def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
|
|
162
|
+
"""
|
|
163
|
+
Convert block count or proportions to block sizes.
|
|
164
|
+
|
|
165
|
+
This function accepts
|
|
166
|
+
|
|
167
|
+
- The number of blocks (int), in which case the block size is
|
|
168
|
+
total_size//blocks; or
|
|
169
|
+
- A list of block sizes (list[int]).
|
|
170
|
+
|
|
171
|
+
In the second case, if sum(blocks) < total_size, the ratios between
|
|
172
|
+
the block sizes will be preserved. For instance, if blocks is
|
|
173
|
+
[2, 1, 1] and total_size is 1024, the returned block sizes are
|
|
174
|
+
[512, 256, 256].
|
|
175
|
+
"""
|
|
176
|
+
if isinstance(blocks, list):
|
|
177
|
+
total_blocks = sum(blocks)
|
|
178
|
+
assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
|
|
179
|
+
part_size = total_size // total_blocks
|
|
180
|
+
return [part_size * block for block in blocks]
|
|
181
|
+
else:
|
|
182
|
+
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
|
183
|
+
single_size = total_size // blocks
|
|
184
|
+
return [single_size] * blocks
|
|
185
|
+
|
|
186
|
+
|
|
189
187
|
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
|
190
188
|
"""
|
|
191
189
|
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
|
|
@@ -372,19 +370,20 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int
|
|
|
372
370
|
dim (int): Dimension along which to shard the tensor.
|
|
373
371
|
"""
|
|
374
372
|
param_dim = empty_param.ndim
|
|
375
|
-
# Flatten the mesh to get the total number of devices
|
|
376
373
|
mesh_shape = device_mesh.shape
|
|
377
374
|
world_size = reduce(operator.mul, mesh_shape)
|
|
375
|
+
# Get param shape: works for both torch.Tensor and safetensors TensorInfo
|
|
376
|
+
param_shape = list(param.shape) if isinstance(param, torch.Tensor) else param.get_shape()
|
|
378
377
|
if dim < 0:
|
|
379
378
|
dim = param_dim + dim
|
|
380
|
-
if empty_param.dim() == 3 and dim == 1 and len(
|
|
381
|
-
dim = 0
|
|
382
|
-
elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2:
|
|
379
|
+
if empty_param.dim() == 3 and dim == 1 and len(param_shape) == 2:
|
|
383
380
|
dim = 0
|
|
381
|
+
elif empty_param.dim() == 3 and dim == 2 and len(param_shape) == 2:
|
|
382
|
+
dim = 1
|
|
384
383
|
|
|
385
|
-
shard_size = math.ceil(
|
|
384
|
+
shard_size = math.ceil(param_shape[dim] / world_size)
|
|
386
385
|
start = rank * shard_size
|
|
387
|
-
end = min(start + shard_size,
|
|
386
|
+
end = min(start + shard_size, param_shape[dim])
|
|
388
387
|
|
|
389
388
|
if dim >= param_dim:
|
|
390
389
|
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
|
@@ -401,9 +400,7 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int
|
|
|
401
400
|
# actually we still shard dim=0 does not change
|
|
402
401
|
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
|
|
403
402
|
# tensor on a certain device (with the input tensor_index)
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2:
|
|
403
|
+
if tensor_idx is not None and empty_param.dim() == 3 and dim == 0 and len(param_shape) == 2:
|
|
407
404
|
# special case we don't "shard" just send this entire tensor to the correct rank.
|
|
408
405
|
if start <= tensor_idx < end:
|
|
409
406
|
# this tensor does need to be materialized on this device:
|
|
@@ -411,17 +408,214 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int
|
|
|
411
408
|
else:
|
|
412
409
|
return torch.empty([], dtype=torch.int64, device=rank)
|
|
413
410
|
|
|
414
|
-
slice_indices = [slice(None)] * len(
|
|
411
|
+
slice_indices = [slice(None)] * len(param_shape)
|
|
415
412
|
|
|
416
|
-
if start <
|
|
413
|
+
if start < param_shape[dim]:
|
|
417
414
|
slice_indices[dim] = slice(start, end)
|
|
418
415
|
param = param[tuple(slice_indices)]
|
|
419
416
|
if isinstance(param, list): # TODO handle the modulelist case!
|
|
420
417
|
param = [p[:] for p in param]
|
|
421
418
|
return param
|
|
422
419
|
|
|
423
|
-
|
|
424
|
-
return torch.empty(tuple(
|
|
420
|
+
param_shape[dim] = 0
|
|
421
|
+
return torch.empty(tuple(param_shape), dtype=torch.int64) # empty allocates memory....
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def _split_along_last_dim(x, world_size):
|
|
425
|
+
"""Split tensor along last dimension into world_size chunks."""
|
|
426
|
+
return torch.chunk(x, world_size, dim=-1)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
# =============================================================================
|
|
430
|
+
# Distributed Communication Primitives
|
|
431
|
+
# =============================================================================
|
|
432
|
+
#
|
|
433
|
+
# Naming convention:
|
|
434
|
+
# - Functions describe their FORWARD behavior
|
|
435
|
+
# - Backward behavior is the "conjugate" operation for gradient flow
|
|
436
|
+
#
|
|
437
|
+
# Available operations:
|
|
438
|
+
# ┌────────────────────┬─────────────────────┬─────────────────────┐
|
|
439
|
+
# │ Function │ Forward │ Backward │
|
|
440
|
+
# ├────────────────────┼─────────────────────┼─────────────────────┤
|
|
441
|
+
# │ all_reduce │ all-reduce (sum) │ identity │
|
|
442
|
+
# │ all_reduce_backward│ identity │ all-reduce (sum) │
|
|
443
|
+
# │ all_gather │ all-gather │ split (local chunk) │
|
|
444
|
+
# │ split │ split (local chunk) │ all-gather │
|
|
445
|
+
# │ reduce_scatter │ reduce-scatter │ all-gather │
|
|
446
|
+
# └────────────────────┴─────────────────────┴─────────────────────┘
|
|
447
|
+
# ===================
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
class _AllReduceBackward(torch.autograd.Function):
|
|
451
|
+
"""Identity forward, all-reduce backward. Used before colwise layers (f in Megatron)."""
|
|
452
|
+
|
|
453
|
+
@staticmethod
|
|
454
|
+
def forward(ctx, x, device_mesh):
|
|
455
|
+
ctx.device_mesh = device_mesh
|
|
456
|
+
return x
|
|
457
|
+
|
|
458
|
+
@staticmethod
|
|
459
|
+
def backward(ctx, grad_output):
|
|
460
|
+
device_mesh = ctx.device_mesh
|
|
461
|
+
if device_mesh.size() == 1:
|
|
462
|
+
return grad_output, None
|
|
463
|
+
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
|
|
464
|
+
return grad_output, None
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
class _AllReduceForward(torch.autograd.Function):
|
|
468
|
+
"""All-reduce forward, identity backward. Used after rowwise layers (g in Megatron)."""
|
|
469
|
+
|
|
470
|
+
@staticmethod
|
|
471
|
+
def forward(ctx, x, device_mesh):
|
|
472
|
+
if device_mesh.size() == 1:
|
|
473
|
+
return x
|
|
474
|
+
dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
|
|
475
|
+
return x
|
|
476
|
+
|
|
477
|
+
@staticmethod
|
|
478
|
+
def backward(ctx, grad_output):
|
|
479
|
+
return grad_output, None
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
class _AllGather(torch.autograd.Function):
|
|
483
|
+
"""All-gather forward, split backward. Gathers sharded outputs."""
|
|
484
|
+
|
|
485
|
+
@staticmethod
|
|
486
|
+
def forward(ctx, x, device_mesh):
|
|
487
|
+
ctx.device_mesh = device_mesh
|
|
488
|
+
world_size = device_mesh.size()
|
|
489
|
+
|
|
490
|
+
if world_size == 1:
|
|
491
|
+
return x
|
|
492
|
+
|
|
493
|
+
last_dim = x.dim() - 1
|
|
494
|
+
rank = device_mesh.get_local_rank()
|
|
495
|
+
group = device_mesh.get_group()
|
|
496
|
+
|
|
497
|
+
x = x.contiguous()
|
|
498
|
+
tensor_list = [torch.empty_like(x) for _ in range(world_size)]
|
|
499
|
+
tensor_list[rank] = x
|
|
500
|
+
dist.all_gather(tensor_list, x, group=group)
|
|
501
|
+
return torch.cat(tensor_list, dim=last_dim).contiguous()
|
|
502
|
+
|
|
503
|
+
@staticmethod
|
|
504
|
+
def backward(ctx, grad_output):
|
|
505
|
+
device_mesh = ctx.device_mesh
|
|
506
|
+
world_size = device_mesh.size()
|
|
507
|
+
|
|
508
|
+
if world_size == 1:
|
|
509
|
+
return grad_output, None
|
|
510
|
+
|
|
511
|
+
rank = device_mesh.get_local_rank()
|
|
512
|
+
chunks = _split_along_last_dim(grad_output, world_size)
|
|
513
|
+
return chunks[rank].contiguous(), None
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class _Split(torch.autograd.Function):
|
|
517
|
+
"""Split forward, all-gather backward. Scatters replicated input."""
|
|
518
|
+
|
|
519
|
+
@staticmethod
|
|
520
|
+
def forward(ctx, x, device_mesh):
|
|
521
|
+
ctx.device_mesh = device_mesh
|
|
522
|
+
world_size = device_mesh.size()
|
|
523
|
+
|
|
524
|
+
if world_size == 1:
|
|
525
|
+
return x
|
|
526
|
+
|
|
527
|
+
rank = device_mesh.get_local_rank()
|
|
528
|
+
chunks = _split_along_last_dim(x, world_size)
|
|
529
|
+
return chunks[rank].contiguous()
|
|
530
|
+
|
|
531
|
+
@staticmethod
|
|
532
|
+
def backward(ctx, grad_output):
|
|
533
|
+
device_mesh = ctx.device_mesh
|
|
534
|
+
world_size = device_mesh.size()
|
|
535
|
+
|
|
536
|
+
if world_size == 1:
|
|
537
|
+
return grad_output, None
|
|
538
|
+
|
|
539
|
+
last_dim = grad_output.dim() - 1
|
|
540
|
+
rank = device_mesh.get_local_rank()
|
|
541
|
+
group = device_mesh.get_group()
|
|
542
|
+
|
|
543
|
+
grad_output = grad_output.contiguous()
|
|
544
|
+
tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
|
|
545
|
+
tensor_list[rank] = grad_output
|
|
546
|
+
dist.all_gather(tensor_list, grad_output, group=group)
|
|
547
|
+
return torch.cat(tensor_list, dim=last_dim).contiguous(), None
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
class _ReduceScatter(torch.autograd.Function):
|
|
551
|
+
"""Reduce-scatter forward, all-gather backward. For sequence parallel."""
|
|
552
|
+
|
|
553
|
+
@staticmethod
|
|
554
|
+
def forward(ctx, x, device_mesh):
|
|
555
|
+
ctx.device_mesh = device_mesh
|
|
556
|
+
world_size = device_mesh.size()
|
|
557
|
+
|
|
558
|
+
if world_size == 1:
|
|
559
|
+
return x
|
|
560
|
+
|
|
561
|
+
last_dim = x.dim() - 1
|
|
562
|
+
group = device_mesh.get_group()
|
|
563
|
+
|
|
564
|
+
input_chunks = list(x.chunk(world_size, dim=last_dim))
|
|
565
|
+
output_shape = list(x.shape)
|
|
566
|
+
output_shape[last_dim] //= world_size
|
|
567
|
+
output = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
|
568
|
+
|
|
569
|
+
dist.reduce_scatter(output, input_chunks, op=dist.ReduceOp.SUM, group=group)
|
|
570
|
+
return output
|
|
571
|
+
|
|
572
|
+
@staticmethod
|
|
573
|
+
def backward(ctx, grad_output):
|
|
574
|
+
device_mesh = ctx.device_mesh
|
|
575
|
+
world_size = device_mesh.size()
|
|
576
|
+
|
|
577
|
+
if world_size == 1:
|
|
578
|
+
return grad_output, None
|
|
579
|
+
|
|
580
|
+
last_dim = grad_output.dim() - 1
|
|
581
|
+
rank = device_mesh.get_local_rank()
|
|
582
|
+
group = device_mesh.get_group()
|
|
583
|
+
|
|
584
|
+
grad_output = grad_output.contiguous()
|
|
585
|
+
tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
|
|
586
|
+
tensor_list[rank] = grad_output
|
|
587
|
+
dist.all_gather(tensor_list, grad_output, group=group)
|
|
588
|
+
return torch.cat(tensor_list, dim=last_dim).contiguous(), None
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
# =============================================================================
|
|
592
|
+
# Convenience wrappers
|
|
593
|
+
# =============================================================================
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def all_reduce_backward(x, device_mesh):
|
|
597
|
+
"""Identity forward, all-reduce backward. Use before colwise layers."""
|
|
598
|
+
return _AllReduceBackward.apply(x, device_mesh)
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
def all_reduce_forward(x, device_mesh):
|
|
602
|
+
"""All-reduce forward, identity backward. Use after rowwise layers."""
|
|
603
|
+
return _AllReduceForward.apply(x, device_mesh)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
def all_gather(x, device_mesh):
|
|
607
|
+
"""All-gather forward, split backward."""
|
|
608
|
+
return _AllGather.apply(x, device_mesh)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def split(x, device_mesh):
|
|
612
|
+
"""Split forward, all-gather backward."""
|
|
613
|
+
return _Split.apply(x, device_mesh)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def reduce_scatter(x, device_mesh):
|
|
617
|
+
"""Reduce-scatter forward, all-gather backward."""
|
|
618
|
+
return _ReduceScatter.apply(x, device_mesh)
|
|
425
619
|
|
|
426
620
|
|
|
427
621
|
def distribute_module(
|
|
@@ -434,224 +628,163 @@ def distribute_module(
|
|
|
434
628
|
Copy pasted from torch's function but we remove the communications (partitioning)
|
|
435
629
|
as well as buffer registering that is similarly not efficient.
|
|
436
630
|
"""
|
|
437
|
-
if
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
|
|
631
|
+
if input_fn is not None:
|
|
632
|
+
module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
|
|
633
|
+
if output_fn is not None:
|
|
634
|
+
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
|
|
442
635
|
return module
|
|
443
636
|
|
|
444
637
|
|
|
445
638
|
class TensorParallelLayer:
|
|
446
|
-
"""
|
|
447
|
-
General tensor parallel layer for transformers.
|
|
448
|
-
"""
|
|
639
|
+
"""General tensor parallel layer for transformers"""
|
|
449
640
|
|
|
450
|
-
use_dtensor = True
|
|
451
641
|
device_mesh = None
|
|
452
642
|
rank = None
|
|
453
|
-
|
|
454
|
-
# Used to compare the shape of the original tensor
|
|
455
643
|
empty_param = None
|
|
456
644
|
|
|
457
|
-
# Used to init the corresponding DTensor
|
|
458
|
-
shard = None
|
|
459
|
-
|
|
460
645
|
def __init__(self, device_mesh=None, rank=None, empty_param=None):
|
|
461
646
|
self.rank = rank
|
|
462
647
|
self.device_mesh = device_mesh
|
|
463
648
|
self.empty_param = empty_param
|
|
464
649
|
|
|
465
650
|
@staticmethod
|
|
466
|
-
def _prepare_input_fn(
|
|
651
|
+
def _prepare_input_fn(mod, inputs, device_mesh): ...
|
|
467
652
|
|
|
468
653
|
@staticmethod
|
|
469
|
-
def _prepare_output_fn(
|
|
654
|
+
def _prepare_output_fn(mod, outputs, device_mesh): ...
|
|
470
655
|
|
|
471
656
|
def shard_tensor(
|
|
472
657
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
473
658
|
) -> torch.Tensor:
|
|
474
659
|
raise NotImplementedError
|
|
475
660
|
|
|
476
|
-
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
477
|
-
raise NotImplementedError
|
|
478
|
-
|
|
479
|
-
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
|
480
|
-
if self.use_dtensor:
|
|
481
|
-
distribute_module(
|
|
482
|
-
module,
|
|
483
|
-
device_mesh,
|
|
484
|
-
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
|
|
485
|
-
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
|
|
486
|
-
)
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
# use_dtensor needs to be set to false for nn.Parameter when you want to view, chunk, slice
|
|
490
|
-
# you name it. Whatever you want to do that is a bit unconventional, you need local tensors
|
|
491
|
-
class GatherParallel(TensorParallelLayer):
|
|
492
|
-
"""
|
|
493
|
-
Simple class used to define the hooks to add to a layer when we just want to gather the outputs
|
|
494
|
-
"""
|
|
495
|
-
|
|
496
|
-
def __init__(
|
|
497
|
-
self,
|
|
498
|
-
input_layouts: Placement | None = None,
|
|
499
|
-
output_layouts: Placement | None = None,
|
|
500
|
-
use_local_output: bool = True,
|
|
501
|
-
**kwargs,
|
|
502
|
-
):
|
|
503
|
-
super().__init__(**kwargs)
|
|
504
|
-
self.input_layouts = (input_layouts or Replicate(),)
|
|
505
|
-
self.output_layouts = output_layouts
|
|
506
|
-
self.desired_input_layouts = (Replicate(),)
|
|
507
|
-
self.use_local_output = use_local_output
|
|
508
|
-
|
|
509
|
-
@staticmethod
|
|
510
|
-
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
|
511
|
-
mod.expert_parallel_group = device_mesh.get_group()
|
|
512
|
-
if inputs and isinstance(inputs[0], DTensor):
|
|
513
|
-
inputs = inputs[0].to_local()
|
|
514
|
-
return inputs
|
|
515
|
-
|
|
516
|
-
@staticmethod
|
|
517
|
-
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
|
518
|
-
if isinstance(outputs, torch.Tensor):
|
|
519
|
-
dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False)
|
|
520
|
-
else:
|
|
521
|
-
dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
|
|
522
|
-
return outputs
|
|
523
|
-
|
|
524
|
-
def shard_tensor(
|
|
525
|
-
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
526
|
-
) -> torch.Tensor:
|
|
527
|
-
self.shard = [Replicate()]
|
|
528
|
-
return param[...].to(device=device, dtype=dtype)
|
|
529
|
-
|
|
530
661
|
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
|
531
662
|
distribute_module(
|
|
532
663
|
module,
|
|
533
664
|
device_mesh,
|
|
534
|
-
|
|
535
|
-
|
|
665
|
+
self._prepare_input_fn,
|
|
666
|
+
self._prepare_output_fn,
|
|
536
667
|
)
|
|
537
668
|
|
|
669
|
+
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
|
|
670
|
+
"""
|
|
671
|
+
Compute the expected shape after TP sharding for a given full shape.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
full_shape: The full (unsharded) parameter shape
|
|
538
675
|
|
|
539
|
-
|
|
676
|
+
Returns:
|
|
677
|
+
The expected sharded shape for this rank
|
|
678
|
+
"""
|
|
679
|
+
# Default: no sharding, return full shape
|
|
680
|
+
return tuple(full_shape)
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
class ColwiseParallel(TensorParallelLayer):
|
|
540
684
|
"""
|
|
541
|
-
|
|
542
|
-
|
|
685
|
+
Column-wise parallel: weight is sharded on dim -2 (output features).
|
|
686
|
+
Forward: input replicated -> output sharded on last dim.
|
|
687
|
+
If gather_output=True, output is all-gathered to produce full tensor.
|
|
543
688
|
"""
|
|
544
689
|
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
input_tensor = inputs[0]
|
|
549
|
-
if isinstance(input_tensor, DTensor):
|
|
550
|
-
input_tensor = input_tensor.to_local()
|
|
551
|
-
return input_tensor
|
|
690
|
+
def __init__(self, gather_output: bool = False, **kwargs):
|
|
691
|
+
super().__init__(**kwargs)
|
|
692
|
+
self.gather_output = gather_output
|
|
552
693
|
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
694
|
+
def _prepare_input_fn(self, mod, inputs, device_mesh):
|
|
695
|
+
input_tensor = inputs[0] if inputs else inputs
|
|
696
|
+
return all_reduce_backward(input_tensor, device_mesh)
|
|
697
|
+
|
|
698
|
+
def _prepare_output_fn(self, mod, outputs, device_mesh):
|
|
699
|
+
if self.gather_output:
|
|
700
|
+
return all_gather(outputs, device_mesh)
|
|
556
701
|
return outputs
|
|
557
702
|
|
|
558
703
|
def shard_tensor(
|
|
559
704
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
560
705
|
) -> torch.Tensor:
|
|
561
|
-
|
|
562
|
-
if
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
parameter = self.shard_tensor(param, dtype=dtype)
|
|
569
|
-
if to_contiguous:
|
|
570
|
-
parameter = parameter.contiguous()
|
|
571
|
-
# TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel)
|
|
572
|
-
return parameter
|
|
706
|
+
# If only 1 dim, shard this one (usually it's a `bias`)
|
|
707
|
+
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
708
|
+
if dim == 1:
|
|
709
|
+
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
710
|
+
else:
|
|
711
|
+
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
|
|
712
|
+
return parameter.to(device=device, dtype=dtype)
|
|
573
713
|
|
|
574
|
-
def
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
)
|
|
714
|
+
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
|
|
715
|
+
world_size = self.device_mesh.size()
|
|
716
|
+
shape = list(full_shape)
|
|
717
|
+
# Colwise shards dim -2, but 1D tensors (bias) shard on dim -1
|
|
718
|
+
dim = -1 if len(shape) == 1 else -2
|
|
719
|
+
dim = len(shape) + dim if dim < 0 else dim
|
|
720
|
+
shard_size = math.ceil(shape[dim] / world_size)
|
|
721
|
+
start = self.rank * shard_size
|
|
722
|
+
end = min(start + shard_size, shape[dim])
|
|
723
|
+
shape[dim] = end - start
|
|
724
|
+
return tuple(shape)
|
|
581
725
|
|
|
582
726
|
|
|
583
|
-
class
|
|
727
|
+
class RowwiseParallel(TensorParallelLayer):
|
|
584
728
|
"""
|
|
585
|
-
|
|
729
|
+
Row-wise parallel: weight is sharded on dim -1 (input features).
|
|
730
|
+
Forward: input (optionally split) -> output partial -> all-reduce to replicate.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
split_input: If True, splits replicated input before matmul. Use when input
|
|
734
|
+
comes from a non-parallelizable operation (chunk/slice).
|
|
735
|
+
Default False (expects pre-sharded input from colwise layer).
|
|
586
736
|
"""
|
|
587
737
|
|
|
588
|
-
def __init__(self,
|
|
738
|
+
def __init__(self, split_input: bool = False, **kwargs):
|
|
589
739
|
super().__init__(**kwargs)
|
|
590
|
-
self.
|
|
591
|
-
self.output_layouts = (Replicate(),)
|
|
592
|
-
self.desired_input_layouts = (Replicate(),)
|
|
593
|
-
self.use_local_output = use_local_output
|
|
594
|
-
self.use_dtensor = use_dtensor
|
|
740
|
+
self.split_input = split_input
|
|
595
741
|
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
if
|
|
602
|
-
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
|
742
|
+
def _prepare_input_fn(self, mod, inputs, device_mesh):
|
|
743
|
+
if hasattr(mod, "bias") and mod.bias is not None:
|
|
744
|
+
mod._bias = mod.bias
|
|
745
|
+
mod.bias = None
|
|
746
|
+
|
|
747
|
+
input_tensor = inputs[0] if inputs else inputs
|
|
603
748
|
|
|
749
|
+
if self.split_input:
|
|
750
|
+
# Input is replicated, split it to match sharded weight
|
|
751
|
+
return split(input_tensor, device_mesh)
|
|
604
752
|
return input_tensor
|
|
605
753
|
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
754
|
+
def _prepare_output_fn(self, mod, outputs, device_mesh):
|
|
755
|
+
outputs = all_reduce_forward(outputs, device_mesh)
|
|
756
|
+
if hasattr(mod, "_bias") and mod._bias is not None:
|
|
757
|
+
outputs = outputs + mod._bias
|
|
758
|
+
return outputs
|
|
609
759
|
|
|
610
760
|
def shard_tensor(
|
|
611
761
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
612
762
|
) -> torch.Tensor:
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
return parameter
|
|
621
|
-
|
|
763
|
+
# If only 1 dim, it should not be sharded (usually it's a `bias`)
|
|
764
|
+
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
765
|
+
if dim == 1:
|
|
766
|
+
parameter = param[...]
|
|
767
|
+
else:
|
|
768
|
+
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
769
|
+
return parameter.to(device=device, dtype=dtype)
|
|
622
770
|
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
771
|
+
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
|
|
772
|
+
# 1D tensors (bias) are NOT sharded in rowwise
|
|
773
|
+
if len(full_shape) == 1:
|
|
774
|
+
return tuple(full_shape)
|
|
775
|
+
world_size = self.device_mesh.size()
|
|
776
|
+
shape = list(full_shape)
|
|
777
|
+
dim = -1
|
|
778
|
+
dim = len(shape) + dim if dim < 0 else dim
|
|
779
|
+
shard_size = math.ceil(shape[dim] / world_size)
|
|
780
|
+
start = self.rank * shard_size
|
|
781
|
+
end = min(start + shard_size, shape[dim])
|
|
782
|
+
shape[dim] = end - start
|
|
783
|
+
return tuple(shape)
|
|
627
784
|
|
|
628
|
-
def __init__(
|
|
629
|
-
self,
|
|
630
|
-
input_layouts: Placement | None = None,
|
|
631
|
-
output_layouts: Placement | None = None,
|
|
632
|
-
use_local_output: bool = True,
|
|
633
|
-
use_dtensor=True,
|
|
634
|
-
**kwargs,
|
|
635
|
-
):
|
|
636
|
-
super().__init__(**kwargs)
|
|
637
|
-
self.input_layouts = (input_layouts or Replicate(),)
|
|
638
|
-
self.output_layouts = (output_layouts or Shard(-1),)
|
|
639
|
-
self.desired_input_layouts = (Replicate(),)
|
|
640
|
-
self.use_local_output = use_local_output
|
|
641
|
-
self.use_dtensor = use_dtensor
|
|
642
785
|
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
# TODO: figure out dynamo support for instance method and switch this to instance method
|
|
646
|
-
# annotate module input placements/sharding with input_layouts
|
|
647
|
-
input_tensor = inputs[0]
|
|
648
|
-
if not isinstance(input_tensor, DTensor):
|
|
649
|
-
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
|
650
|
-
|
|
651
|
-
# transform the input layouts to the desired layouts of ColwiseParallel
|
|
652
|
-
if input_layouts != desired_input_layouts:
|
|
653
|
-
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
|
|
654
|
-
return input_tensor
|
|
786
|
+
class PackedColwiseParallel(ColwiseParallel):
|
|
787
|
+
"""Packed column-wise parallel for fused weights like gate_up_proj."""
|
|
655
788
|
|
|
656
789
|
def shard_tensor(
|
|
657
790
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
@@ -659,333 +792,144 @@ class ColwiseParallel(TensorParallelLayer):
|
|
|
659
792
|
# If only 1 dim, shard this one (usually it's a `bias`)
|
|
660
793
|
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
661
794
|
if dim == 1:
|
|
662
|
-
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1
|
|
663
|
-
shard = [Shard(-1)]
|
|
795
|
+
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
664
796
|
else:
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
797
|
+
expected_shape = self.get_expected_sharded_shape(self.empty_param.shape)
|
|
798
|
+
if dim < len(expected_shape):
|
|
799
|
+
# Input is unpacked (e.g., gate_proj that will be concatenated to gate_up_proj)
|
|
800
|
+
# Use regular tensor shard - concatenation will happen after
|
|
801
|
+
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
|
|
802
|
+
else:
|
|
803
|
+
# Input is already packed, use packed sharding
|
|
804
|
+
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
|
|
668
805
|
return parameter.to(device=device, dtype=dtype)
|
|
669
806
|
|
|
670
|
-
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
671
|
-
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
|
672
|
-
# means Colwise as Linear is input * weight^T + bias, where
|
|
673
|
-
# weight would become Shard(1)
|
|
674
|
-
parameter = self.shard_tensor(param, dtype=dtype)
|
|
675
|
-
if to_contiguous:
|
|
676
|
-
parameter = parameter.contiguous()
|
|
677
|
-
if self.use_dtensor:
|
|
678
|
-
parameter = DTensor.from_local(
|
|
679
|
-
parameter,
|
|
680
|
-
self.device_mesh,
|
|
681
|
-
self.shard,
|
|
682
|
-
run_check=False,
|
|
683
|
-
shape=self.empty_param.size(),
|
|
684
|
-
stride=self.empty_param.stride(),
|
|
685
|
-
)
|
|
686
|
-
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
|
687
|
-
|
|
688
|
-
@staticmethod
|
|
689
|
-
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
|
690
|
-
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
|
|
691
|
-
if outputs.placements != output_layouts:
|
|
692
|
-
outputs = outputs.redistribute(placements=output_layouts, async_op=False)
|
|
693
|
-
# back to local tensor
|
|
694
|
-
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
|
|
695
807
|
|
|
808
|
+
class PackedRowwiseParallel(RowwiseParallel):
|
|
809
|
+
"""Packed row-wise parallel for fused weights like gate_up_proj."""
|
|
696
810
|
|
|
697
|
-
class PackedColwiseParallel(ColwiseParallel):
|
|
698
811
|
def shard_tensor(
|
|
699
812
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
700
813
|
) -> torch.Tensor:
|
|
701
|
-
|
|
814
|
+
# If only 1 dim, it should not be sharded (usually it's a `bias`)
|
|
815
|
+
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
816
|
+
if dim == 1:
|
|
817
|
+
parameter = param[...]
|
|
818
|
+
else:
|
|
819
|
+
# Check if input tensor is unpacked (shape mismatch with expected packed size)
|
|
820
|
+
# This happens when using MergeModulelist + Concatenate for fused weights like gate_up_proj
|
|
821
|
+
param_shape = param.shape if isinstance(param, torch.Tensor) else param.get_shape()
|
|
822
|
+
expected_packed_dim = self.empty_param.shape[-1] if self.empty_param.dim() >= 1 else 0
|
|
823
|
+
actual_dim = param_shape[-1] if len(param_shape) >= 1 else 0
|
|
824
|
+
|
|
825
|
+
if actual_dim < expected_packed_dim:
|
|
826
|
+
# Input is unpacked, use regular tensor shard
|
|
827
|
+
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
828
|
+
else:
|
|
829
|
+
# Input is already packed, use packed sharding
|
|
830
|
+
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
702
831
|
return parameter.to(device=device, dtype=dtype)
|
|
703
832
|
|
|
704
|
-
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
705
|
-
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
|
706
|
-
# means Colwise as Linear is input * weight^T + bias, where
|
|
707
|
-
# weight would become Shard(1)
|
|
708
|
-
parameter = self.shard_tensor(param, dtype=dtype)
|
|
709
|
-
if to_contiguous:
|
|
710
|
-
parameter = parameter.contiguous()
|
|
711
|
-
if self.use_dtensor:
|
|
712
|
-
parameter = DTensor.from_local(parameter, self.device_mesh, [Shard(-2)], run_check=False)
|
|
713
|
-
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
|
714
833
|
|
|
834
|
+
class EmbeddingParallel(TensorParallelLayer):
|
|
835
|
+
"""EmbeddingParallel: shards embedding table, handles masked lookups for vocab parallelism."""
|
|
715
836
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
"""
|
|
837
|
+
def __init__(self, *, embedding_dim_sharding: int = 0, **kwargs):
|
|
838
|
+
super().__init__(**kwargs)
|
|
839
|
+
self.embedding_dim_sharding = embedding_dim_sharding
|
|
720
840
|
|
|
721
|
-
def
|
|
722
|
-
|
|
841
|
+
def _prepare_input_fn(self, mod, inputs, device_mesh):
|
|
842
|
+
input_tensor = inputs[0] if inputs else inputs
|
|
723
843
|
|
|
844
|
+
# For vocab-parallel (dim 0), we need to handle masking and offsetting
|
|
845
|
+
if self.embedding_dim_sharding == 0:
|
|
846
|
+
rank = device_mesh.get_local_rank()
|
|
724
847
|
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
848
|
+
# Get vocab range for this rank
|
|
849
|
+
# Use weight.shape[0] to get the actual local (sharded) size, not num_embeddings
|
|
850
|
+
# which may not be updated after sharding
|
|
851
|
+
per_partition_size = mod.weight.shape[0]
|
|
852
|
+
vocab_start_index = rank * per_partition_size
|
|
853
|
+
vocab_end_index = vocab_start_index + per_partition_size
|
|
729
854
|
|
|
730
|
-
|
|
731
|
-
|
|
855
|
+
# Build mask for out-of-vocabulary tokens
|
|
856
|
+
input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
|
|
857
|
+
mod._input_mask = input_mask
|
|
732
858
|
|
|
859
|
+
# Offset input to local indices and mask invalid ones
|
|
860
|
+
masked_input = input_tensor.clone() - vocab_start_index
|
|
861
|
+
masked_input[input_mask] = 0 # Set to valid local index
|
|
733
862
|
|
|
734
|
-
|
|
735
|
-
"""
|
|
736
|
-
Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
|
|
737
|
-
Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
|
|
738
|
-
(i.e. MLP, Attention)
|
|
739
|
-
|
|
740
|
-
Keyword Args:
|
|
741
|
-
input_layouts (Placement, optional):
|
|
742
|
-
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
|
|
743
|
-
become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
|
|
744
|
-
output_layouts (Placement, optional):
|
|
745
|
-
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
|
|
746
|
-
with the user desired layout. If not specified, the output tensor is replicated.
|
|
747
|
-
use_local_output (bool, optional):
|
|
748
|
-
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
|
|
749
|
-
Returns:
|
|
750
|
-
A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
|
|
751
|
-
"""
|
|
863
|
+
return masked_input
|
|
752
864
|
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
self.use_dtensor = use_dtensor
|
|
865
|
+
return input_tensor
|
|
866
|
+
|
|
867
|
+
def _prepare_output_fn(self, mod, outputs, device_mesh):
|
|
868
|
+
# For vocab-parallel (dim 0), zero out embeddings for out-of-range tokens before all-reduce
|
|
869
|
+
if self.embedding_dim_sharding == 0 and hasattr(mod, "_input_mask"):
|
|
870
|
+
input_mask = mod._input_mask
|
|
871
|
+
# Use multiplication instead of in-place assignment to preserve gradients
|
|
872
|
+
mask_expanded = input_mask.unsqueeze(-1).expand_as(outputs)
|
|
873
|
+
outputs = outputs * (~mask_expanded).float()
|
|
874
|
+
del mod._input_mask
|
|
875
|
+
|
|
876
|
+
return all_reduce_forward(outputs, device_mesh)
|
|
766
877
|
|
|
767
878
|
def shard_tensor(
|
|
768
879
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
769
880
|
) -> torch.Tensor:
|
|
770
|
-
# If only 1 dim,
|
|
881
|
+
# If only 1 dim, shard this one (usually it's a `bias`)
|
|
771
882
|
dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
|
|
772
883
|
if dim == 1:
|
|
773
|
-
|
|
774
|
-
parameter = param[...]
|
|
884
|
+
parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
775
885
|
else:
|
|
776
886
|
parameter = get_tensor_shard(
|
|
777
|
-
param,
|
|
778
|
-
|
|
779
|
-
shard = [Shard(-1)]
|
|
780
|
-
self.shard = shard
|
|
781
|
-
return parameter.to(device=device, dtype=dtype)
|
|
782
|
-
|
|
783
|
-
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
784
|
-
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
|
785
|
-
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
|
786
|
-
# weight would become Shard(0)
|
|
787
|
-
parameter = self.shard_tensor(param, dtype=dtype)
|
|
788
|
-
if to_contiguous:
|
|
789
|
-
parameter = parameter.contiguous()
|
|
790
|
-
if self.use_dtensor:
|
|
791
|
-
parameter = DTensor.from_local(
|
|
792
|
-
parameter,
|
|
887
|
+
param,
|
|
888
|
+
self.empty_param,
|
|
793
889
|
self.device_mesh,
|
|
794
|
-
self.
|
|
795
|
-
|
|
796
|
-
shape=self.empty_param.size(),
|
|
797
|
-
stride=self.empty_param.stride(),
|
|
798
|
-
)
|
|
799
|
-
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
|
800
|
-
|
|
801
|
-
@staticmethod
|
|
802
|
-
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
|
803
|
-
if hasattr(mod, "bias") and mod.bias is not None:
|
|
804
|
-
mod._bias = mod.bias.to_local()
|
|
805
|
-
mod.bias = None
|
|
806
|
-
|
|
807
|
-
input_tensor = inputs[0]
|
|
808
|
-
if not isinstance(input_tensor, DTensor):
|
|
809
|
-
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
|
810
|
-
|
|
811
|
-
if input_layouts != desired_input_layouts:
|
|
812
|
-
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
|
813
|
-
return input_tensor
|
|
814
|
-
|
|
815
|
-
@staticmethod
|
|
816
|
-
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
|
817
|
-
# Rowwise sharding produces partial output, depending on output layouts:
|
|
818
|
-
# 1. to replicate -> allreduce
|
|
819
|
-
# 2. to shard -> reduce_scatter
|
|
820
|
-
if outputs.placements != output_layouts:
|
|
821
|
-
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
|
|
822
|
-
outputs = outputs.to_local() # otherwise the `+=` op will gather
|
|
823
|
-
if hasattr(mod, "_bias"):
|
|
824
|
-
outputs = outputs + mod._bias
|
|
825
|
-
# back to local tensor if use_local_output is True
|
|
826
|
-
return outputs
|
|
827
|
-
|
|
828
|
-
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
|
829
|
-
module._distribute_module_applied = True
|
|
830
|
-
if self.use_dtensor:
|
|
831
|
-
if isinstance(module, nn.Linear):
|
|
832
|
-
# rowwise linear runtime sharding requires input tensor shard on last dim
|
|
833
|
-
self.desired_input_layouts: tuple[Placement, ...] = (Shard(-1),)
|
|
834
|
-
elif isinstance(module, nn.Embedding):
|
|
835
|
-
# rowwise embedding runtime sharding requires input tensor replicated
|
|
836
|
-
self.desired_input_layouts = (Replicate(),)
|
|
837
|
-
elif isinstance(module, nn.Parameter):
|
|
838
|
-
# rowwise embedding runtime sharding requires input tensor replicated
|
|
839
|
-
self.desired_input_layouts = (Shard(-1),)
|
|
840
|
-
else:
|
|
841
|
-
raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!")
|
|
842
|
-
|
|
843
|
-
distribute_module(
|
|
844
|
-
module,
|
|
845
|
-
device_mesh,
|
|
846
|
-
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
|
|
847
|
-
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
|
|
890
|
+
self.rank,
|
|
891
|
+
self.embedding_dim_sharding,
|
|
848
892
|
)
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
class PackedRowwiseParallel(RowwiseParallel):
|
|
852
|
-
def shard_tensor(
|
|
853
|
-
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
854
|
-
) -> torch.Tensor:
|
|
855
|
-
parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
|
|
856
893
|
return parameter.to(device=device, dtype=dtype)
|
|
857
894
|
|
|
858
|
-
def
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
#
|
|
862
|
-
|
|
863
|
-
if
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
class LocalRowwiseParallel(RowwiseParallel):
|
|
871
|
-
"""
|
|
872
|
-
Rowwise parallel with use_dtensor=False for local tensor operations.
|
|
873
|
-
"""
|
|
874
|
-
|
|
875
|
-
def __init__(self, **kwargs):
|
|
876
|
-
super().__init__(use_dtensor=False, **kwargs)
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
class LocalPackedRowwiseParallel(PackedRowwiseParallel):
|
|
880
|
-
"""
|
|
881
|
-
Packed rowwise parallel with use_dtensor=False for local tensor operations.
|
|
882
|
-
"""
|
|
883
|
-
|
|
884
|
-
def __init__(self, **kwargs):
|
|
885
|
-
super().__init__(use_dtensor=False, **kwargs)
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
class RowwiseParallelReplicate(RowwiseParallel):
|
|
889
|
-
"""
|
|
890
|
-
Rowwise parallel with input layouts replicated.
|
|
891
|
-
"""
|
|
892
|
-
|
|
893
|
-
def __init__(self, **kwargs):
|
|
894
|
-
super().__init__(input_layouts=Replicate(), **kwargs)
|
|
895
|
+
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
|
|
896
|
+
world_size = self.device_mesh.size()
|
|
897
|
+
shape = list(full_shape)
|
|
898
|
+
# EmbeddingParallel shards on self.embedding_dim_sharding (default 0)
|
|
899
|
+
# 1D tensors (bias) shard on dim -1
|
|
900
|
+
dim = -1 if len(shape) == 1 else self.embedding_dim_sharding
|
|
901
|
+
dim = len(shape) + dim if dim < 0 else dim
|
|
902
|
+
shard_size = math.ceil(shape[dim] / world_size)
|
|
903
|
+
start = self.rank * shard_size
|
|
904
|
+
end = min(start + shard_size, shape[dim])
|
|
905
|
+
shape[dim] = end - start
|
|
906
|
+
return tuple(shape)
|
|
895
907
|
|
|
896
908
|
|
|
897
909
|
class SequenceParallel(TensorParallelLayer):
|
|
898
910
|
"""
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
`RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
|
|
902
|
-
|
|
903
|
-
This style implements the operation that is described in the paper
|
|
904
|
-
`Reducing Activation Recomputation in Large Transformer Models <https://huggingface.co/papers/2205.05198>`__
|
|
905
|
-
|
|
906
|
-
If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
|
|
907
|
-
on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
|
|
908
|
-
passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
|
|
909
|
-
redistribute the input to be sharded on the sequence dimension.
|
|
910
|
-
|
|
911
|
-
The output of the ``nn.Module`` will be sharded on the sequence dimension.
|
|
912
|
-
|
|
913
|
-
Keyword Args:
|
|
914
|
-
sequence_dim (int, optional):
|
|
915
|
-
The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
|
|
916
|
-
become a DTensor that is sharded on the sequence dimension, default: 1.
|
|
917
|
-
use_local_output (bool, optional):
|
|
918
|
-
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
|
|
919
|
-
Returns:
|
|
920
|
-
A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
|
|
921
|
-
|
|
922
|
-
Example::
|
|
923
|
-
>>> # xdoctest: +SKIP(failing)
|
|
924
|
-
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
|
|
925
|
-
>>> from torch.distributed.device_mesh import init_device_mesh
|
|
926
|
-
>>> ...
|
|
927
|
-
>>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
|
|
928
|
-
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
|
929
|
-
>>>
|
|
930
|
-
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
|
|
931
|
-
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
|
|
932
|
-
>>>
|
|
933
|
-
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
|
|
934
|
-
>>> ...
|
|
935
|
-
|
|
936
|
-
.. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
|
|
937
|
-
``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
|
|
938
|
-
inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
|
|
939
|
-
to ensure that they are replicated.
|
|
911
|
+
Sequence Parallel: input/output sharded on sequence dimension.
|
|
912
|
+
Weights are replicated.
|
|
940
913
|
"""
|
|
941
914
|
|
|
942
915
|
def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
|
|
943
916
|
super().__init__(**kwargs)
|
|
944
|
-
self.
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
917
|
+
self.sequence_dim = sequence_dim
|
|
918
|
+
|
|
919
|
+
def _prepare_input_fn(self, mod, inputs, device_mesh):
|
|
920
|
+
input_tensor = inputs[0] if inputs else inputs
|
|
921
|
+
# For sequence parallel, input is sharded on sequence dim
|
|
922
|
+
# All-gather for the layer, then reduce-scatter after
|
|
923
|
+
return all_gather(input_tensor, device_mesh)
|
|
924
|
+
|
|
925
|
+
def _prepare_output_fn(self, mod, outputs, device_mesh):
|
|
926
|
+
return reduce_scatter(outputs, device_mesh)
|
|
951
927
|
|
|
952
928
|
def shard_tensor(
|
|
953
|
-
self,
|
|
954
|
-
param: torch.Tensor,
|
|
955
|
-
tensor_idx=None,
|
|
956
|
-
device=None,
|
|
957
|
-
dtype=None,
|
|
929
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
958
930
|
) -> torch.Tensor:
|
|
959
|
-
self.shard = [Replicate()]
|
|
960
931
|
return param[...].to(device=device, dtype=dtype)
|
|
961
932
|
|
|
962
|
-
@staticmethod
|
|
963
|
-
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
|
964
|
-
input_tensor = inputs[0]
|
|
965
|
-
if not isinstance(input_tensor, DTensor):
|
|
966
|
-
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
|
967
|
-
if input_layouts != desired_input_layouts:
|
|
968
|
-
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
|
969
|
-
return input_tensor
|
|
970
|
-
|
|
971
|
-
@staticmethod
|
|
972
|
-
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
|
973
|
-
outputs = outputs.redistribute(
|
|
974
|
-
placements=(Replicate(),), async_op=True
|
|
975
|
-
) # maybe we have to replicate ? because next layer is not sharded
|
|
976
|
-
return outputs.to_local() # if use_local_output else outputs
|
|
977
|
-
|
|
978
|
-
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
979
|
-
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
|
980
|
-
# means Colwise as Linear is input * weight^T + bias, where
|
|
981
|
-
# weight would become Shard(1)
|
|
982
|
-
parameter = self.shard_tensor(param, dtype=dtype)
|
|
983
|
-
if to_contiguous:
|
|
984
|
-
parameter = parameter.contiguous()
|
|
985
|
-
if self.use_dtensor:
|
|
986
|
-
parameter = DTensor.from_local(parameter, self.device_mesh, [Replicate()], run_check=False)
|
|
987
|
-
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
|
988
|
-
|
|
989
933
|
|
|
990
934
|
class GroupedGemmParallel(TensorParallelLayer):
|
|
991
935
|
"""
|
|
@@ -994,7 +938,6 @@ class GroupedGemmParallel(TensorParallelLayer):
|
|
|
994
938
|
|
|
995
939
|
def __init__(self, **kwargs):
|
|
996
940
|
super().__init__(**kwargs)
|
|
997
|
-
self.use_dtensor = False
|
|
998
941
|
|
|
999
942
|
def shard_tensor(
|
|
1000
943
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
@@ -1005,15 +948,30 @@ class GroupedGemmParallel(TensorParallelLayer):
|
|
|
1005
948
|
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {self.device_mesh.size()} != 0"
|
|
1006
949
|
)
|
|
1007
950
|
local_num_experts = global_num_experts // self.device_mesh.size()
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
if
|
|
1015
|
-
|
|
1016
|
-
|
|
951
|
+
shard_size = local_num_experts
|
|
952
|
+
if isinstance(device, torch.device):
|
|
953
|
+
device = device.index if device.index is not None else 0
|
|
954
|
+
start = device * shard_size
|
|
955
|
+
end = (device + 1) * shard_size
|
|
956
|
+
# special case we don't "shard" just send this entire tensor to the correct rank.
|
|
957
|
+
shape = param.get_shape() if not isinstance(param, torch.Tensor) else param.shape
|
|
958
|
+
if tensor_idx is not None and start <= tensor_idx < end:
|
|
959
|
+
# this tensor does need to be materialized on this device:
|
|
960
|
+
return param[:].to(device=device)
|
|
961
|
+
elif tensor_idx is None: # a bias or a weight, but already merged
|
|
962
|
+
return param[start:end].to(device=device, dtype=dtype)
|
|
963
|
+
elif len(shape) >= 1 and tensor_idx is not None:
|
|
964
|
+
return None
|
|
965
|
+
else: # bias case
|
|
966
|
+
return param[:].to(device=device, dtype=dtype)
|
|
967
|
+
|
|
968
|
+
def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
|
|
969
|
+
# GroupedGemm shards on dim 0 (experts dimension)
|
|
970
|
+
world_size = self.device_mesh.size()
|
|
971
|
+
shape = list(full_shape)
|
|
972
|
+
local_num_experts = shape[0] // world_size
|
|
973
|
+
shape[0] = local_num_experts
|
|
974
|
+
return tuple(shape)
|
|
1017
975
|
|
|
1018
976
|
|
|
1019
977
|
class RouterParallel(TensorParallelLayer):
|
|
@@ -1021,20 +979,15 @@ class RouterParallel(TensorParallelLayer):
|
|
|
1021
979
|
Allows to reshape the router scores to support running expert parallel.
|
|
1022
980
|
"""
|
|
1023
981
|
|
|
1024
|
-
def __init__(self,
|
|
982
|
+
def __init__(self, **kwargs):
|
|
1025
983
|
super().__init__(**kwargs)
|
|
1026
|
-
self.args = args
|
|
1027
|
-
self.use_dtensor = use_dtensor
|
|
1028
984
|
|
|
1029
985
|
@staticmethod
|
|
1030
|
-
def _prepare_input_fn(
|
|
1031
|
-
|
|
1032
|
-
if isinstance(input_tensor, DTensor):
|
|
1033
|
-
raise NotImplementedError("RouterParallel does not support DTensor input for now")
|
|
1034
|
-
return input_tensor
|
|
986
|
+
def _prepare_input_fn(mod, inputs, device_mesh):
|
|
987
|
+
return inputs[0] if inputs else inputs
|
|
1035
988
|
|
|
1036
989
|
@staticmethod
|
|
1037
|
-
def _prepare_output_fn(
|
|
990
|
+
def _prepare_output_fn(mod, outputs, device_mesh):
|
|
1038
991
|
"""
|
|
1039
992
|
Imagine if you had 4 tokens, top_k = 4, and 128experts.
|
|
1040
993
|
With EP = 8. The num_local_expert should be 128/8 = 16
|
|
@@ -1076,6 +1029,7 @@ class RouterParallel(TensorParallelLayer):
|
|
|
1076
1029
|
)
|
|
1077
1030
|
num_local_experts = mod.num_experts // ep_size
|
|
1078
1031
|
router_logits, router_scores, router_indices = outputs
|
|
1032
|
+
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_scores)
|
|
1079
1033
|
router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
|
|
1080
1034
|
router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
|
|
1081
1035
|
# As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
|
|
@@ -1083,32 +1037,54 @@ class RouterParallel(TensorParallelLayer):
|
|
|
1083
1037
|
router_indices = torch.fmod(router_indices, num_local_experts)
|
|
1084
1038
|
else:
|
|
1085
1039
|
router_indices = router_indices.masked_fill(router_indices > 0, 0).masked_fill(router_indices < 0, -1)
|
|
1086
|
-
router_indices = router_indices.masked_fill(
|
|
1087
|
-
router_indices == -1, num_local_experts
|
|
1088
|
-
) # masking class for one hot
|
|
1040
|
+
router_indices = router_indices.masked_fill(router_indices == -1, num_local_experts)
|
|
1089
1041
|
return router_logits, router_scores, router_indices
|
|
1090
1042
|
|
|
1091
1043
|
def shard_tensor(
|
|
1092
1044
|
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
1093
1045
|
) -> torch.Tensor:
|
|
1094
|
-
self.shard = None
|
|
1095
1046
|
return param[...].to(device=device, dtype=dtype)
|
|
1096
1047
|
|
|
1097
|
-
def partition_tensor(self, param: torch.Tensor, dtype, to_contiguous: bool):
|
|
1098
|
-
# TODO: i'd like for this to be the default
|
|
1099
|
-
parameter = self.shard_tensor(param, dtype=dtype)
|
|
1100
|
-
if to_contiguous:
|
|
1101
|
-
parameter = parameter.contiguous()
|
|
1102
|
-
return parameter
|
|
1103
1048
|
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1049
|
+
class MoeTensorParalellExperts(TensorParallelLayer):
|
|
1050
|
+
"""
|
|
1051
|
+
Note: For tensor parallel, the MoEExpertsParallel TP layer handles gradient sync:
|
|
1052
|
+
- all_reduce_backward on hidden_states (for colwise gate_up_proj gradient)
|
|
1053
|
+
- all_reduce_backward on top_k_weights (for router gradient)
|
|
1054
|
+
- all_reduce_forward on output (for partial expert outputs)
|
|
1055
|
+
"""
|
|
1056
|
+
|
|
1057
|
+
def __init__(self, **kwargs):
|
|
1058
|
+
super().__init__(**kwargs)
|
|
1059
|
+
|
|
1060
|
+
@staticmethod
|
|
1061
|
+
def _prepare_input_fn(mod, inputs, device_mesh):
|
|
1062
|
+
# inputs = (hidden_states, top_k_index, top_k_weights)
|
|
1063
|
+
hidden_states = inputs[0]
|
|
1064
|
+
top_k_index = inputs[1]
|
|
1065
|
+
top_k_weights = inputs[2]
|
|
1066
|
+
|
|
1067
|
+
# all_reduce_backward on hidden_states for correct colwise (gate_up_proj) gradient
|
|
1068
|
+
hidden_states = all_reduce_backward(hidden_states, device_mesh)
|
|
1069
|
+
|
|
1070
|
+
# all_reduce_backward on routing weights for correct router gradient
|
|
1071
|
+
# This is needed because ∂L/∂routing_weights = ∂L/∂output * partial_expert_output
|
|
1072
|
+
# and partial_expert_output is different on each GPU before all-reduce
|
|
1073
|
+
top_k_weights = all_reduce_backward(top_k_weights, device_mesh)
|
|
1074
|
+
|
|
1075
|
+
return (hidden_states, top_k_index, top_k_weights)
|
|
1076
|
+
|
|
1077
|
+
@staticmethod
|
|
1078
|
+
def _prepare_output_fn(mod, outputs, device_mesh):
|
|
1079
|
+
# all_reduce_forward to sum partial expert outputs across GPUs
|
|
1080
|
+
return all_reduce_forward(outputs, device_mesh)
|
|
1081
|
+
|
|
1082
|
+
def shard_tensor(
|
|
1083
|
+
self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
|
|
1084
|
+
) -> torch.Tensor:
|
|
1085
|
+
# This class doesn't shard tensors - sharding is handled by packed_colwise/rowwise
|
|
1086
|
+
# on the individual weight tensors (gate_up_proj/down_proj)
|
|
1087
|
+
return param[...].to(device=device, dtype=dtype)
|
|
1112
1088
|
|
|
1113
1089
|
|
|
1114
1090
|
class ParallelInterface(GeneralInterface):
|
|
@@ -1116,69 +1092,152 @@ class ParallelInterface(GeneralInterface):
|
|
|
1116
1092
|
# a new instance is created (in order to locally override a given entry)
|
|
1117
1093
|
_global_mapping = (
|
|
1118
1094
|
{
|
|
1095
|
+
"embedding_rowwise": EmbeddingParallel(embedding_dim_sharding=0),
|
|
1096
|
+
"colwise_gather_output": ColwiseParallel(gather_output=True),
|
|
1119
1097
|
"colwise": ColwiseParallel(),
|
|
1120
1098
|
"rowwise": RowwiseParallel(),
|
|
1121
|
-
"
|
|
1122
|
-
"
|
|
1123
|
-
"
|
|
1124
|
-
"local_rowwise": LocalRowwiseParallel(),
|
|
1125
|
-
"local": IsolatedParallel(),
|
|
1126
|
-
"gather": GatherParallel(),
|
|
1127
|
-
"local_packed_rowwise": LocalPackedRowwiseParallel(),
|
|
1099
|
+
"rowwise_split_input": RowwiseParallel(split_input=True),
|
|
1100
|
+
"packed_colwise": PackedColwiseParallel(),
|
|
1101
|
+
"packed_rowwise": PackedRowwiseParallel(),
|
|
1128
1102
|
"sequence_parallel": SequenceParallel(),
|
|
1129
|
-
"replicate": ReplicateParallel(),
|
|
1130
1103
|
"grouped_gemm": GroupedGemmParallel(),
|
|
1131
1104
|
"ep_router": RouterParallel(),
|
|
1105
|
+
"moe_tp_experts": MoeTensorParalellExperts(),
|
|
1132
1106
|
}
|
|
1133
|
-
if
|
|
1107
|
+
if is_torch_available() and _torch_distributed_available
|
|
1134
1108
|
else {}
|
|
1135
1109
|
)
|
|
1136
1110
|
|
|
1111
|
+
# Map plan names to sharding dimensions for weights
|
|
1112
|
+
# For weights: colwise shards dim -2, rowwise shards dim -1
|
|
1113
|
+
# For embedding: rowwise shards dim 0 (vocab), colwise shards dim -2 (hidden)
|
|
1114
|
+
plan_to_weight_dim: dict[str, int | None] = {
|
|
1115
|
+
"colwise": -2,
|
|
1116
|
+
"colwise_gather_output": -2,
|
|
1117
|
+
"packed_colwise": -2,
|
|
1118
|
+
"rowwise": -1,
|
|
1119
|
+
"rowwise_split_input": -1,
|
|
1120
|
+
"packed_rowwise": -1,
|
|
1121
|
+
"embedding_rowwise": 0,
|
|
1122
|
+
"sequence_parallel": None,
|
|
1123
|
+
}
|
|
1124
|
+
|
|
1125
|
+
# Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced)
|
|
1126
|
+
plan_to_bias_dim: dict[str, int | None] = {
|
|
1127
|
+
"colwise": -1,
|
|
1128
|
+
"colwise_gather_output": -1,
|
|
1129
|
+
"packed_colwise": -1,
|
|
1130
|
+
"rowwise": None,
|
|
1131
|
+
"rowwise_split_input": None,
|
|
1132
|
+
"packed_rowwise": None,
|
|
1133
|
+
"embedding_rowwise": None,
|
|
1134
|
+
"sequence_parallel": None,
|
|
1135
|
+
}
|
|
1136
|
+
|
|
1137
1137
|
|
|
1138
1138
|
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
|
|
1139
1139
|
|
|
1140
1140
|
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1141
|
+
# =============================================================================
|
|
1142
|
+
# High-Level API Functions
|
|
1143
|
+
# =============================================================================
|
|
1144
|
+
|
|
1145
|
+
|
|
1146
|
+
def gather_full_tensor(local_tensor: torch.Tensor, shard_dim: int, device_mesh) -> torch.Tensor:
|
|
1144
1147
|
"""
|
|
1145
|
-
|
|
1148
|
+
All-gather a sharded tensor along the specified dimension to reconstruct the full tensor.
|
|
1149
|
+
|
|
1150
|
+
Args:
|
|
1151
|
+
local_tensor: The local shard of the tensor on this rank
|
|
1152
|
+
shard_dim: The dimension along which the tensor was sharded
|
|
1153
|
+
device_mesh: The device mesh for distributed communication
|
|
1154
|
+
|
|
1155
|
+
Returns:
|
|
1156
|
+
The full reconstructed tensor (same on all ranks)
|
|
1146
1157
|
"""
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
|
|
1153
|
-
return parameter
|
|
1154
|
-
# TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
|
|
1155
|
-
if tp_style == "local_packed_rowwise":
|
|
1156
|
-
placements = [Shard(-1)]
|
|
1157
|
-
elif tp_style == "local_rowwise":
|
|
1158
|
-
if param_type == "bias":
|
|
1159
|
-
placements = [Replicate()]
|
|
1160
|
-
else:
|
|
1161
|
-
placements = [Shard(-1)]
|
|
1162
|
-
elif tp_style == "local_colwise":
|
|
1163
|
-
if param_type == "bias":
|
|
1164
|
-
placements = [Shard(-1)]
|
|
1165
|
-
else:
|
|
1166
|
-
placements = [Shard(-2)]
|
|
1167
|
-
return DTensor.from_local(parameter, device_mesh, placements, run_check=False)
|
|
1158
|
+
world_size = device_mesh.size()
|
|
1159
|
+
|
|
1160
|
+
# Normalize negative dimension
|
|
1161
|
+
if shard_dim < 0:
|
|
1162
|
+
shard_dim = local_tensor.ndim + shard_dim
|
|
1168
1163
|
|
|
1164
|
+
# Gather all shards
|
|
1165
|
+
gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
|
|
1166
|
+
dist.all_gather(gathered_tensors, local_tensor.contiguous())
|
|
1169
1167
|
|
|
1170
|
-
|
|
1168
|
+
# Concatenate along the shard dimension
|
|
1169
|
+
return torch.cat(gathered_tensors, dim=shard_dim)
|
|
1170
|
+
|
|
1171
|
+
|
|
1172
|
+
def gather_state_dict_for_save(
|
|
1171
1173
|
state_dict: dict[str, torch.Tensor],
|
|
1172
1174
|
tp_plan: dict[str, str],
|
|
1173
1175
|
device_mesh,
|
|
1176
|
+
tp_size: int,
|
|
1174
1177
|
) -> dict[str, torch.Tensor]:
|
|
1175
1178
|
"""
|
|
1176
|
-
|
|
1179
|
+
Gather sharded tensors to reconstruct full tensors for saving.
|
|
1180
|
+
|
|
1181
|
+
This function all-gathers each sharded tensor along its shard dimension
|
|
1182
|
+
to reconstruct the full unsharded tensor for checkpoint saving.
|
|
1183
|
+
|
|
1184
|
+
Args:
|
|
1185
|
+
state_dict: The model state dict with local sharded tensors
|
|
1186
|
+
tp_plan: The tensor parallel plan mapping layer patterns to shard styles
|
|
1187
|
+
device_mesh: The device mesh for distributed communication
|
|
1188
|
+
tp_size: The tensor parallel world size
|
|
1189
|
+
|
|
1190
|
+
Returns:
|
|
1191
|
+
State dict with full (gathered) tensors
|
|
1177
1192
|
"""
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1193
|
+
# Use the global mappings from ParallelInterface (can be extended by users)
|
|
1194
|
+
plan_to_weight_dim = ALL_PARALLEL_STYLES.plan_to_weight_dim
|
|
1195
|
+
plan_to_bias_dim = ALL_PARALLEL_STYLES.plan_to_bias_dim
|
|
1196
|
+
|
|
1197
|
+
result = {}
|
|
1198
|
+
for key, tensor in state_dict.items():
|
|
1199
|
+
# Find the matching TP plan for this parameter
|
|
1200
|
+
param_name = key.rsplit(".", 1)[0] if "." in key else key
|
|
1201
|
+
param_type = key.rsplit(".", 1)[1] if "." in key else None
|
|
1202
|
+
generic_param_name = re.sub(r"\d+", "*", param_name)
|
|
1203
|
+
# Also check the full key for nn.Parameter (e.g., MoE experts without .weight suffix)
|
|
1204
|
+
generic_full_key = re.sub(r"\d+", "*", key)
|
|
1205
|
+
|
|
1206
|
+
# Check if this parameter has a TP plan
|
|
1207
|
+
current_plan = None
|
|
1208
|
+
if generic_full_key in tp_plan:
|
|
1209
|
+
# Full key match (e.g., "model.layers.*.mlp.experts.gate_up_proj" for MoE experts)
|
|
1210
|
+
current_plan = tp_plan[generic_full_key]
|
|
1211
|
+
elif generic_param_name in tp_plan:
|
|
1212
|
+
current_plan = tp_plan[generic_param_name]
|
|
1213
|
+
elif "." in generic_param_name:
|
|
1214
|
+
parent_param_name = generic_param_name.rsplit(".", 1)[0]
|
|
1215
|
+
if parent_param_name in tp_plan:
|
|
1216
|
+
current_plan = tp_plan[parent_param_name]
|
|
1217
|
+
|
|
1218
|
+
if current_plan is None or current_plan not in plan_to_weight_dim:
|
|
1219
|
+
# Not sharded, keep as-is
|
|
1220
|
+
result[key] = tensor
|
|
1221
|
+
continue
|
|
1222
|
+
|
|
1223
|
+
# Determine sharding dimension based on param type
|
|
1224
|
+
if param_type == "bias":
|
|
1225
|
+
shard_dim = plan_to_bias_dim.get(current_plan)
|
|
1226
|
+
else:
|
|
1227
|
+
shard_dim = plan_to_weight_dim.get(current_plan)
|
|
1228
|
+
|
|
1229
|
+
if shard_dim is None:
|
|
1230
|
+
# Replicated, keep as-is
|
|
1231
|
+
result[key] = tensor
|
|
1232
|
+
continue
|
|
1233
|
+
|
|
1234
|
+
# Gather full tensor and handle packed weights repacking
|
|
1235
|
+
full_tensor = gather_full_tensor(tensor, shard_dim, device_mesh)
|
|
1236
|
+
if current_plan in ("packed_colwise", "packed_rowwise"):
|
|
1237
|
+
full_tensor = repack_weights(full_tensor, shard_dim, tp_size, 2)
|
|
1238
|
+
result[key] = full_tensor.contiguous()
|
|
1239
|
+
|
|
1240
|
+
return result
|
|
1182
1241
|
|
|
1183
1242
|
|
|
1184
1243
|
def add_tensor_parallel_hooks_to_module(
|
|
@@ -1207,7 +1266,7 @@ def add_tensor_parallel_hooks_to_module(
|
|
|
1207
1266
|
|
|
1208
1267
|
def shard_and_distribute_module(
|
|
1209
1268
|
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
|
|
1210
|
-
):
|
|
1269
|
+
):
|
|
1211
1270
|
r"""
|
|
1212
1271
|
This function is called in `from_pretrained` when loading a model's checkpoints.
|
|
1213
1272
|
It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding".
|
|
@@ -1223,7 +1282,7 @@ def shard_and_distribute_module(
|
|
|
1223
1282
|
"""
|
|
1224
1283
|
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
|
1225
1284
|
tp_plan = model.tp_plan or {}
|
|
1226
|
-
module_to_tp = model.get_submodule(param_name)
|
|
1285
|
+
module_to_tp = model.get_submodule(param_name)
|
|
1227
1286
|
rank = int(rank)
|
|
1228
1287
|
current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
|
|
1229
1288
|
|
|
@@ -1235,10 +1294,13 @@ def shard_and_distribute_module(
|
|
|
1235
1294
|
|
|
1236
1295
|
if current_shard_plan is not None:
|
|
1237
1296
|
try:
|
|
1238
|
-
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1297
|
+
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
|
|
1298
|
+
tp_layer.empty_param = empty_param
|
|
1299
|
+
tp_layer.device_mesh = device_mesh
|
|
1300
|
+
tp_layer.rank = rank
|
|
1301
|
+
param = tp_layer.shard_tensor(param, tensor_idx=None, dtype=param_casting_dtype, device=rank)
|
|
1302
|
+
if is_contiguous:
|
|
1303
|
+
param = param.contiguous()
|
|
1242
1304
|
except NotImplementedError as e:
|
|
1243
1305
|
print(
|
|
1244
1306
|
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
|
|
@@ -1251,7 +1313,6 @@ def shard_and_distribute_module(
|
|
|
1251
1313
|
if not isinstance(param, torch.nn.Parameter):
|
|
1252
1314
|
param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point())
|
|
1253
1315
|
setattr(module_to_tp, param_type, param)
|
|
1254
|
-
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
|
|
1255
1316
|
return param
|
|
1256
1317
|
|
|
1257
1318
|
|
|
@@ -1265,20 +1326,18 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
|
|
|
1265
1326
|
|
|
1266
1327
|
generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
|
|
1267
1328
|
unsharded_layers = set(generic_keys)
|
|
1268
|
-
unused_rules = tp_plan
|
|
1329
|
+
unused_rules = tp_plan.copy()
|
|
1269
1330
|
|
|
1270
1331
|
for key in generic_keys:
|
|
1271
1332
|
param_name = key.rsplit(".", 1)[0] if "." in key else key
|
|
1272
1333
|
generic_param_name = re.sub(r"\d+", "*", param_name)
|
|
1273
1334
|
|
|
1274
1335
|
if generic_param_name in tp_plan:
|
|
1275
|
-
unused_rules.pop(generic_param_name)
|
|
1336
|
+
unused_rules.pop(generic_param_name, None)
|
|
1276
1337
|
unsharded_layers.discard(key)
|
|
1277
1338
|
elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
|
|
1278
|
-
unused_rules.pop(parent_param_name)
|
|
1339
|
+
unused_rules.pop(parent_param_name, None)
|
|
1279
1340
|
unsharded_layers.discard(key)
|
|
1280
|
-
else:
|
|
1281
|
-
pass # we couldn't find the rule for this parameter, so it's not sharded
|
|
1282
1341
|
|
|
1283
1342
|
if len(unused_rules) > 0:
|
|
1284
1343
|
logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
|
|
@@ -1287,6 +1346,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
|
|
|
1287
1346
|
|
|
1288
1347
|
|
|
1289
1348
|
def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
|
|
1349
|
+
"""Distribute a model according to the TP plan."""
|
|
1290
1350
|
model._tp_size = tp_size
|
|
1291
1351
|
model._device_mesh = device_mesh
|
|
1292
1352
|
if distributed_config is not None:
|
|
@@ -1297,7 +1357,7 @@ def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
|
|
|
1297
1357
|
if isinstance(tp_plan, dict):
|
|
1298
1358
|
model.tp_plan = tp_plan
|
|
1299
1359
|
model_plan = model.tp_plan
|
|
1300
|
-
if model_plan is not None and
|
|
1360
|
+
if model_plan is not None and _torch_distributed_available:
|
|
1301
1361
|
for v in model_plan.values():
|
|
1302
1362
|
if v not in ALL_PARALLEL_STYLES:
|
|
1303
1363
|
raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
|