transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -33,7 +33,7 @@ from ... import initialization as init
|
|
|
33
33
|
from ...activations import ACT2FN
|
|
34
34
|
from ...cache_utils import Cache, DynamicCache
|
|
35
35
|
from ...generation import GenerationMixin
|
|
36
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
36
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
37
37
|
from ...masking_utils import create_causal_mask
|
|
38
38
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
39
39
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
41
41
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
42
42
|
from ...processing_utils import Unpack
|
|
43
43
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
44
|
-
from ...utils.generic import check_model_inputs
|
|
44
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
45
45
|
from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
|
|
46
46
|
|
|
47
47
|
|
|
@@ -118,6 +118,7 @@ def eager_attention_forward(
|
|
|
118
118
|
return attn_output, attn_weights
|
|
119
119
|
|
|
120
120
|
|
|
121
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
121
122
|
class Emu3Attention(nn.Module):
|
|
122
123
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
123
124
|
|
|
@@ -143,7 +144,6 @@ class Emu3Attention(nn.Module):
|
|
|
143
144
|
self.o_proj = nn.Linear(
|
|
144
145
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
145
146
|
)
|
|
146
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
147
147
|
|
|
148
148
|
def forward(
|
|
149
149
|
self,
|
|
@@ -958,6 +958,10 @@ class Emu3VQVAE(PreTrainedModel):
|
|
|
958
958
|
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
|
959
959
|
init.constant_(module.weight, 1.0)
|
|
960
960
|
init.constant_(module.bias, 0.0)
|
|
961
|
+
if getattr(module, "running_mean", None) is not None:
|
|
962
|
+
init.zeros_(module.running_mean)
|
|
963
|
+
init.ones_(module.running_var)
|
|
964
|
+
init.zeros_(module.num_batches_tracked)
|
|
961
965
|
elif isinstance(module, nn.Embedding):
|
|
962
966
|
init.normal_(module.weight)
|
|
963
967
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
@@ -1128,7 +1132,7 @@ class Emu3RotaryEmbedding(nn.Module):
|
|
|
1128
1132
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
1129
1133
|
|
|
1130
1134
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
1131
|
-
self.original_inv_freq =
|
|
1135
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
1132
1136
|
|
|
1133
1137
|
@staticmethod
|
|
1134
1138
|
def compute_default_rope_parameters(
|
|
@@ -1167,7 +1171,7 @@ class Emu3RotaryEmbedding(nn.Module):
|
|
|
1167
1171
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
1168
1172
|
|
|
1169
1173
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
1170
|
-
with
|
|
1174
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
1171
1175
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
1172
1176
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1173
1177
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -1615,6 +1619,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
|
|
1615
1619
|
position_ids=None,
|
|
1616
1620
|
use_cache=True,
|
|
1617
1621
|
pixel_values=None,
|
|
1622
|
+
is_first_iteration=False,
|
|
1618
1623
|
**kwargs,
|
|
1619
1624
|
):
|
|
1620
1625
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -1628,10 +1633,11 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
|
|
1628
1633
|
position_ids=position_ids,
|
|
1629
1634
|
pixel_values=pixel_values,
|
|
1630
1635
|
use_cache=use_cache,
|
|
1636
|
+
is_first_iteration=is_first_iteration,
|
|
1631
1637
|
**kwargs,
|
|
1632
1638
|
)
|
|
1633
1639
|
|
|
1634
|
-
if
|
|
1640
|
+
if not is_first_iteration and use_cache:
|
|
1635
1641
|
model_inputs["pixel_values"] = None
|
|
1636
1642
|
|
|
1637
1643
|
return model_inputs
|
|
@@ -706,6 +706,10 @@ class Emu3VQVAE(PreTrainedModel):
|
|
|
706
706
|
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
|
707
707
|
init.constant_(module.weight, 1.0)
|
|
708
708
|
init.constant_(module.bias, 0.0)
|
|
709
|
+
if getattr(module, "running_mean", None) is not None:
|
|
710
|
+
init.zeros_(module.running_mean)
|
|
711
|
+
init.ones_(module.running_var)
|
|
712
|
+
init.zeros_(module.num_batches_tracked)
|
|
709
713
|
elif isinstance(module, nn.Embedding):
|
|
710
714
|
init.normal_(module.weight)
|
|
711
715
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
@@ -1167,6 +1171,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
|
|
1167
1171
|
position_ids=None,
|
|
1168
1172
|
use_cache=True,
|
|
1169
1173
|
pixel_values=None,
|
|
1174
|
+
is_first_iteration=False,
|
|
1170
1175
|
**kwargs,
|
|
1171
1176
|
):
|
|
1172
1177
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -1180,10 +1185,11 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
|
|
1180
1185
|
position_ids=position_ids,
|
|
1181
1186
|
pixel_values=pixel_values,
|
|
1182
1187
|
use_cache=use_cache,
|
|
1188
|
+
is_first_iteration=is_first_iteration,
|
|
1183
1189
|
**kwargs,
|
|
1184
1190
|
)
|
|
1185
1191
|
|
|
1186
|
-
if
|
|
1192
|
+
if not is_first_iteration and use_cache:
|
|
1187
1193
|
model_inputs["pixel_values"] = None
|
|
1188
1194
|
|
|
1189
1195
|
return model_inputs
|
|
@@ -474,6 +474,20 @@ class EncodecPreTrainedModel(PreTrainedAudioTokenizerBase):
|
|
|
474
474
|
init.xavier_uniform_(param)
|
|
475
475
|
elif "bias" in name:
|
|
476
476
|
init.constant_(param, 0.0)
|
|
477
|
+
elif isinstance(module, EncodecConv1d):
|
|
478
|
+
kernel_size = module.conv.kernel_size[0]
|
|
479
|
+
stride = torch.tensor(module.conv.stride[0], dtype=torch.int64)
|
|
480
|
+
dilation = module.conv.dilation[0]
|
|
481
|
+
# Effective kernel size with dilations.
|
|
482
|
+
kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64)
|
|
483
|
+
init.copy_(module.stride, stride)
|
|
484
|
+
init.copy_(module.kernel_size, kernel_size)
|
|
485
|
+
init.copy_(module.padding_total, kernel_size - stride)
|
|
486
|
+
elif isinstance(module, EncodecEuclideanCodebook):
|
|
487
|
+
init.copy_(module.inited, torch.Tensor([True]))
|
|
488
|
+
init.zeros_(module.cluster_size)
|
|
489
|
+
init.zeros_(module.embed)
|
|
490
|
+
init.zeros_(module.embed_avg)
|
|
477
491
|
|
|
478
492
|
|
|
479
493
|
@auto_docstring(
|
|
@@ -815,7 +815,19 @@ class EomtImageProcessor(BaseImageProcessor):
|
|
|
815
815
|
|
|
816
816
|
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
|
817
817
|
|
|
818
|
-
|
|
818
|
+
if patch_offsets:
|
|
819
|
+
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
|
|
820
|
+
else:
|
|
821
|
+
output_logits = []
|
|
822
|
+
|
|
823
|
+
for idx in range(len(segmentation_logits)):
|
|
824
|
+
resized_logits = torch.nn.functional.interpolate(
|
|
825
|
+
segmentation_logits[idx].unsqueeze(dim=0),
|
|
826
|
+
size=target_sizes[idx],
|
|
827
|
+
mode="bilinear",
|
|
828
|
+
align_corners=False,
|
|
829
|
+
)
|
|
830
|
+
output_logits.append(resized_logits[0])
|
|
819
831
|
|
|
820
832
|
preds = [logit.argmax(dim=0) for logit in output_logits]
|
|
821
833
|
return preds
|
|
@@ -44,12 +44,43 @@ from ...utils import (
|
|
|
44
44
|
from .image_processing_eomt import (
|
|
45
45
|
EomtImageProcessorKwargs,
|
|
46
46
|
compute_segments,
|
|
47
|
-
convert_segmentation_map_to_binary_masks,
|
|
48
47
|
get_size_with_aspect_ratio,
|
|
49
48
|
remove_low_and_no_objects,
|
|
50
49
|
)
|
|
51
50
|
|
|
52
51
|
|
|
52
|
+
# Adapted from transformers.models.maskformer.image_processing_maskformer_fast.convert_segmentation_map_to_binary_masks_fast
|
|
53
|
+
def convert_segmentation_map_to_binary_masks_fast(
|
|
54
|
+
segmentation_map: "torch.Tensor",
|
|
55
|
+
instance_id_to_semantic_id: Optional[dict[int, int]] = None,
|
|
56
|
+
ignore_index: Optional[int] = None,
|
|
57
|
+
):
|
|
58
|
+
if ignore_index is not None:
|
|
59
|
+
segmentation_map = torch.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
|
|
60
|
+
|
|
61
|
+
all_labels = torch.unique(segmentation_map)
|
|
62
|
+
|
|
63
|
+
if ignore_index is not None:
|
|
64
|
+
all_labels = all_labels[all_labels != ignore_index] # drop background label if applicable
|
|
65
|
+
|
|
66
|
+
binary_masks = [(segmentation_map == i) for i in all_labels]
|
|
67
|
+
if binary_masks:
|
|
68
|
+
binary_masks = torch.stack(binary_masks, dim=0)
|
|
69
|
+
else:
|
|
70
|
+
binary_masks = torch.zeros((0, *segmentation_map.shape), device=segmentation_map.device)
|
|
71
|
+
|
|
72
|
+
# Convert instance ids to class ids
|
|
73
|
+
if instance_id_to_semantic_id is not None:
|
|
74
|
+
labels = torch.zeros(all_labels.shape[0], device=segmentation_map.device)
|
|
75
|
+
|
|
76
|
+
for i, label in enumerate(all_labels):
|
|
77
|
+
class_id = instance_id_to_semantic_id[(label.item() + 1 if ignore_index is not None else label.item())]
|
|
78
|
+
labels[i] = class_id - 1 if ignore_index is not None else class_id
|
|
79
|
+
else:
|
|
80
|
+
labels = all_labels
|
|
81
|
+
return binary_masks.float(), labels.long()
|
|
82
|
+
|
|
83
|
+
|
|
53
84
|
def get_target_size(size_dict: dict[str, int]) -> tuple[int, int]:
|
|
54
85
|
"""Returns the height and width from a size dict."""
|
|
55
86
|
target_height = size_dict["shortest_edge"]
|
|
@@ -162,8 +193,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|
|
162
193
|
)
|
|
163
194
|
ignore_index = kwargs.pop("ignore_index", None)
|
|
164
195
|
images_kwargs = kwargs.copy()
|
|
165
|
-
|
|
166
|
-
outputs = BatchFeature({"pixel_values": processed_images})
|
|
196
|
+
outputs = self._preprocess(images, **images_kwargs)
|
|
167
197
|
|
|
168
198
|
if segmentation_maps is not None:
|
|
169
199
|
processed_segmentation_maps = self._prepare_image_like_inputs(
|
|
@@ -183,9 +213,9 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|
|
183
213
|
}
|
|
184
214
|
)
|
|
185
215
|
|
|
186
|
-
processed_segmentation_maps
|
|
216
|
+
processed_segmentation_maps = self._preprocess(
|
|
187
217
|
images=processed_segmentation_maps, **segmentation_maps_kwargs
|
|
188
|
-
)
|
|
218
|
+
).pixel_values
|
|
189
219
|
processed_segmentation_maps = processed_segmentation_maps.squeeze(1).to(torch.int64)
|
|
190
220
|
# Convert to list of binary masks and labels
|
|
191
221
|
mask_labels, class_labels = [], []
|
|
@@ -195,21 +225,21 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|
|
195
225
|
else:
|
|
196
226
|
instance_id = instance_id_to_semantic_id
|
|
197
227
|
# Use instance2class_id mapping per image
|
|
198
|
-
masks, classes =
|
|
228
|
+
masks, classes = convert_segmentation_map_to_binary_masks_fast(
|
|
199
229
|
segmentation_map,
|
|
200
230
|
instance_id,
|
|
201
231
|
ignore_index=ignore_index,
|
|
202
232
|
)
|
|
203
233
|
|
|
204
|
-
mask_labels.append(
|
|
205
|
-
class_labels.append(
|
|
234
|
+
mask_labels.append(masks)
|
|
235
|
+
class_labels.append(classes)
|
|
206
236
|
|
|
207
237
|
# we cannot batch them since they don't share a common class size
|
|
208
238
|
outputs["mask_labels"] = mask_labels
|
|
209
239
|
outputs["class_labels"] = class_labels
|
|
210
240
|
|
|
211
|
-
if patch_offsets:
|
|
212
|
-
outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets]
|
|
241
|
+
if outputs.patch_offsets:
|
|
242
|
+
outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in outputs.patch_offsets]
|
|
213
243
|
|
|
214
244
|
return outputs
|
|
215
245
|
|
|
@@ -239,7 +269,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|
|
239
269
|
for shape, stacked_images in grouped_images.items():
|
|
240
270
|
if do_resize:
|
|
241
271
|
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
|
242
|
-
|
|
272
|
+
resized_images_grouped[shape] = stacked_images
|
|
243
273
|
images = reorder_images(resized_images_grouped, grouped_images_index)
|
|
244
274
|
|
|
245
275
|
# Group images by size for batched resizing, Needed in case do_resize is False.
|
|
@@ -274,11 +304,13 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|
|
274
304
|
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
|
275
305
|
)
|
|
276
306
|
processed_images_grouped[shape] = stacked_images
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
processed_images = torch.stack(images, dim=0) if return_tensors else images
|
|
307
|
+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
280
308
|
|
|
281
|
-
return
|
|
309
|
+
return BatchFeature(
|
|
310
|
+
data={"pixel_values": processed_images, "patch_offsets": patch_offsets},
|
|
311
|
+
tensor_type=return_tensors,
|
|
312
|
+
skip_tensor_conversion=["patch_offsets"],
|
|
313
|
+
)
|
|
282
314
|
|
|
283
315
|
def merge_image_patches(
|
|
284
316
|
self,
|
|
@@ -385,7 +417,19 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
|
|
|
385
417
|
|
|
386
418
|
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
|
387
419
|
|
|
388
|
-
|
|
420
|
+
if patch_offsets:
|
|
421
|
+
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
|
|
422
|
+
else:
|
|
423
|
+
output_logits = []
|
|
424
|
+
|
|
425
|
+
for idx in range(len(segmentation_logits)):
|
|
426
|
+
resized_logits = torch.nn.functional.interpolate(
|
|
427
|
+
segmentation_logits[idx].unsqueeze(dim=0),
|
|
428
|
+
size=target_sizes[idx],
|
|
429
|
+
mode="bilinear",
|
|
430
|
+
align_corners=False,
|
|
431
|
+
)
|
|
432
|
+
output_logits.append(resized_logits[0])
|
|
389
433
|
|
|
390
434
|
preds = [logit.argmax(dim=0) for logit in output_logits]
|
|
391
435
|
return preds
|
|
@@ -1020,6 +1020,13 @@ class EomtPreTrainedModel(PreTrainedModel):
|
|
|
1020
1020
|
elif isinstance(module, EomtEmbeddings):
|
|
1021
1021
|
init.trunc_normal_(module.cls_token, mean=0.0, std=std)
|
|
1022
1022
|
init.zeros_(module.register_tokens)
|
|
1023
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
1024
|
+
elif isinstance(module, EomtLoss):
|
|
1025
|
+
empty_weight = torch.ones(module.num_labels + 1)
|
|
1026
|
+
empty_weight[-1] = module.eos_coef
|
|
1027
|
+
init.copy_(module.empty_weight, empty_weight)
|
|
1028
|
+
elif isinstance(module, EomtForUniversalSegmentation):
|
|
1029
|
+
init.ones_(module.attn_mask_probs)
|
|
1023
1030
|
|
|
1024
1031
|
|
|
1025
1032
|
@auto_docstring(
|
|
@@ -425,6 +425,13 @@ class EomtPreTrainedModel(PreTrainedModel):
|
|
|
425
425
|
elif isinstance(module, EomtEmbeddings):
|
|
426
426
|
init.trunc_normal_(module.cls_token, mean=0.0, std=std)
|
|
427
427
|
init.zeros_(module.register_tokens)
|
|
428
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
429
|
+
elif isinstance(module, EomtLoss):
|
|
430
|
+
empty_weight = torch.ones(module.num_labels + 1)
|
|
431
|
+
empty_weight[-1] = module.eos_coef
|
|
432
|
+
init.copy_(module.empty_weight, empty_weight)
|
|
433
|
+
elif isinstance(module, EomtForUniversalSegmentation):
|
|
434
|
+
init.ones_(module.attn_mask_probs)
|
|
428
435
|
|
|
429
436
|
|
|
430
437
|
@auto_docstring(
|
|
@@ -113,6 +113,9 @@ class ErnieEmbeddings(nn.Module):
|
|
|
113
113
|
if inputs_embeds is None:
|
|
114
114
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
115
115
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
116
|
+
|
|
117
|
+
# .to is better than using _no_split_modules on ErnieEmbeddings as it's the first module and >1/2 the model size
|
|
118
|
+
inputs_embeds = inputs_embeds.to(token_type_embeddings.device)
|
|
116
119
|
embeddings = inputs_embeds + token_type_embeddings
|
|
117
120
|
|
|
118
121
|
position_embeddings = self.position_embeddings(position_ids)
|
|
@@ -553,6 +556,9 @@ class ErniePreTrainedModel(PreTrainedModel):
|
|
|
553
556
|
super()._init_weights(module)
|
|
554
557
|
if isinstance(module, ErnieLMPredictionHead):
|
|
555
558
|
init.zeros_(module.bias)
|
|
559
|
+
elif isinstance(module, ErnieEmbeddings):
|
|
560
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
561
|
+
init.zeros_(module.token_type_ids)
|
|
556
562
|
|
|
557
563
|
|
|
558
564
|
@auto_docstring(
|
|
@@ -107,6 +107,9 @@ class ErnieEmbeddings(BertEmbeddings):
|
|
|
107
107
|
if inputs_embeds is None:
|
|
108
108
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
109
109
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
110
|
+
|
|
111
|
+
# .to is better than using _no_split_modules on ErnieEmbeddings as it's the first module and >1/2 the model size
|
|
112
|
+
inputs_embeds = inputs_embeds.to(token_type_embeddings.device)
|
|
110
113
|
embeddings = inputs_embeds + token_type_embeddings
|
|
111
114
|
|
|
112
115
|
position_embeddings = self.position_embeddings(position_ids)
|
|
@@ -169,6 +172,9 @@ class ErniePreTrainedModel(PreTrainedModel):
|
|
|
169
172
|
super()._init_weights(module)
|
|
170
173
|
if isinstance(module, ErnieLMPredictionHead):
|
|
171
174
|
init.zeros_(module.bias)
|
|
175
|
+
elif isinstance(module, ErnieEmbeddings):
|
|
176
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
177
|
+
init.zeros_(module.token_type_ids)
|
|
172
178
|
|
|
173
179
|
|
|
174
180
|
class ErnieModel(BertModel):
|
|
@@ -27,7 +27,7 @@ from torch import nn
|
|
|
27
27
|
from ...activations import ACT2FN
|
|
28
28
|
from ...cache_utils import Cache, DynamicCache
|
|
29
29
|
from ...generation import GenerationMixin
|
|
30
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
30
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
31
31
|
from ...masking_utils import create_causal_mask
|
|
32
32
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
33
33
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
@@ -35,7 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
35
35
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
36
36
|
from ...processing_utils import Unpack
|
|
37
37
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
38
|
-
from ...utils.generic import check_model_inputs
|
|
38
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
39
39
|
from .configuration_ernie4_5 import Ernie4_5Config
|
|
40
40
|
|
|
41
41
|
|
|
@@ -56,7 +56,7 @@ class Ernie4_5RotaryEmbedding(nn.Module):
|
|
|
56
56
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
57
57
|
|
|
58
58
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
59
|
-
self.original_inv_freq =
|
|
59
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
60
60
|
|
|
61
61
|
@staticmethod
|
|
62
62
|
def compute_default_rope_parameters(
|
|
@@ -95,7 +95,7 @@ class Ernie4_5RotaryEmbedding(nn.Module):
|
|
|
95
95
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
96
96
|
|
|
97
97
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
98
|
-
with
|
|
98
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
99
99
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
100
100
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
101
101
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -203,6 +203,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
203
203
|
return q_embed.to(original_dtype), k_embed.to(original_dtype)
|
|
204
204
|
|
|
205
205
|
|
|
206
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
206
207
|
class Ernie4_5Attention(nn.Module):
|
|
207
208
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
208
209
|
|
|
@@ -221,7 +222,6 @@ class Ernie4_5Attention(nn.Module):
|
|
|
221
222
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
|
|
222
223
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
|
|
223
224
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
|
|
224
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
225
225
|
|
|
226
226
|
def forward(
|
|
227
227
|
self,
|
|
@@ -18,6 +18,7 @@ from torch import nn
|
|
|
18
18
|
|
|
19
19
|
from ...modeling_rope_utils import dynamic_rope_update
|
|
20
20
|
from ...utils import auto_docstring, can_return_tuple
|
|
21
|
+
from ...utils.generic import maybe_autocast
|
|
21
22
|
from ..glm.modeling_glm import rotate_half
|
|
22
23
|
from ..llama.modeling_llama import (
|
|
23
24
|
LlamaAttention,
|
|
@@ -36,7 +37,7 @@ class Ernie4_5RotaryEmbedding(OlmoRotaryEmbedding):
|
|
|
36
37
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
37
38
|
|
|
38
39
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
39
|
-
with
|
|
40
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
40
41
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
41
42
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
42
43
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -29,15 +29,15 @@ from ... import initialization as init
|
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
31
|
from ...generation import GenerationMixin
|
|
32
|
-
from ...integrations import use_kernel_forward_from_hub
|
|
32
|
+
from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
|
|
33
33
|
from ...masking_utils import create_causal_mask
|
|
34
34
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
35
35
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
36
36
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
37
37
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
38
38
|
from ...processing_utils import Unpack
|
|
39
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
40
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
39
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
40
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
41
41
|
from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
|
|
42
42
|
|
|
43
43
|
|
|
@@ -96,7 +96,7 @@ class Ernie4_5_MoeRotaryEmbedding(nn.Module):
|
|
|
96
96
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
97
97
|
|
|
98
98
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
99
|
-
self.original_inv_freq =
|
|
99
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
100
100
|
|
|
101
101
|
@staticmethod
|
|
102
102
|
def compute_default_rope_parameters(
|
|
@@ -135,7 +135,7 @@ class Ernie4_5_MoeRotaryEmbedding(nn.Module):
|
|
|
135
135
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
136
136
|
|
|
137
137
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
138
|
-
with
|
|
138
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
139
139
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
140
140
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
141
141
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -226,6 +226,7 @@ def eager_attention_forward(
|
|
|
226
226
|
return attn_output, attn_weights
|
|
227
227
|
|
|
228
228
|
|
|
229
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
229
230
|
class Ernie4_5_MoeAttention(nn.Module):
|
|
230
231
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
231
232
|
|
|
@@ -244,7 +245,6 @@ class Ernie4_5_MoeAttention(nn.Module):
|
|
|
244
245
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
|
|
245
246
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
|
|
246
247
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
|
|
247
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
248
248
|
|
|
249
249
|
def forward(
|
|
250
250
|
self,
|
|
@@ -317,6 +317,7 @@ class Ernie4_5_MoeStatics(nn.Module):
|
|
|
317
317
|
return hidden_states + self.e_score_correction_bias.squeeze()
|
|
318
318
|
|
|
319
319
|
|
|
320
|
+
@use_experts_implementation
|
|
320
321
|
class Ernie4_5_MoeExperts(nn.Module):
|
|
321
322
|
"""Collection of expert weights stored as 3D tensors."""
|
|
322
323
|
|
|
@@ -371,16 +372,16 @@ class Ernie4_5_MoeTopKRouter(nn.Module):
|
|
|
371
372
|
else "cpu"
|
|
372
373
|
)
|
|
373
374
|
|
|
374
|
-
with
|
|
375
|
-
router_logits = F.linear(hidden_states.float(), self.weight)
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
375
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
376
|
+
router_logits = F.linear(hidden_states.float(), self.weight.float())
|
|
377
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
378
|
+
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
|
|
379
|
+
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
|
380
|
+
routing_weights = routing_weights / torch.clamp(
|
|
381
|
+
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
|
380
382
|
)
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
return router_logits, router_scores, router_indices
|
|
383
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
384
|
+
return router_logits, selected_experts, routing_weights
|
|
384
385
|
|
|
385
386
|
|
|
386
387
|
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
|
@@ -403,7 +404,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
|
|
403
404
|
if self.shared_experts is not None:
|
|
404
405
|
shared_output = self.shared_experts(hidden_states)
|
|
405
406
|
|
|
406
|
-
_,
|
|
407
|
+
_, top_k_index, top_k_weights = self.gate(hidden_states)
|
|
407
408
|
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
|
|
408
409
|
|
|
409
410
|
if self.shared_experts is not None:
|
|
@@ -476,7 +477,9 @@ class Ernie4_5_MoePreTrainedModel(PreTrainedModel):
|
|
|
476
477
|
_supports_flash_attn = True
|
|
477
478
|
_supports_sdpa = True
|
|
478
479
|
_supports_flex_attn = True
|
|
479
|
-
_can_compile_fullgraph =
|
|
480
|
+
_can_compile_fullgraph = (
|
|
481
|
+
is_grouped_mm_available()
|
|
482
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
480
483
|
_supports_attention_backend = True
|
|
481
484
|
_can_record_outputs = {
|
|
482
485
|
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, index=0),
|
|
@@ -26,7 +26,7 @@ from ...modeling_outputs import MoeModelOutputWithPast
|
|
|
26
26
|
from ...modeling_utils import PreTrainedModel
|
|
27
27
|
from ...processing_utils import Unpack
|
|
28
28
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
29
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
29
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
30
30
|
from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401
|
|
31
31
|
from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm
|
|
32
32
|
from ..mixtral.modeling_mixtral import (
|
|
@@ -104,32 +104,6 @@ class Ernie4_5_MoeExperts(MixtralExperts):
|
|
|
104
104
|
self.num_experts = config.moe_num_experts
|
|
105
105
|
self.intermediate_dim = config.moe_intermediate_size
|
|
106
106
|
|
|
107
|
-
def forward(
|
|
108
|
-
self,
|
|
109
|
-
hidden_states: torch.Tensor,
|
|
110
|
-
top_k_index: torch.Tensor,
|
|
111
|
-
top_k_weights: torch.Tensor,
|
|
112
|
-
) -> torch.Tensor:
|
|
113
|
-
final_hidden_states = torch.zeros_like(hidden_states)
|
|
114
|
-
with torch.no_grad():
|
|
115
|
-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
|
|
116
|
-
expert_mask = expert_mask.permute(2, 1, 0)
|
|
117
|
-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
|
118
|
-
|
|
119
|
-
for expert_idx in expert_hit:
|
|
120
|
-
expert_idx = expert_idx[0]
|
|
121
|
-
if expert_idx == self.num_experts:
|
|
122
|
-
continue
|
|
123
|
-
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
|
124
|
-
current_state = hidden_states[token_idx]
|
|
125
|
-
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
|
|
126
|
-
current_hidden_states = self.act_fn(gate) * up
|
|
127
|
-
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
|
128
|
-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
|
|
129
|
-
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
|
130
|
-
|
|
131
|
-
return final_hidden_states
|
|
132
|
-
|
|
133
107
|
|
|
134
108
|
class Ernie4_5_MoeTopKRouter(nn.Module):
|
|
135
109
|
def __init__(self, config):
|
|
@@ -146,16 +120,16 @@ class Ernie4_5_MoeTopKRouter(nn.Module):
|
|
|
146
120
|
else "cpu"
|
|
147
121
|
)
|
|
148
122
|
|
|
149
|
-
with
|
|
150
|
-
router_logits = F.linear(hidden_states.float(), self.weight)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
123
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
124
|
+
router_logits = F.linear(hidden_states.float(), self.weight.float())
|
|
125
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
126
|
+
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
|
|
127
|
+
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
|
128
|
+
routing_weights = routing_weights / torch.clamp(
|
|
129
|
+
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
|
155
130
|
)
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
return router_logits, router_scores, router_indices
|
|
131
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
132
|
+
return router_logits, selected_experts, routing_weights
|
|
159
133
|
|
|
160
134
|
|
|
161
135
|
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
|
@@ -178,7 +152,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
|
|
178
152
|
if self.shared_experts is not None:
|
|
179
153
|
shared_output = self.shared_experts(hidden_states)
|
|
180
154
|
|
|
181
|
-
_,
|
|
155
|
+
_, top_k_index, top_k_weights = self.gate(hidden_states)
|
|
182
156
|
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
|
|
183
157
|
|
|
184
158
|
if self.shared_experts is not None:
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
from ...utils import _LazyModule
|
|
17
|
+
from ...utils.import_utils import define_import_structure
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from .configuration_ernie4_5_vl_moe import *
|
|
22
|
+
from .image_processing_ernie4_5_vl_moe import *
|
|
23
|
+
from .image_processing_ernie4_5_vl_moe_fast import *
|
|
24
|
+
from .modeling_ernie4_5_vl_moe import *
|
|
25
|
+
from .processing_ernie4_5_vl_moe import *
|
|
26
|
+
from .video_processing_ernie4_5_vl_moe import *
|
|
27
|
+
else:
|
|
28
|
+
import sys
|
|
29
|
+
|
|
30
|
+
_file = globals()["__file__"]
|
|
31
|
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|