transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -19,7 +19,7 @@ states before downsampling, which is different from the default Swin Transformer
|
|
|
19
19
|
import collections.abc
|
|
20
20
|
import math
|
|
21
21
|
from dataclasses import dataclass
|
|
22
|
-
from typing import Optional
|
|
22
|
+
from typing import Optional, Union
|
|
23
23
|
|
|
24
24
|
import torch
|
|
25
25
|
from torch import Tensor, nn
|
|
@@ -331,18 +331,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
|
|
331
331
|
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
|
332
332
|
)
|
|
333
333
|
|
|
334
|
-
|
|
335
|
-
coords_h = torch.arange(self.window_size[0])
|
|
336
|
-
coords_w = torch.arange(self.window_size[1])
|
|
337
|
-
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
|
338
|
-
coords_flatten = torch.flatten(coords, 1)
|
|
339
|
-
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
340
|
-
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
341
|
-
relative_coords[:, :, 0] += self.window_size[0] - 1
|
|
342
|
-
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
343
|
-
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
344
|
-
relative_position_index = relative_coords.sum(-1)
|
|
345
|
-
self.register_buffer("relative_position_index", relative_position_index)
|
|
334
|
+
self.register_buffer("relative_position_index", self.create_relative_position_index())
|
|
346
335
|
|
|
347
336
|
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
|
348
337
|
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
|
@@ -401,6 +390,20 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
|
|
401
390
|
|
|
402
391
|
return outputs
|
|
403
392
|
|
|
393
|
+
def create_relative_position_index(self):
|
|
394
|
+
# get pair-wise relative position index for each token inside the window
|
|
395
|
+
coords_h = torch.arange(self.window_size[0])
|
|
396
|
+
coords_w = torch.arange(self.window_size[1])
|
|
397
|
+
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
|
398
|
+
coords_flatten = torch.flatten(coords, 1)
|
|
399
|
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
400
|
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
401
|
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
|
402
|
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
403
|
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
404
|
+
relative_position_index = relative_coords.sum(-1)
|
|
405
|
+
return relative_position_index
|
|
406
|
+
|
|
404
407
|
|
|
405
408
|
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
|
|
406
409
|
class MaskFormerSwinSelfOutput(nn.Module):
|
|
@@ -656,7 +659,7 @@ class MaskFormerSwinEncoder(nn.Module):
|
|
|
656
659
|
output_attentions=False,
|
|
657
660
|
output_hidden_states=False,
|
|
658
661
|
return_dict=True,
|
|
659
|
-
):
|
|
662
|
+
) -> Union[tuple, MaskFormerSwinBaseModelOutput]:
|
|
660
663
|
all_hidden_states = () if output_hidden_states else None
|
|
661
664
|
all_input_dimensions = ()
|
|
662
665
|
all_self_attentions = () if output_attentions else None
|
|
@@ -711,6 +714,7 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
|
|
|
711
714
|
init.zeros_(module.position_embeddings)
|
|
712
715
|
elif isinstance(module, MaskFormerSwinSelfAttention):
|
|
713
716
|
init.zeros_(module.relative_position_bias_table)
|
|
717
|
+
init.copy_(module.relative_position_index, module.create_relative_position_index())
|
|
714
718
|
|
|
715
719
|
|
|
716
720
|
class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
|
@@ -739,7 +743,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
|
|
739
743
|
interpolate_pos_encoding=False,
|
|
740
744
|
return_dict=None,
|
|
741
745
|
**kwargs,
|
|
742
|
-
):
|
|
746
|
+
) -> Union[tuple, MaskFormerSwinModelOutputWithPooling]:
|
|
743
747
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
744
748
|
output_hidden_states = (
|
|
745
749
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -147,6 +147,7 @@ class MBartConfig(PreTrainedConfig):
|
|
|
147
147
|
self.use_cache = use_cache
|
|
148
148
|
self.num_hidden_layers = encoder_layers
|
|
149
149
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
|
150
|
+
|
|
150
151
|
super().__init__(
|
|
151
152
|
pad_token_id=pad_token_id,
|
|
152
153
|
bos_token_id=bos_token_id,
|
|
@@ -22,6 +22,7 @@ import torch
|
|
|
22
22
|
from torch import nn
|
|
23
23
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
24
24
|
|
|
25
|
+
from ... import initialization as init
|
|
25
26
|
from ...activations import ACT2FN
|
|
26
27
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
27
28
|
from ...generation import GenerationMixin
|
|
@@ -478,6 +479,11 @@ class MBartPreTrainedModel(PreTrainedModel):
|
|
|
478
479
|
_supports_flex_attn = True
|
|
479
480
|
_can_compile_fullgraph = True
|
|
480
481
|
|
|
482
|
+
def _init_weights(self, module):
|
|
483
|
+
super()._init_weights(module)
|
|
484
|
+
if isinstance(module, MBartForConditionalGeneration):
|
|
485
|
+
init.zeros_(module.final_logits_bias)
|
|
486
|
+
|
|
481
487
|
@property
|
|
482
488
|
def dummy_inputs(self):
|
|
483
489
|
pad_token = self.config.pad_token_id
|
|
@@ -1442,6 +1448,7 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
|
|
|
1442
1448
|
def __init__(self, config):
|
|
1443
1449
|
super().__init__(config)
|
|
1444
1450
|
self.decoder = MBartDecoder(config)
|
|
1451
|
+
self.post_init()
|
|
1445
1452
|
|
|
1446
1453
|
def forward(self, *args, **kwargs):
|
|
1447
1454
|
return self.decoder(*args, **kwargs)
|
|
@@ -528,6 +528,8 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
|
|
|
528
528
|
super()._init_weights(module)
|
|
529
529
|
if isinstance(module, MegatronBertLMPredictionHead):
|
|
530
530
|
init.zeros_(module.bias)
|
|
531
|
+
elif isinstance(module, MegatronBertEmbeddings):
|
|
532
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
531
533
|
|
|
532
534
|
|
|
533
535
|
@dataclass
|
|
@@ -306,11 +306,13 @@ class MetaClip2PreTrainedModel(PreTrainedModel):
|
|
|
306
306
|
if isinstance(module, MetaClip2TextEmbeddings):
|
|
307
307
|
init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
308
308
|
init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
309
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
309
310
|
elif isinstance(module, MetaClip2VisionEmbeddings):
|
|
310
311
|
factor = self.config.initializer_factor
|
|
311
312
|
init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
|
312
313
|
init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
|
313
314
|
init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
|
315
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
314
316
|
elif isinstance(module, MetaClip2Attention):
|
|
315
317
|
factor = self.config.initializer_factor
|
|
316
318
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
@@ -225,11 +225,13 @@ class MetaClip2PreTrainedModel(CLIPPreTrainedModel):
|
|
|
225
225
|
if isinstance(module, MetaClip2TextEmbeddings):
|
|
226
226
|
init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
227
227
|
init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
|
|
228
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
228
229
|
elif isinstance(module, MetaClip2VisionEmbeddings):
|
|
229
230
|
factor = self.config.initializer_factor
|
|
230
231
|
init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
|
231
232
|
init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
|
232
233
|
init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
|
234
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
233
235
|
elif isinstance(module, MetaClip2Attention):
|
|
234
236
|
factor = self.config.initializer_factor
|
|
235
237
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
@@ -521,7 +521,7 @@ class MimiRotaryEmbedding(nn.Module):
|
|
|
521
521
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
522
522
|
|
|
523
523
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
524
|
-
self.original_inv_freq =
|
|
524
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
525
525
|
|
|
526
526
|
@staticmethod
|
|
527
527
|
def compute_default_rope_parameters(
|
|
@@ -814,8 +814,8 @@ class MimiFlashAttention2(MimiAttention):
|
|
|
814
814
|
else torch.get_autocast_gpu_dtype()
|
|
815
815
|
)
|
|
816
816
|
# Handle the case where the model is quantized
|
|
817
|
-
elif hasattr(self.config, "
|
|
818
|
-
target_dtype = self.config.
|
|
817
|
+
elif hasattr(self.config, "quantization_config"):
|
|
818
|
+
target_dtype = self.config.dtype
|
|
819
819
|
else:
|
|
820
820
|
target_dtype = self.q_proj.weight.dtype
|
|
821
821
|
|
|
@@ -1380,7 +1380,7 @@ class MimiPreTrainedModel(PreTrainedModel):
|
|
|
1380
1380
|
main_input_name = "input_values"
|
|
1381
1381
|
input_modalities = "audio"
|
|
1382
1382
|
supports_gradient_checkpointing = True
|
|
1383
|
-
_no_split_modules = ["
|
|
1383
|
+
_no_split_modules = ["MimiResidualVectorQuantizer", "MimiTransformerLayer"]
|
|
1384
1384
|
_skip_keys_device_placement = "past_key_values"
|
|
1385
1385
|
_supports_flash_attn = True
|
|
1386
1386
|
_supports_sdpa = True
|
|
@@ -1404,6 +1404,27 @@ class MimiPreTrainedModel(PreTrainedModel):
|
|
|
1404
1404
|
init.uniform_(module.bias, a=-k, b=k)
|
|
1405
1405
|
elif isinstance(module, MimiLayerScale):
|
|
1406
1406
|
init.constant_(module.scale, self.config.layer_scale_initial_scale)
|
|
1407
|
+
elif isinstance(module, MimiConv1d):
|
|
1408
|
+
kernel_size = module.conv.kernel_size[0]
|
|
1409
|
+
stride = module.conv.stride[0]
|
|
1410
|
+
dilation = module.conv.dilation[0]
|
|
1411
|
+
kernel_size = (kernel_size - 1) * dilation + 1
|
|
1412
|
+
init.constant_(module.stride, stride)
|
|
1413
|
+
init.constant_(module.kernel_size, kernel_size)
|
|
1414
|
+
init.constant_(module.padding_total, kernel_size - stride)
|
|
1415
|
+
elif isinstance(module, MimiEuclideanCodebook):
|
|
1416
|
+
init.ones_(module.initialized)
|
|
1417
|
+
init.ones_(module.cluster_usage)
|
|
1418
|
+
init.zeros_(module.embed_sum)
|
|
1419
|
+
elif isinstance(module, MimiRotaryEmbedding):
|
|
1420
|
+
rope_fn = (
|
|
1421
|
+
ROPE_INIT_FUNCTIONS[module.rope_type]
|
|
1422
|
+
if module.rope_type != "default"
|
|
1423
|
+
else module.compute_default_rope_parameters
|
|
1424
|
+
)
|
|
1425
|
+
buffer_value, _ = rope_fn(module.config)
|
|
1426
|
+
init.copy_(module.inv_freq, buffer_value)
|
|
1427
|
+
init.copy_(module.original_inv_freq, buffer_value)
|
|
1407
1428
|
|
|
1408
1429
|
|
|
1409
1430
|
@auto_docstring(
|
|
@@ -31,7 +31,12 @@ from ... import initialization as init
|
|
|
31
31
|
from ...activations import ACT2FN
|
|
32
32
|
from ...cache_utils import Cache, DynamicCache
|
|
33
33
|
from ...generation import GenerationMixin
|
|
34
|
-
from ...integrations import
|
|
34
|
+
from ...integrations import (
|
|
35
|
+
use_experts_implementation,
|
|
36
|
+
use_kernel_forward_from_hub,
|
|
37
|
+
use_kernel_func_from_hub,
|
|
38
|
+
use_kernelized_func,
|
|
39
|
+
)
|
|
35
40
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
36
41
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
37
42
|
from ...modeling_layers import (
|
|
@@ -271,7 +276,7 @@ class MiniMaxRotaryEmbedding(nn.Module):
|
|
|
271
276
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
272
277
|
|
|
273
278
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
274
|
-
self.original_inv_freq =
|
|
279
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
275
280
|
|
|
276
281
|
@staticmethod
|
|
277
282
|
def compute_default_rope_parameters(
|
|
@@ -473,6 +478,7 @@ class MiniMaxTopKRouter(nn.Module):
|
|
|
473
478
|
return router_logits, router_scores, router_indices
|
|
474
479
|
|
|
475
480
|
|
|
481
|
+
@use_experts_implementation
|
|
476
482
|
class MiniMaxExperts(nn.Module):
|
|
477
483
|
"""Collection of expert weights stored as 3D tensors."""
|
|
478
484
|
|
|
@@ -596,7 +602,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
|
|
|
596
602
|
_supports_flash_attn = True
|
|
597
603
|
_supports_sdpa = True
|
|
598
604
|
_supports_flex_attn = True
|
|
599
|
-
_can_compile_fullgraph = False
|
|
605
|
+
_can_compile_fullgraph = False # uses a non-compilable custom cache class MiniMaxCache
|
|
600
606
|
_supports_attention_backend = True
|
|
601
607
|
_can_record_outputs = {
|
|
602
608
|
"router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0),
|
|
@@ -613,6 +619,13 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
|
|
|
613
619
|
init.normal_(module.down_proj, mean=0.0, std=std)
|
|
614
620
|
elif isinstance(module, MiniMaxTopKRouter):
|
|
615
621
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
622
|
+
if isinstance(module, MiniMaxLightningAttention):
|
|
623
|
+
slope_rate = module.get_slope_rate()
|
|
624
|
+
query_decay, key_decay, diagonal_decay = module.decay_factors(slope_rate)
|
|
625
|
+
init.copy_(module.slope_rate, slope_rate)
|
|
626
|
+
init.copy_(module.query_decay, query_decay)
|
|
627
|
+
init.copy_(module.key_decay, key_decay)
|
|
628
|
+
init.copy_(module.diagonal_decay, diagonal_decay)
|
|
616
629
|
|
|
617
630
|
|
|
618
631
|
@auto_docstring
|
|
@@ -21,6 +21,7 @@ import torch
|
|
|
21
21
|
import torch.nn.functional as F
|
|
22
22
|
from torch import nn
|
|
23
23
|
|
|
24
|
+
from ... import initialization as init
|
|
24
25
|
from ...activations import ACT2FN
|
|
25
26
|
from ...cache_utils import Cache, DynamicCache
|
|
26
27
|
from ...configuration_utils import PreTrainedConfig, layer_type_validation
|
|
@@ -520,13 +521,23 @@ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
|
|
|
520
521
|
|
|
521
522
|
|
|
522
523
|
class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
|
|
523
|
-
_can_compile_fullgraph = False
|
|
524
|
+
_can_compile_fullgraph = False # uses a non-compilable custom cache class MiniMaxCache
|
|
524
525
|
_can_record_outputs = {
|
|
525
526
|
"router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0),
|
|
526
527
|
"hidden_states": MiniMaxDecoderLayer,
|
|
527
528
|
"attentions": [MiniMaxAttention, MiniMaxLightningAttention],
|
|
528
529
|
}
|
|
529
530
|
|
|
531
|
+
def _init_weights(self, module):
|
|
532
|
+
super()._init_weights(module)
|
|
533
|
+
if isinstance(module, MiniMaxLightningAttention):
|
|
534
|
+
slope_rate = module.get_slope_rate()
|
|
535
|
+
query_decay, key_decay, diagonal_decay = module.decay_factors(slope_rate)
|
|
536
|
+
init.copy_(module.slope_rate, slope_rate)
|
|
537
|
+
init.copy_(module.query_decay, query_decay)
|
|
538
|
+
init.copy_(module.key_decay, key_decay)
|
|
539
|
+
init.copy_(module.diagonal_decay, diagonal_decay)
|
|
540
|
+
|
|
530
541
|
|
|
531
542
|
class MiniMaxModel(MixtralModel):
|
|
532
543
|
@check_model_inputs
|
|
@@ -289,7 +289,7 @@ class MinistralRotaryEmbedding(nn.Module):
|
|
|
289
289
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
290
290
|
|
|
291
291
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
292
|
-
self.original_inv_freq =
|
|
292
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
293
293
|
|
|
294
294
|
@staticmethod
|
|
295
295
|
def compute_default_rope_parameters(
|
|
@@ -295,7 +295,7 @@ class Ministral3RotaryEmbedding(nn.Module):
|
|
|
295
295
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
296
296
|
|
|
297
297
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
298
|
-
self.original_inv_freq =
|
|
298
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
299
299
|
|
|
300
300
|
@staticmethod
|
|
301
301
|
def compute_default_rope_parameters(
|
|
@@ -285,7 +285,7 @@ class MistralRotaryEmbedding(nn.Module):
|
|
|
285
285
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
286
286
|
|
|
287
287
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
288
|
-
self.original_inv_freq =
|
|
288
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
289
289
|
|
|
290
290
|
@staticmethod
|
|
291
291
|
def compute_default_rope_parameters(
|
|
@@ -252,7 +252,9 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
|
|
252
252
|
|
|
253
253
|
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
|
254
254
|
downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
|
|
255
|
-
split_sizes =
|
|
255
|
+
split_sizes = (
|
|
256
|
+
(torch.as_tensor(image_sizes, device=image_features.device) // downsample_ratio).prod(dim=-1).tolist()
|
|
257
|
+
)
|
|
256
258
|
image_features = torch.split(image_features.squeeze(0), split_sizes)
|
|
257
259
|
return image_features
|
|
258
260
|
|
|
@@ -489,6 +491,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
|
|
|
489
491
|
attention_mask=None,
|
|
490
492
|
cache_position=None,
|
|
491
493
|
logits_to_keep=None,
|
|
494
|
+
is_first_iteration=False,
|
|
492
495
|
**kwargs,
|
|
493
496
|
):
|
|
494
497
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -500,12 +503,15 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
|
|
|
500
503
|
attention_mask=attention_mask,
|
|
501
504
|
cache_position=cache_position,
|
|
502
505
|
logits_to_keep=logits_to_keep,
|
|
506
|
+
is_first_iteration=is_first_iteration,
|
|
503
507
|
**kwargs,
|
|
504
508
|
)
|
|
505
509
|
|
|
506
|
-
if
|
|
507
|
-
#
|
|
508
|
-
#
|
|
510
|
+
if is_first_iteration or not kwargs.get("use_cache", True):
|
|
511
|
+
# Pixel values are used only in the first iteration if available
|
|
512
|
+
# In subsquent iterations, they are already merged with text and cached
|
|
513
|
+
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
|
514
|
+
# iteration with a question and cached system prompt (continue generate from cache)
|
|
509
515
|
model_inputs["pixel_values"] = pixel_values
|
|
510
516
|
|
|
511
517
|
return model_inputs
|
|
@@ -157,7 +157,9 @@ class Mistral3Model(LlavaModel):
|
|
|
157
157
|
|
|
158
158
|
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
|
159
159
|
downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
|
|
160
|
-
split_sizes =
|
|
160
|
+
split_sizes = (
|
|
161
|
+
(torch.as_tensor(image_sizes, device=image_features.device) // downsample_ratio).prod(dim=-1).tolist()
|
|
162
|
+
)
|
|
161
163
|
image_features = torch.split(image_features.squeeze(0), split_sizes)
|
|
162
164
|
return image_features
|
|
163
165
|
|
|
@@ -37,7 +37,12 @@ from ... import initialization as init
|
|
|
37
37
|
from ...activations import ACT2FN
|
|
38
38
|
from ...cache_utils import Cache, DynamicCache
|
|
39
39
|
from ...generation import GenerationMixin
|
|
40
|
-
from ...integrations import
|
|
40
|
+
from ...integrations import (
|
|
41
|
+
use_experts_implementation,
|
|
42
|
+
use_kernel_forward_from_hub,
|
|
43
|
+
use_kernel_func_from_hub,
|
|
44
|
+
use_kernelized_func,
|
|
45
|
+
)
|
|
41
46
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
42
47
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
43
48
|
from ...modeling_layers import (
|
|
@@ -50,11 +55,12 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPas
|
|
|
50
55
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
51
56
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
52
57
|
from ...processing_utils import Unpack
|
|
53
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
58
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
54
59
|
from ...utils.generic import OutputRecorder, maybe_autocast
|
|
55
60
|
from .configuration_mixtral import MixtralConfig
|
|
56
61
|
|
|
57
62
|
|
|
63
|
+
@use_experts_implementation
|
|
58
64
|
class MixtralExperts(nn.Module):
|
|
59
65
|
"""Collection of expert weights stored as 3D tensors."""
|
|
60
66
|
|
|
@@ -169,7 +175,7 @@ class MixtralRotaryEmbedding(nn.Module):
|
|
|
169
175
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
170
176
|
|
|
171
177
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
172
|
-
self.original_inv_freq =
|
|
178
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
173
179
|
|
|
174
180
|
@staticmethod
|
|
175
181
|
def compute_default_rope_parameters(
|
|
@@ -403,7 +409,9 @@ class MixtralPreTrainedModel(PreTrainedModel):
|
|
|
403
409
|
_supports_flash_attn = True
|
|
404
410
|
_supports_sdpa = True
|
|
405
411
|
_supports_flex_attn = True
|
|
406
|
-
_can_compile_fullgraph =
|
|
412
|
+
_can_compile_fullgraph = (
|
|
413
|
+
is_grouped_mm_available()
|
|
414
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
407
415
|
_supports_attention_backend = True
|
|
408
416
|
_can_record_outputs = {
|
|
409
417
|
"router_logits": OutputRecorder(MixtralTopKRouter, index=0),
|
|
@@ -28,12 +28,13 @@ from torch import nn
|
|
|
28
28
|
from ... import initialization as init
|
|
29
29
|
from ...activations import ACT2FN
|
|
30
30
|
from ...cache_utils import Cache, DynamicCache
|
|
31
|
+
from ...integrations import use_experts_implementation
|
|
31
32
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
32
33
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
33
34
|
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
|
34
35
|
from ...modeling_utils import PreTrainedModel
|
|
35
36
|
from ...processing_utils import Unpack
|
|
36
|
-
from ...utils import TransformersKwargs, logging
|
|
37
|
+
from ...utils import TransformersKwargs, is_grouped_mm_available, logging
|
|
37
38
|
from ...utils.generic import OutputRecorder
|
|
38
39
|
from ..mistral.modeling_mistral import (
|
|
39
40
|
MistralAttention,
|
|
@@ -134,6 +135,7 @@ def load_balancing_loss_func(
|
|
|
134
135
|
return overall_loss * num_experts
|
|
135
136
|
|
|
136
137
|
|
|
138
|
+
@use_experts_implementation
|
|
137
139
|
class MixtralExperts(nn.Module):
|
|
138
140
|
"""Collection of expert weights stored as 3D tensors."""
|
|
139
141
|
|
|
@@ -263,7 +265,9 @@ class MixtralDecoderLayer(GradientCheckpointingLayer):
|
|
|
263
265
|
|
|
264
266
|
|
|
265
267
|
class MixtralPreTrainedModel(MistralPreTrainedModel):
|
|
266
|
-
_can_compile_fullgraph =
|
|
268
|
+
_can_compile_fullgraph = (
|
|
269
|
+
is_grouped_mm_available()
|
|
270
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
267
271
|
_can_record_outputs = {
|
|
268
272
|
"router_logits": OutputRecorder(MixtralTopKRouter, index=0),
|
|
269
273
|
"hidden_states": MixtralDecoderLayer,
|
|
@@ -55,6 +55,8 @@ class MLCDRotaryEmbedding(nn.Module):
|
|
|
55
55
|
|
|
56
56
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
57
57
|
super().__init__()
|
|
58
|
+
self.dim = dim
|
|
59
|
+
self.theta = theta
|
|
58
60
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
|
59
61
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
60
62
|
|
|
@@ -424,6 +426,7 @@ class MLCDPreTrainedModel(PreTrainedModel):
|
|
|
424
426
|
factor = self.config.initializer_factor
|
|
425
427
|
init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
|
426
428
|
init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
|
429
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
427
430
|
elif isinstance(module, MLCDAttention):
|
|
428
431
|
factor = self.config.initializer_factor
|
|
429
432
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
@@ -447,6 +450,9 @@ class MLCDPreTrainedModel(PreTrainedModel):
|
|
|
447
450
|
init.ones_(module.weight)
|
|
448
451
|
elif isinstance(module, nn.Linear) and module.bias is not None:
|
|
449
452
|
init.zeros_(module.bias)
|
|
453
|
+
elif isinstance(module, MLCDRotaryEmbedding):
|
|
454
|
+
inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
|
|
455
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
450
456
|
|
|
451
457
|
|
|
452
458
|
class MLCDVisionTransformer(nn.Module):
|
|
@@ -363,6 +363,7 @@ class MLCDPreTrainedModel(PreTrainedModel):
|
|
|
363
363
|
factor = self.config.initializer_factor
|
|
364
364
|
init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
|
365
365
|
init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
|
366
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
366
367
|
elif isinstance(module, MLCDAttention):
|
|
367
368
|
factor = self.config.initializer_factor
|
|
368
369
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
@@ -386,6 +387,9 @@ class MLCDPreTrainedModel(PreTrainedModel):
|
|
|
386
387
|
init.ones_(module.weight)
|
|
387
388
|
elif isinstance(module, nn.Linear) and module.bias is not None:
|
|
388
389
|
init.zeros_(module.bias)
|
|
390
|
+
elif isinstance(module, MLCDRotaryEmbedding):
|
|
391
|
+
inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
|
|
392
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
389
393
|
|
|
390
394
|
|
|
391
395
|
class MLCDVisionTransformer(CLIPVisionTransformer):
|
|
@@ -741,7 +741,7 @@ class MllamaRotaryEmbedding(nn.Module):
|
|
|
741
741
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
742
742
|
|
|
743
743
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
744
|
-
self.original_inv_freq =
|
|
744
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
745
745
|
|
|
746
746
|
@staticmethod
|
|
747
747
|
def compute_default_rope_parameters(
|
|
@@ -847,6 +847,15 @@ class MllamaPreTrainedModel(PreTrainedModel):
|
|
|
847
847
|
elif isinstance(module, MllamaPrecomputedAspectRatioEmbedding):
|
|
848
848
|
if module.is_gated:
|
|
849
849
|
init.zeros_(module.gate)
|
|
850
|
+
elif isinstance(module, MllamaRotaryEmbedding):
|
|
851
|
+
rope_fn = (
|
|
852
|
+
ROPE_INIT_FUNCTIONS[module.rope_type]
|
|
853
|
+
if module.rope_type != "default"
|
|
854
|
+
else module.compute_default_rope_parameters
|
|
855
|
+
)
|
|
856
|
+
buffer_value, _ = rope_fn(module.config)
|
|
857
|
+
init.copy_(module.inv_freq, buffer_value)
|
|
858
|
+
init.copy_(module.original_inv_freq, buffer_value)
|
|
850
859
|
|
|
851
860
|
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
|
852
861
|
def _update_causal_mask(
|
|
@@ -1721,6 +1730,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
|
|
1721
1730
|
use_cache=False,
|
|
1722
1731
|
cache_position=None,
|
|
1723
1732
|
logits_to_keep=None,
|
|
1733
|
+
is_first_iteration=False,
|
|
1724
1734
|
**kwargs,
|
|
1725
1735
|
):
|
|
1726
1736
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -1738,12 +1748,13 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
|
|
1738
1748
|
cross_attention_mask=cross_attention_mask,
|
|
1739
1749
|
cache_position=cache_position,
|
|
1740
1750
|
logits_to_keep=logits_to_keep,
|
|
1751
|
+
is_first_iteration=is_first_iteration,
|
|
1741
1752
|
**kwargs,
|
|
1742
1753
|
)
|
|
1743
1754
|
|
|
1744
1755
|
# If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
|
|
1745
1756
|
# to compute image hidden states, otherwise they are cached within each cross attn layer
|
|
1746
|
-
if
|
|
1757
|
+
if not is_first_iteration and use_cache:
|
|
1747
1758
|
model_inputs["pixel_values"] = None
|
|
1748
1759
|
model_inputs["aspect_ratio_ids"] = None
|
|
1749
1760
|
model_inputs["aspect_ratio_mask"] = None
|
|
@@ -38,7 +38,7 @@ class MMGroundingDinoConfig(PreTrainedConfig):
|
|
|
38
38
|
documentation from [`PreTrainedConfig`] for more information.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
|
-
backbone_config (`PreTrainedConfig
|
|
41
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `SwinConfig()`):
|
|
42
42
|
The configuration of the backbone model.
|
|
43
43
|
backbone (`str`, *optional*):
|
|
44
44
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -280,7 +280,6 @@ class MMGroundingDinoConfig(PreTrainedConfig):
|
|
|
280
280
|
self.layer_norm_eps = layer_norm_eps
|
|
281
281
|
|
|
282
282
|
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
|
283
|
-
self.tie_encoder_decoder = True
|
|
284
283
|
|
|
285
284
|
|
|
286
285
|
__all__ = ["MMGroundingDinoConfig"]
|
|
@@ -552,7 +552,7 @@ class MMGroundingDinoPreTrainedModel(PreTrainedModel):
|
|
|
552
552
|
elif isinstance(module, MMGroundingDinoFusionLayer):
|
|
553
553
|
init.constant_(module.vision_param, 1e-4)
|
|
554
554
|
init.constant_(module.text_param, 1e-4)
|
|
555
|
-
elif isinstance(module, (nn.Linear, nn.Conv2d
|
|
555
|
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
556
556
|
init.normal_(module.weight, mean=0.0, std=std)
|
|
557
557
|
if module.bias is not None:
|
|
558
558
|
init.zeros_(module.bias)
|
|
@@ -1181,7 +1181,7 @@ class MMGroundingDinoEncoder(MMGroundingDinoPreTrainedModel):
|
|
|
1181
1181
|
output_hidden_states=None,
|
|
1182
1182
|
return_dict=None,
|
|
1183
1183
|
**kwargs,
|
|
1184
|
-
):
|
|
1184
|
+
) -> Union[tuple, MMGroundingDinoEncoderOutput]:
|
|
1185
1185
|
r"""
|
|
1186
1186
|
Args:
|
|
1187
1187
|
vision_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
@@ -1478,7 +1478,7 @@ class MMGroundingDinoDecoder(MMGroundingDinoPreTrainedModel):
|
|
|
1478
1478
|
output_hidden_states=None,
|
|
1479
1479
|
return_dict=None,
|
|
1480
1480
|
**kwargs,
|
|
1481
|
-
):
|
|
1481
|
+
) -> Union[tuple, MMGroundingDinoDecoderOutput]:
|
|
1482
1482
|
r"""
|
|
1483
1483
|
Args:
|
|
1484
1484
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
@@ -1954,7 +1954,7 @@ class MMGroundingDinoModel(MMGroundingDinoPreTrainedModel):
|
|
|
1954
1954
|
output_hidden_states=None,
|
|
1955
1955
|
return_dict=None,
|
|
1956
1956
|
**kwargs,
|
|
1957
|
-
):
|
|
1957
|
+
) -> Union[tuple, MMGroundingDinoModelOutput]:
|
|
1958
1958
|
r"""
|
|
1959
1959
|
input_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`):
|
|
1960
1960
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
@@ -51,7 +51,7 @@ class MMGroundingDinoConfig(PreTrainedConfig):
|
|
|
51
51
|
documentation from [`PreTrainedConfig`] for more information.
|
|
52
52
|
|
|
53
53
|
Args:
|
|
54
|
-
backbone_config (`PreTrainedConfig
|
|
54
|
+
backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `SwinConfig()`):
|
|
55
55
|
The configuration of the backbone model.
|
|
56
56
|
backbone (`str`, *optional*):
|
|
57
57
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
|
@@ -293,7 +293,6 @@ class MMGroundingDinoConfig(PreTrainedConfig):
|
|
|
293
293
|
self.layer_norm_eps = layer_norm_eps
|
|
294
294
|
|
|
295
295
|
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
|
296
|
-
self.tie_encoder_decoder = True
|
|
297
296
|
|
|
298
297
|
|
|
299
298
|
class MMGroundingDinoContrastiveEmbedding(GroundingDinoContrastiveEmbedding):
|
|
@@ -556,6 +556,8 @@ class MobileBertPreTrainedModel(PreTrainedModel):
|
|
|
556
556
|
init.ones_(module.weight)
|
|
557
557
|
elif isinstance(module, MobileBertLMPredictionHead):
|
|
558
558
|
init.zeros_(module.bias)
|
|
559
|
+
elif isinstance(module, MobileBertEmbeddings):
|
|
560
|
+
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
559
561
|
|
|
560
562
|
|
|
561
563
|
@dataclass
|
|
@@ -180,7 +180,6 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
|
|
|
180
180
|
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
|
181
181
|
|
|
182
182
|
# Stack all processed images if return_tensors is specified
|
|
183
|
-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
|
184
183
|
|
|
185
184
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
|
186
185
|
|