transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +49 -3
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/cli/serve.py +47 -17
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +83 -7
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +374 -147
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +2 -3
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +55 -24
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +165 -124
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +228 -136
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +3 -14
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +16 -2
- transformers/integrations/accelerate.py +58 -113
- transformers/integrations/aqlm.py +36 -66
- transformers/integrations/awq.py +46 -515
- transformers/integrations/bitnet.py +47 -105
- transformers/integrations/bitsandbytes.py +91 -202
- transformers/integrations/deepspeed.py +18 -2
- transformers/integrations/eetq.py +84 -81
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +241 -208
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +37 -62
- transformers/integrations/hub_kernels.py +65 -8
- transformers/integrations/integration_utils.py +45 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +28 -74
- transformers/integrations/peft.py +12 -29
- transformers/integrations/quanto.py +77 -56
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +42 -90
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +40 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +74 -19
- transformers/modeling_rope_utils.py +107 -86
- transformers/modeling_utils.py +611 -527
- transformers/models/__init__.py +22 -0
- transformers/models/afmoe/modeling_afmoe.py +10 -19
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +14 -6
- transformers/models/altclip/modeling_altclip.py +11 -3
- transformers/models/apertus/modeling_apertus.py +8 -6
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +5 -5
- transformers/models/aria/modeling_aria.py +12 -8
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +38 -0
- transformers/models/auto/feature_extraction_auto.py +9 -3
- transformers/models/auto/image_processing_auto.py +5 -2
- transformers/models/auto/modeling_auto.py +37 -0
- transformers/models/auto/processing_auto.py +22 -10
- transformers/models/auto/tokenization_auto.py +147 -566
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +21 -21
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +11 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +14 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +9 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +16 -3
- transformers/models/bitnet/modeling_bitnet.py +5 -5
- transformers/models/blenderbot/modeling_blenderbot.py +12 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +10 -0
- transformers/models/blip_2/modeling_blip_2.py +4 -1
- transformers/models/bloom/modeling_bloom.py +17 -44
- transformers/models/blt/modeling_blt.py +164 -4
- transformers/models/blt/modular_blt.py +170 -5
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +11 -1
- transformers/models/bros/modeling_bros.py +12 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +11 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +11 -5
- transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +30 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +9 -0
- transformers/models/clvp/modeling_clvp.py +19 -3
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +5 -4
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +8 -7
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
- transformers/models/convbert/modeling_convbert.py +9 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +7 -4
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +15 -2
- transformers/models/cvt/modeling_cvt.py +7 -1
- transformers/models/cwm/modeling_cwm.py +5 -5
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +48 -39
- transformers/models/d_fine/modular_d_fine.py +16 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +5 -1
- transformers/models/dac/modeling_dac.py +6 -6
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +3 -3
- transformers/models/deberta/modeling_deberta.py +7 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
- transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +13 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +16 -4
- transformers/models/dia/modular_dia.py +11 -1
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +5 -5
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +3 -4
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +18 -12
- transformers/models/dots1/modeling_dots1.py +23 -11
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +6 -3
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +7 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +12 -6
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +60 -16
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +11 -5
- transformers/models/evolla/modeling_evolla.py +13 -5
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +3 -3
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +9 -4
- transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
- transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
- transformers/models/fast_vlm/__init__.py +27 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +21 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +10 -2
- transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
- transformers/models/florence2/modeling_florence2.py +22 -4
- transformers/models/florence2/modular_florence2.py +15 -1
- transformers/models/fnet/modeling_fnet.py +14 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +19 -3
- transformers/models/gemma/modeling_gemma.py +14 -16
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +5 -5
- transformers/models/gemma2/modular_gemma2.py +3 -2
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +42 -91
- transformers/models/gemma3/modular_gemma3.py +38 -87
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +65 -218
- transformers/models/gemma3n/modular_gemma3n.py +68 -68
- transformers/models/git/modeling_git.py +183 -126
- transformers/models/glm/modeling_glm.py +5 -5
- transformers/models/glm4/modeling_glm4.py +5 -5
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +18 -8
- transformers/models/glm4v/modular_glm4v.py +17 -7
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
- transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +13 -6
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
- transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
- transformers/models/gptj/modeling_gptj.py +18 -6
- transformers/models/granite/modeling_granite.py +5 -5
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +6 -9
- transformers/models/granitemoe/modular_granitemoe.py +1 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
- transformers/models/groupvit/modeling_groupvit.py +9 -1
- transformers/models/helium/modeling_helium.py +5 -4
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +7 -0
- transformers/models/hubert/modular_hubert.py +5 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +22 -0
- transformers/models/idefics/modeling_idefics.py +15 -21
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +11 -3
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +13 -12
- transformers/models/internvl/modular_internvl.py +7 -13
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +25 -20
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +16 -7
- transformers/models/janus/modular_janus.py +17 -7
- transformers/models/jetmoe/modeling_jetmoe.py +4 -4
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +15 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +248 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +730 -0
- transformers/models/lasr/modular_lasr.py +576 -0
- transformers/models/lasr/processing_lasr.py +94 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +10 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +12 -0
- transformers/models/levit/modeling_levit.py +21 -0
- transformers/models/lfm2/modeling_lfm2.py +5 -6
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +23 -15
- transformers/models/llama/modeling_llama.py +5 -5
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +11 -6
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
- transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -4
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +14 -0
- transformers/models/mamba/modeling_mamba.py +16 -23
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +8 -0
- transformers/models/markuplm/modeling_markuplm.py +9 -8
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +11 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +11 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +14 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +28 -5
- transformers/models/minimax/modeling_minimax.py +19 -6
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +5 -5
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +5 -4
- transformers/models/mistral/modeling_mistral.py +5 -4
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +15 -7
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +15 -4
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +7 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
- transformers/models/modernbert/modeling_modernbert.py +16 -2
- transformers/models/modernbert/modular_modernbert.py +14 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
- transformers/models/moonshine/modeling_moonshine.py +5 -3
- transformers/models/moshi/modeling_moshi.py +26 -53
- transformers/models/mpnet/modeling_mpnet.py +7 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +10 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +7 -10
- transformers/models/musicgen/modeling_musicgen.py +7 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
- transformers/models/mvp/modeling_mvp.py +14 -0
- transformers/models/nanochat/modeling_nanochat.py +5 -5
- transformers/models/nemotron/modeling_nemotron.py +7 -5
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +15 -68
- transformers/models/nystromformer/modeling_nystromformer.py +13 -0
- transformers/models/olmo/modeling_olmo.py +5 -5
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +5 -6
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +5 -5
- transformers/models/olmoe/modeling_olmoe.py +15 -7
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +11 -39
- transformers/models/openai/modeling_openai.py +15 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +11 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +11 -3
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +14 -6
- transformers/models/parakeet/modular_parakeet.py +7 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
- transformers/models/patchtst/modeling_patchtst.py +25 -6
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +8 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +13 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +3 -2
- transformers/models/phi/modeling_phi.py +5 -6
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +3 -2
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +15 -7
- transformers/models/phimoe/modular_phimoe.py +3 -3
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +3 -2
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +13 -0
- transformers/models/plbart/modular_plbart.py +8 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +13 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +5 -1
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +5 -5
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
- transformers/models/qwen3/modeling_qwen3.py +5 -5
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
- transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
- transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +8 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
- transformers/models/reformer/modeling_reformer.py +13 -1
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +10 -1
- transformers/models/rembert/modeling_rembert.py +13 -1
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +19 -5
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +6 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +2 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +7 -3
- transformers/models/sam2/modular_sam2.py +7 -3
- transformers/models/sam2_video/modeling_sam2_video.py +52 -43
- transformers/models/sam2_video/modular_sam2_video.py +32 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +100 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +4 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
- transformers/models/seed_oss/modeling_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +6 -3
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +67 -41
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +5 -5
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
- transformers/models/speecht5/modeling_speecht5.py +41 -1
- transformers/models/splinter/modeling_splinter.py +12 -3
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +8 -0
- transformers/models/stablelm/modeling_stablelm.py +4 -2
- transformers/models/starcoder2/modeling_starcoder2.py +5 -4
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +6 -0
- transformers/models/swin/modeling_swin.py +20 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +51 -33
- transformers/models/swinv2/modeling_swinv2.py +45 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +8 -7
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +6 -6
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +5 -1
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +14 -0
- transformers/models/timesfm/modular_timesfm.py +14 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
- transformers/models/trocr/modeling_trocr.py +3 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +6 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +7 -7
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +7 -6
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +13 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +8 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +5 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +11 -3
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +5 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +11 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +18 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +10 -1
- transformers/models/zamba/modeling_zamba.py +4 -1
- transformers/models/zamba2/modeling_zamba2.py +7 -4
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +8 -0
- transformers/pipelines/__init__.py +11 -9
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +2 -10
- transformers/pipelines/document_question_answering.py +4 -2
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +133 -50
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +44 -174
- transformers/quantizers/quantizer_aqlm.py +2 -23
- transformers/quantizers/quantizer_auto_round.py +2 -12
- transformers/quantizers/quantizer_awq.py +20 -89
- transformers/quantizers/quantizer_bitnet.py +4 -14
- transformers/quantizers/quantizer_bnb_4bit.py +18 -155
- transformers/quantizers/quantizer_bnb_8bit.py +24 -110
- transformers/quantizers/quantizer_compressed_tensors.py +2 -9
- transformers/quantizers/quantizer_eetq.py +16 -74
- transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
- transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
- transformers/quantizers/quantizer_fp_quant.py +52 -82
- transformers/quantizers/quantizer_gptq.py +8 -28
- transformers/quantizers/quantizer_higgs.py +42 -60
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +14 -194
- transformers/quantizers/quantizer_quanto.py +35 -79
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +4 -12
- transformers/quantizers/quantizer_torchao.py +50 -325
- transformers/quantizers/quantizer_vptq.py +4 -27
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +324 -47
- transformers/tokenization_mistral_common.py +7 -2
- transformers/tokenization_utils_base.py +116 -224
- transformers/tokenization_utils_tokenizers.py +190 -106
- transformers/trainer.py +51 -32
- transformers/trainer_callback.py +8 -0
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +74 -38
- transformers/utils/__init__.py +7 -4
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +35 -25
- transformers/utils/generic.py +47 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +112 -25
- transformers/utils/kernel_config.py +74 -19
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +78 -245
- transformers/video_processing_utils.py +17 -14
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -19,7 +19,7 @@ states before downsampling, which is different from the default Swin Transformer
|
|
|
19
19
|
import collections.abc
|
|
20
20
|
import math
|
|
21
21
|
from dataclasses import dataclass
|
|
22
|
-
from typing import Optional
|
|
22
|
+
from typing import Optional, Union
|
|
23
23
|
|
|
24
24
|
import torch
|
|
25
25
|
from torch import Tensor, nn
|
|
@@ -331,18 +331,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
|
|
331
331
|
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
|
332
332
|
)
|
|
333
333
|
|
|
334
|
-
|
|
335
|
-
coords_h = torch.arange(self.window_size[0])
|
|
336
|
-
coords_w = torch.arange(self.window_size[1])
|
|
337
|
-
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
|
338
|
-
coords_flatten = torch.flatten(coords, 1)
|
|
339
|
-
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
340
|
-
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
341
|
-
relative_coords[:, :, 0] += self.window_size[0] - 1
|
|
342
|
-
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
343
|
-
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
344
|
-
relative_position_index = relative_coords.sum(-1)
|
|
345
|
-
self.register_buffer("relative_position_index", relative_position_index)
|
|
334
|
+
self.register_buffer("relative_position_index", self.create_relative_position_index())
|
|
346
335
|
|
|
347
336
|
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
|
348
337
|
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
|
@@ -401,6 +390,20 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
|
|
401
390
|
|
|
402
391
|
return outputs
|
|
403
392
|
|
|
393
|
+
def create_relative_position_index(self):
|
|
394
|
+
# get pair-wise relative position index for each token inside the window
|
|
395
|
+
coords_h = torch.arange(self.window_size[0])
|
|
396
|
+
coords_w = torch.arange(self.window_size[1])
|
|
397
|
+
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
|
398
|
+
coords_flatten = torch.flatten(coords, 1)
|
|
399
|
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
400
|
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
401
|
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
|
402
|
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
403
|
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
404
|
+
relative_position_index = relative_coords.sum(-1)
|
|
405
|
+
return relative_position_index
|
|
406
|
+
|
|
404
407
|
|
|
405
408
|
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
|
|
406
409
|
class MaskFormerSwinSelfOutput(nn.Module):
|
|
@@ -656,7 +659,7 @@ class MaskFormerSwinEncoder(nn.Module):
|
|
|
656
659
|
output_attentions=False,
|
|
657
660
|
output_hidden_states=False,
|
|
658
661
|
return_dict=True,
|
|
659
|
-
):
|
|
662
|
+
) -> Union[tuple, MaskFormerSwinBaseModelOutput]:
|
|
660
663
|
all_hidden_states = () if output_hidden_states else None
|
|
661
664
|
all_input_dimensions = ()
|
|
662
665
|
all_self_attentions = () if output_attentions else None
|
|
@@ -711,6 +714,7 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
|
|
|
711
714
|
init.zeros_(module.position_embeddings)
|
|
712
715
|
elif isinstance(module, MaskFormerSwinSelfAttention):
|
|
713
716
|
init.zeros_(module.relative_position_bias_table)
|
|
717
|
+
init.copy_(module.relative_position_index, module.create_relative_position_index())
|
|
714
718
|
|
|
715
719
|
|
|
716
720
|
class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
|
@@ -738,7 +742,8 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
|
|
738
742
|
output_hidden_states=None,
|
|
739
743
|
interpolate_pos_encoding=False,
|
|
740
744
|
return_dict=None,
|
|
741
|
-
|
|
745
|
+
**kwargs,
|
|
746
|
+
) -> Union[tuple, MaskFormerSwinModelOutputWithPooling]:
|
|
742
747
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
743
748
|
output_hidden_states = (
|
|
744
749
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -815,6 +820,7 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
|
|
|
815
820
|
output_hidden_states: Optional[bool] = None,
|
|
816
821
|
output_attentions: Optional[bool] = None,
|
|
817
822
|
return_dict: Optional[bool] = None,
|
|
823
|
+
**kwargs,
|
|
818
824
|
) -> BackboneOutput:
|
|
819
825
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
820
826
|
output_hidden_states = (
|
|
@@ -147,6 +147,7 @@ class MBartConfig(PreTrainedConfig):
|
|
|
147
147
|
self.use_cache = use_cache
|
|
148
148
|
self.num_hidden_layers = encoder_layers
|
|
149
149
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
|
150
|
+
|
|
150
151
|
super().__init__(
|
|
151
152
|
pad_token_id=pad_token_id,
|
|
152
153
|
bos_token_id=bos_token_id,
|
|
@@ -22,6 +22,7 @@ import torch
|
|
|
22
22
|
from torch import nn
|
|
23
23
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
24
24
|
|
|
25
|
+
from ... import initialization as init
|
|
25
26
|
from ...activations import ACT2FN
|
|
26
27
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
27
28
|
from ...generation import GenerationMixin
|
|
@@ -478,6 +479,11 @@ class MBartPreTrainedModel(PreTrainedModel):
|
|
|
478
479
|
_supports_flex_attn = True
|
|
479
480
|
_can_compile_fullgraph = True
|
|
480
481
|
|
|
482
|
+
def _init_weights(self, module):
|
|
483
|
+
super()._init_weights(module)
|
|
484
|
+
if isinstance(module, MBartForConditionalGeneration):
|
|
485
|
+
init.zeros_(module.final_logits_bias)
|
|
486
|
+
|
|
481
487
|
@property
|
|
482
488
|
def dummy_inputs(self):
|
|
483
489
|
pad_token = self.config.pad_token_id
|
|
@@ -540,6 +546,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
|
|
540
546
|
output_attentions: Optional[bool] = None,
|
|
541
547
|
output_hidden_states: Optional[bool] = None,
|
|
542
548
|
return_dict: Optional[bool] = None,
|
|
549
|
+
**kwargs,
|
|
543
550
|
) -> Union[tuple, BaseModelOutput]:
|
|
544
551
|
r"""
|
|
545
552
|
Args:
|
|
@@ -691,6 +698,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
|
|
691
698
|
output_hidden_states: Optional[bool] = None,
|
|
692
699
|
return_dict: Optional[bool] = None,
|
|
693
700
|
cache_position: Optional[torch.Tensor] = None,
|
|
701
|
+
**kwargs,
|
|
694
702
|
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
|
695
703
|
r"""
|
|
696
704
|
Args:
|
|
@@ -919,6 +927,7 @@ class MBartModel(MBartPreTrainedModel):
|
|
|
919
927
|
output_hidden_states: Optional[bool] = None,
|
|
920
928
|
return_dict: Optional[bool] = None,
|
|
921
929
|
cache_position: Optional[torch.Tensor] = None,
|
|
930
|
+
**kwargs,
|
|
922
931
|
) -> Union[Seq2SeqModelOutput, tuple[torch.FloatTensor]]:
|
|
923
932
|
r"""
|
|
924
933
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1052,6 +1061,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin):
|
|
|
1052
1061
|
output_hidden_states: Optional[bool] = None,
|
|
1053
1062
|
return_dict: Optional[bool] = None,
|
|
1054
1063
|
cache_position: Optional[torch.Tensor] = None,
|
|
1064
|
+
**kwargs,
|
|
1055
1065
|
) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
|
|
1056
1066
|
r"""
|
|
1057
1067
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1205,6 +1215,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
|
|
1205
1215
|
output_hidden_states: Optional[bool] = None,
|
|
1206
1216
|
return_dict: Optional[bool] = None,
|
|
1207
1217
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1218
|
+
**kwargs,
|
|
1208
1219
|
) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
|
|
1209
1220
|
r"""
|
|
1210
1221
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1338,6 +1349,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
|
|
1338
1349
|
output_hidden_states: Optional[bool] = None,
|
|
1339
1350
|
return_dict: Optional[bool] = None,
|
|
1340
1351
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1352
|
+
**kwargs,
|
|
1341
1353
|
) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
|
1342
1354
|
r"""
|
|
1343
1355
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
|
@@ -1436,6 +1448,7 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
|
|
|
1436
1448
|
def __init__(self, config):
|
|
1437
1449
|
super().__init__(config)
|
|
1438
1450
|
self.decoder = MBartDecoder(config)
|
|
1451
|
+
self.post_init()
|
|
1439
1452
|
|
|
1440
1453
|
def forward(self, *args, **kwargs):
|
|
1441
1454
|
return self.decoder(*args, **kwargs)
|
|
@@ -1480,6 +1493,7 @@ class MBartForCausalLM(MBartPreTrainedModel, GenerationMixin):
|
|
|
1480
1493
|
return_dict: Optional[bool] = None,
|
|
1481
1494
|
cache_position: Optional[torch.LongTensor] = None,
|
|
1482
1495
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
1496
|
+
**kwargs,
|
|
1483
1497
|
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
|
1484
1498
|
r"""
|
|
1485
1499
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from typing import Optional
|
|
16
|
+
from typing import Optional, Union
|
|
17
17
|
|
|
18
18
|
from tokenizers import Tokenizer, decoders, pre_tokenizers, processors
|
|
19
19
|
from tokenizers.models import Unigram
|
|
@@ -58,13 +58,14 @@ class MBartTokenizer(TokenizersBackend):
|
|
|
58
58
|
|
|
59
59
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
60
60
|
model_input_names = ["input_ids", "attention_mask"]
|
|
61
|
-
|
|
61
|
+
model = Unigram
|
|
62
62
|
|
|
63
63
|
prefix_tokens: list[int] = []
|
|
64
64
|
suffix_tokens: list[int] = []
|
|
65
65
|
|
|
66
66
|
def __init__(
|
|
67
67
|
self,
|
|
68
|
+
vocab: Optional[Union[str, dict, list]] = None,
|
|
68
69
|
bos_token="<s>",
|
|
69
70
|
eos_token="</s>",
|
|
70
71
|
sep_token="</s>",
|
|
@@ -75,9 +76,6 @@ class MBartTokenizer(TokenizersBackend):
|
|
|
75
76
|
src_lang=None,
|
|
76
77
|
tgt_lang=None,
|
|
77
78
|
additional_special_tokens=None,
|
|
78
|
-
vocab=None,
|
|
79
|
-
merges=None, # Ignored for Unigram
|
|
80
|
-
vocab_file=None,
|
|
81
79
|
**kwargs,
|
|
82
80
|
):
|
|
83
81
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
|
@@ -88,56 +86,20 @@ class MBartTokenizer(TokenizersBackend):
|
|
|
88
86
|
[t for t in additional_special_tokens if t not in _additional_special_tokens]
|
|
89
87
|
)
|
|
90
88
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
# Handle different vocab formats (dict, list of tokens, or list of tuples)
|
|
94
|
-
# SentencePieceExtractor returns list[tuple[str, float]] which is the expected format
|
|
95
|
-
if isinstance(vocab, dict):
|
|
96
|
-
vocab = [(token, 0.0) for token in vocab.keys()]
|
|
97
|
-
elif isinstance(vocab, list) and len(vocab) > 0:
|
|
98
|
-
if not isinstance(vocab[0], tuple):
|
|
99
|
-
vocab = [(token, 0.0) for token in vocab]
|
|
100
|
-
else:
|
|
101
|
-
# Ensure tuples are (str, float) format
|
|
102
|
-
vocab = [(str(item[0]), float(item[1])) for item in vocab]
|
|
103
|
-
|
|
104
|
-
# Reorder to fairseq: <s>, <pad>, </s>, <unk>, ... (rest of vocab from SPM[3:])
|
|
105
|
-
vocab_list = []
|
|
106
|
-
vocab_list.append((str(bos_token), 0.0))
|
|
107
|
-
vocab_list.append((str(pad_token), 0.0))
|
|
108
|
-
vocab_list.append((str(eos_token), 0.0))
|
|
109
|
-
vocab_list.append((str(unk_token), 0.0))
|
|
110
|
-
|
|
111
|
-
# Add the rest of the SentencePiece vocab (skipping first 3: <unk>, <s>, </s>)
|
|
112
|
-
vocab_list.extend(vocab[4:])
|
|
113
|
-
|
|
114
|
-
# Add language codes
|
|
115
|
-
for lang_code in FAIRSEQ_LANGUAGE_CODES:
|
|
116
|
-
vocab_list.append((str(lang_code), 0.0))
|
|
117
|
-
|
|
118
|
-
# Add mask token
|
|
119
|
-
vocab_list.append((str(mask_token), 0.0))
|
|
120
|
-
|
|
121
|
-
self._vocab_scores = vocab_list
|
|
122
|
-
else:
|
|
123
|
-
self._vocab_scores = [
|
|
89
|
+
if vocab is None:
|
|
90
|
+
vocab = [
|
|
124
91
|
(str(bos_token), 0.0),
|
|
125
92
|
(str(pad_token), 0.0),
|
|
126
93
|
(str(eos_token), 0.0),
|
|
127
94
|
(str(unk_token), 0.0),
|
|
128
|
-
("▁", -2.0),
|
|
129
95
|
]
|
|
96
|
+
vocab += [("▁", -2.0)]
|
|
130
97
|
for lang_code in FAIRSEQ_LANGUAGE_CODES:
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
self.
|
|
135
|
-
|
|
136
|
-
self._vocab_scores,
|
|
137
|
-
unk_id=3,
|
|
138
|
-
byte_fallback=False,
|
|
139
|
-
)
|
|
140
|
-
)
|
|
98
|
+
vocab.append((lang_code, 0.0))
|
|
99
|
+
vocab.append((str(mask_token), 0.0))
|
|
100
|
+
|
|
101
|
+
self._vocab = vocab
|
|
102
|
+
self._tokenizer = Tokenizer(Unigram(self._vocab, unk_id=3, byte_fallback=False))
|
|
141
103
|
|
|
142
104
|
self._tokenizer.normalizer = None
|
|
143
105
|
|
|
@@ -150,10 +112,7 @@ class MBartTokenizer(TokenizersBackend):
|
|
|
150
112
|
|
|
151
113
|
self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
|
|
152
114
|
|
|
153
|
-
tokenizer_object = self._tokenizer
|
|
154
|
-
|
|
155
115
|
super().__init__(
|
|
156
|
-
tokenizer_object=tokenizer_object,
|
|
157
116
|
bos_token=bos_token,
|
|
158
117
|
eos_token=eos_token,
|
|
159
118
|
sep_token=sep_token,
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from typing import Optional
|
|
16
|
+
from typing import Optional, Union
|
|
17
17
|
|
|
18
18
|
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
|
19
19
|
from tokenizers.models import Unigram
|
|
@@ -79,13 +79,14 @@ class MBart50Tokenizer(TokenizersBackend):
|
|
|
79
79
|
|
|
80
80
|
vocab_files_names = VOCAB_FILES_NAMES
|
|
81
81
|
model_input_names = ["input_ids", "attention_mask"]
|
|
82
|
-
|
|
82
|
+
model = Unigram
|
|
83
83
|
|
|
84
84
|
prefix_tokens: list[int] = []
|
|
85
85
|
suffix_tokens: list[int] = []
|
|
86
86
|
|
|
87
87
|
def __init__(
|
|
88
88
|
self,
|
|
89
|
+
vocab: Optional[Union[str, dict, list]] = None,
|
|
89
90
|
src_lang=None,
|
|
90
91
|
tgt_lang=None,
|
|
91
92
|
eos_token="</s>",
|
|
@@ -94,21 +95,16 @@ class MBart50Tokenizer(TokenizersBackend):
|
|
|
94
95
|
unk_token="<unk>",
|
|
95
96
|
pad_token="<pad>",
|
|
96
97
|
mask_token="<mask>",
|
|
97
|
-
vocab=None,
|
|
98
|
-
merges=None, # Ignored for Unigram
|
|
99
|
-
vocab_file=None,
|
|
100
98
|
**kwargs,
|
|
101
99
|
):
|
|
102
100
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
|
103
101
|
|
|
104
|
-
self.vocab_file = vocab_file
|
|
105
|
-
|
|
106
102
|
# Do not pass language codes via extra_special_tokens to super().__init__.
|
|
107
103
|
# We will mark them as special AFTER backend construction to avoid re-adding tokens
|
|
108
104
|
# when loading from pretrained files.
|
|
109
105
|
|
|
110
106
|
# Always construct a tokenizer_object without referencing external tokenizer files
|
|
111
|
-
if vocab
|
|
107
|
+
if isinstance(vocab, list):
|
|
112
108
|
# MBart50 uses fairseq vocab alignment matching MBart50Converter:
|
|
113
109
|
# <s>=0, <pad>=1, </s>=2, <unk>=3, then tokens, lang codes, <mask>
|
|
114
110
|
|
|
@@ -180,9 +176,9 @@ class MBart50Tokenizer(TokenizersBackend):
|
|
|
180
176
|
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True)
|
|
181
177
|
|
|
182
178
|
self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
|
|
183
|
-
|
|
179
|
+
additional_special_tokens = kwargs.pop("additional_special_tokens", []) or []
|
|
180
|
+
additional_special_tokens.extend(FAIRSEQ_LANGUAGE_CODES)
|
|
184
181
|
super().__init__(
|
|
185
|
-
tokenizer_object=self._tokenizer,
|
|
186
182
|
src_lang=src_lang,
|
|
187
183
|
tgt_lang=tgt_lang,
|
|
188
184
|
eos_token=eos_token,
|
|
@@ -191,6 +187,7 @@ class MBart50Tokenizer(TokenizersBackend):
|
|
|
191
187
|
unk_token=unk_token,
|
|
192
188
|
pad_token=pad_token,
|
|
193
189
|
mask_token=mask_token,
|
|
190
|
+
additional_special_tokens=additional_special_tokens,
|
|
194
191
|
**kwargs,
|
|
195
192
|
)
|
|
196
193
|
|
|
@@ -528,6 +528,8 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
|
|
|
528
528
|
super()._init_weights(module)
|
|
529
529
|
if isinstance(module, MegatronBertLMPredictionHead):
|
|
530
530
|
init.zeros_(module.bias)
|
|
531
|
+
elif isinstance(module, MegatronBertEmbeddings):
|
|
532
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
531
533
|
|
|
532
534
|
|
|
533
535
|
@dataclass
|
|
@@ -608,6 +610,7 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
|
|
|
608
610
|
output_hidden_states: Optional[bool] = None,
|
|
609
611
|
return_dict: Optional[bool] = None,
|
|
610
612
|
cache_position: Optional[torch.Tensor] = None,
|
|
613
|
+
**kwargs,
|
|
611
614
|
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
612
615
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
613
616
|
output_hidden_states = (
|
|
@@ -735,6 +738,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
|
|
|
735
738
|
output_attentions: Optional[bool] = None,
|
|
736
739
|
output_hidden_states: Optional[bool] = None,
|
|
737
740
|
return_dict: Optional[bool] = None,
|
|
741
|
+
**kwargs,
|
|
738
742
|
) -> Union[tuple, MegatronBertForPreTrainingOutput]:
|
|
739
743
|
r"""
|
|
740
744
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -955,6 +959,7 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
|
|
|
955
959
|
output_attentions: Optional[bool] = None,
|
|
956
960
|
output_hidden_states: Optional[bool] = None,
|
|
957
961
|
return_dict: Optional[bool] = None,
|
|
962
|
+
**kwargs,
|
|
958
963
|
) -> Union[tuple, MaskedLMOutput]:
|
|
959
964
|
r"""
|
|
960
965
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1140,6 +1145,7 @@ class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel):
|
|
|
1140
1145
|
output_attentions: Optional[bool] = None,
|
|
1141
1146
|
output_hidden_states: Optional[bool] = None,
|
|
1142
1147
|
return_dict: Optional[bool] = None,
|
|
1148
|
+
**kwargs,
|
|
1143
1149
|
) -> Union[tuple, SequenceClassifierOutput]:
|
|
1144
1150
|
r"""
|
|
1145
1151
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -1223,6 +1229,7 @@ class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel):
|
|
|
1223
1229
|
output_attentions: Optional[bool] = None,
|
|
1224
1230
|
output_hidden_states: Optional[bool] = None,
|
|
1225
1231
|
return_dict: Optional[bool] = None,
|
|
1232
|
+
**kwargs,
|
|
1226
1233
|
) -> Union[tuple, MultipleChoiceModelOutput]:
|
|
1227
1234
|
r"""
|
|
1228
1235
|
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
|
@@ -1326,6 +1333,7 @@ class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
|
|
|
1326
1333
|
output_attentions: Optional[bool] = None,
|
|
1327
1334
|
output_hidden_states: Optional[bool] = None,
|
|
1328
1335
|
return_dict: Optional[bool] = None,
|
|
1336
|
+
**kwargs,
|
|
1329
1337
|
) -> Union[tuple, TokenClassifierOutput]:
|
|
1330
1338
|
r"""
|
|
1331
1339
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1391,6 +1399,7 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
|
|
|
1391
1399
|
output_attentions: Optional[bool] = None,
|
|
1392
1400
|
output_hidden_states: Optional[bool] = None,
|
|
1393
1401
|
return_dict: Optional[bool] = None,
|
|
1402
|
+
**kwargs,
|
|
1394
1403
|
) -> Union[tuple, QuestionAnsweringModelOutput]:
|
|
1395
1404
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1396
1405
|
|
|
@@ -306,11 +306,13 @@ class MetaClip2PreTrainedModel(PreTrainedModel):
|
|
|
306
306
|
if isinstance(module, MetaClip2TextEmbeddings):
|
|
307
307
|
init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
308
308
|
init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
309
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
309
310
|
elif isinstance(module, MetaClip2VisionEmbeddings):
|
|
310
311
|
factor = self.config.initializer_factor
|
|
311
312
|
init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
|
312
313
|
init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
|
313
314
|
init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
|
315
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
314
316
|
elif isinstance(module, MetaClip2Attention):
|
|
315
317
|
factor = self.config.initializer_factor
|
|
316
318
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
@@ -225,11 +225,13 @@ class MetaClip2PreTrainedModel(CLIPPreTrainedModel):
|
|
|
225
225
|
if isinstance(module, MetaClip2TextEmbeddings):
|
|
226
226
|
init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
227
227
|
init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
228
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
228
229
|
elif isinstance(module, MetaClip2VisionEmbeddings):
|
|
229
230
|
factor = self.config.initializer_factor
|
|
230
231
|
init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
|
231
232
|
init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
|
232
233
|
init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
|
234
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
233
235
|
elif isinstance(module, MetaClip2Attention):
|
|
234
236
|
factor = self.config.initializer_factor
|
|
235
237
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
@@ -322,6 +322,7 @@ class MgpstrModel(MgpstrPreTrainedModel):
|
|
|
322
322
|
output_attentions: Optional[bool] = None,
|
|
323
323
|
output_hidden_states: Optional[bool] = None,
|
|
324
324
|
return_dict: Optional[bool] = None,
|
|
325
|
+
**kwargs,
|
|
325
326
|
) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
|
|
326
327
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
327
328
|
output_hidden_states = (
|
|
@@ -385,6 +386,7 @@ class MgpstrForSceneTextRecognition(MgpstrPreTrainedModel):
|
|
|
385
386
|
output_a3_attentions: Optional[bool] = None,
|
|
386
387
|
output_hidden_states: Optional[bool] = None,
|
|
387
388
|
return_dict: Optional[bool] = None,
|
|
389
|
+
**kwargs,
|
|
388
390
|
) -> Union[tuple[torch.FloatTensor], MgpstrModelOutput]:
|
|
389
391
|
r"""
|
|
390
392
|
output_a3_attentions (`bool`, *optional*):
|
|
@@ -32,6 +32,7 @@ from ...modeling_outputs import BaseModelOutputWithPast
|
|
|
32
32
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
33
33
|
from ...modeling_utils import PreTrainedModel
|
|
34
34
|
from ...utils import ModelOutput, auto_docstring, logging
|
|
35
|
+
from ...utils.generic import maybe_autocast
|
|
35
36
|
from .configuration_mimi import MimiConfig
|
|
36
37
|
|
|
37
38
|
|
|
@@ -520,7 +521,7 @@ class MimiRotaryEmbedding(nn.Module):
|
|
|
520
521
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
521
522
|
|
|
522
523
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
523
|
-
self.original_inv_freq =
|
|
524
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
524
525
|
|
|
525
526
|
@staticmethod
|
|
526
527
|
def compute_default_rope_parameters(
|
|
@@ -559,7 +560,7 @@ class MimiRotaryEmbedding(nn.Module):
|
|
|
559
560
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
560
561
|
|
|
561
562
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
562
|
-
with
|
|
563
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
563
564
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
564
565
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
565
566
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -813,8 +814,8 @@ class MimiFlashAttention2(MimiAttention):
|
|
|
813
814
|
else torch.get_autocast_gpu_dtype()
|
|
814
815
|
)
|
|
815
816
|
# Handle the case where the model is quantized
|
|
816
|
-
elif hasattr(self.config, "
|
|
817
|
-
target_dtype = self.config.
|
|
817
|
+
elif hasattr(self.config, "quantization_config"):
|
|
818
|
+
target_dtype = self.config.dtype
|
|
818
819
|
else:
|
|
819
820
|
target_dtype = self.q_proj.weight.dtype
|
|
820
821
|
|
|
@@ -1379,7 +1380,7 @@ class MimiPreTrainedModel(PreTrainedModel):
|
|
|
1379
1380
|
main_input_name = "input_values"
|
|
1380
1381
|
input_modalities = "audio"
|
|
1381
1382
|
supports_gradient_checkpointing = True
|
|
1382
|
-
_no_split_modules = ["
|
|
1383
|
+
_no_split_modules = ["MimiResidualVectorQuantizer", "MimiTransformerLayer"]
|
|
1383
1384
|
_skip_keys_device_placement = "past_key_values"
|
|
1384
1385
|
_supports_flash_attn = True
|
|
1385
1386
|
_supports_sdpa = True
|
|
@@ -1403,6 +1404,27 @@ class MimiPreTrainedModel(PreTrainedModel):
|
|
|
1403
1404
|
init.uniform_(module.bias, a=-k, b=k)
|
|
1404
1405
|
elif isinstance(module, MimiLayerScale):
|
|
1405
1406
|
init.constant_(module.scale, self.config.layer_scale_initial_scale)
|
|
1407
|
+
elif isinstance(module, MimiConv1d):
|
|
1408
|
+
kernel_size = module.conv.kernel_size[0]
|
|
1409
|
+
stride = module.conv.stride[0]
|
|
1410
|
+
dilation = module.conv.dilation[0]
|
|
1411
|
+
kernel_size = (kernel_size - 1) * dilation + 1
|
|
1412
|
+
init.constant_(module.stride, stride)
|
|
1413
|
+
init.constant_(module.kernel_size, kernel_size)
|
|
1414
|
+
init.constant_(module.padding_total, kernel_size - stride)
|
|
1415
|
+
elif isinstance(module, MimiEuclideanCodebook):
|
|
1416
|
+
init.ones_(module.initialized)
|
|
1417
|
+
init.ones_(module.cluster_usage)
|
|
1418
|
+
init.zeros_(module.embed_sum)
|
|
1419
|
+
elif isinstance(module, MimiRotaryEmbedding):
|
|
1420
|
+
rope_fn = (
|
|
1421
|
+
ROPE_INIT_FUNCTIONS[module.rope_type]
|
|
1422
|
+
if module.rope_type != "default"
|
|
1423
|
+
else module.compute_default_rope_parameters
|
|
1424
|
+
)
|
|
1425
|
+
buffer_value, _ = rope_fn(module.config)
|
|
1426
|
+
init.copy_(module.inv_freq, buffer_value)
|
|
1427
|
+
init.copy_(module.original_inv_freq, buffer_value)
|
|
1406
1428
|
|
|
1407
1429
|
|
|
1408
1430
|
@auto_docstring(
|
|
@@ -1685,6 +1707,7 @@ class MimiModel(MimiPreTrainedModel):
|
|
|
1685
1707
|
encoder_past_key_values: Optional[Cache] = None,
|
|
1686
1708
|
decoder_past_key_values: Optional[Cache] = None,
|
|
1687
1709
|
return_dict: Optional[bool] = None,
|
|
1710
|
+
**kwargs,
|
|
1688
1711
|
) -> Union[tuple[torch.Tensor, torch.Tensor], MimiOutput]:
|
|
1689
1712
|
r"""
|
|
1690
1713
|
input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
|
|
@@ -31,7 +31,12 @@ from ... import initialization as init
|
|
|
31
31
|
from ...activations import ACT2FN
|
|
32
32
|
from ...cache_utils import Cache, DynamicCache
|
|
33
33
|
from ...generation import GenerationMixin
|
|
34
|
-
from ...integrations import
|
|
34
|
+
from ...integrations import (
|
|
35
|
+
use_experts_implementation,
|
|
36
|
+
use_kernel_forward_from_hub,
|
|
37
|
+
use_kernel_func_from_hub,
|
|
38
|
+
use_kernelized_func,
|
|
39
|
+
)
|
|
35
40
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
36
41
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
37
42
|
from ...modeling_layers import (
|
|
@@ -45,7 +50,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
45
50
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
46
51
|
from ...processing_utils import Unpack
|
|
47
52
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
48
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
53
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
49
54
|
from .configuration_minimax import MiniMaxConfig
|
|
50
55
|
|
|
51
56
|
|
|
@@ -271,7 +276,7 @@ class MiniMaxRotaryEmbedding(nn.Module):
|
|
|
271
276
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
272
277
|
|
|
273
278
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
274
|
-
self.original_inv_freq =
|
|
279
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
275
280
|
|
|
276
281
|
@staticmethod
|
|
277
282
|
def compute_default_rope_parameters(
|
|
@@ -310,7 +315,7 @@ class MiniMaxRotaryEmbedding(nn.Module):
|
|
|
310
315
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
311
316
|
|
|
312
317
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
313
|
-
with
|
|
318
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
314
319
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
315
320
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
316
321
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -392,6 +397,7 @@ def eager_attention_forward(
|
|
|
392
397
|
return attn_output, attn_weights
|
|
393
398
|
|
|
394
399
|
|
|
400
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
395
401
|
class MiniMaxAttention(nn.Module):
|
|
396
402
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
397
403
|
|
|
@@ -408,7 +414,6 @@ class MiniMaxAttention(nn.Module):
|
|
|
408
414
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
|
409
415
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
|
|
410
416
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
411
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
412
417
|
|
|
413
418
|
def forward(
|
|
414
419
|
self,
|
|
@@ -473,6 +478,7 @@ class MiniMaxTopKRouter(nn.Module):
|
|
|
473
478
|
return router_logits, router_scores, router_indices
|
|
474
479
|
|
|
475
480
|
|
|
481
|
+
@use_experts_implementation
|
|
476
482
|
class MiniMaxExperts(nn.Module):
|
|
477
483
|
"""Collection of expert weights stored as 3D tensors."""
|
|
478
484
|
|
|
@@ -596,7 +602,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
|
|
|
596
602
|
_supports_flash_attn = True
|
|
597
603
|
_supports_sdpa = True
|
|
598
604
|
_supports_flex_attn = True
|
|
599
|
-
_can_compile_fullgraph = False
|
|
605
|
+
_can_compile_fullgraph = False # uses a non-compilable custom cache class MiniMaxCache
|
|
600
606
|
_supports_attention_backend = True
|
|
601
607
|
_can_record_outputs = {
|
|
602
608
|
"router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0),
|
|
@@ -613,6 +619,13 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
|
|
|
613
619
|
init.normal_(module.down_proj, mean=0.0, std=std)
|
|
614
620
|
elif isinstance(module, MiniMaxTopKRouter):
|
|
615
621
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
622
|
+
if isinstance(module, MiniMaxLightningAttention):
|
|
623
|
+
slope_rate = module.get_slope_rate()
|
|
624
|
+
query_decay, key_decay, diagonal_decay = module.decay_factors(slope_rate)
|
|
625
|
+
init.copy_(module.slope_rate, slope_rate)
|
|
626
|
+
init.copy_(module.query_decay, query_decay)
|
|
627
|
+
init.copy_(module.key_decay, key_decay)
|
|
628
|
+
init.copy_(module.diagonal_decay, diagonal_decay)
|
|
616
629
|
|
|
617
630
|
|
|
618
631
|
@auto_docstring
|
|
@@ -21,6 +21,7 @@ import torch
|
|
|
21
21
|
import torch.nn.functional as F
|
|
22
22
|
from torch import nn
|
|
23
23
|
|
|
24
|
+
from ... import initialization as init
|
|
24
25
|
from ...activations import ACT2FN
|
|
25
26
|
from ...cache_utils import Cache, DynamicCache
|
|
26
27
|
from ...configuration_utils import PreTrainedConfig, layer_type_validation
|
|
@@ -520,13 +521,23 @@ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
|
|
|
520
521
|
|
|
521
522
|
|
|
522
523
|
class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
|
|
523
|
-
_can_compile_fullgraph = False
|
|
524
|
+
_can_compile_fullgraph = False # uses a non-compilable custom cache class MiniMaxCache
|
|
524
525
|
_can_record_outputs = {
|
|
525
526
|
"router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0),
|
|
526
527
|
"hidden_states": MiniMaxDecoderLayer,
|
|
527
528
|
"attentions": [MiniMaxAttention, MiniMaxLightningAttention],
|
|
528
529
|
}
|
|
529
530
|
|
|
531
|
+
def _init_weights(self, module):
|
|
532
|
+
super()._init_weights(module)
|
|
533
|
+
if isinstance(module, MiniMaxLightningAttention):
|
|
534
|
+
slope_rate = module.get_slope_rate()
|
|
535
|
+
query_decay, key_decay, diagonal_decay = module.decay_factors(slope_rate)
|
|
536
|
+
init.copy_(module.slope_rate, slope_rate)
|
|
537
|
+
init.copy_(module.query_decay, query_decay)
|
|
538
|
+
init.copy_(module.key_decay, key_decay)
|
|
539
|
+
init.copy_(module.diagonal_decay, diagonal_decay)
|
|
540
|
+
|
|
530
541
|
|
|
531
542
|
class MiniMaxModel(MixtralModel):
|
|
532
543
|
@check_model_inputs
|