transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -35,7 +35,12 @@ from ... import initialization as init
|
|
|
35
35
|
from ...activations import ACT2FN
|
|
36
36
|
from ...cache_utils import Cache, DynamicCache
|
|
37
37
|
from ...generation import GenerationMixin
|
|
38
|
-
from ...integrations import
|
|
38
|
+
from ...integrations import (
|
|
39
|
+
use_experts_implementation,
|
|
40
|
+
use_kernel_forward_from_hub,
|
|
41
|
+
use_kernel_func_from_hub,
|
|
42
|
+
use_kernelized_func,
|
|
43
|
+
)
|
|
39
44
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
40
45
|
from ...modeling_layers import (
|
|
41
46
|
GenericForQuestionAnswering,
|
|
@@ -47,7 +52,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPas
|
|
|
47
52
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
48
53
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
49
54
|
from ...processing_utils import Unpack
|
|
50
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
55
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
51
56
|
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
52
57
|
from .configuration_qwen2_moe import Qwen2MoeConfig
|
|
53
58
|
|
|
@@ -90,7 +95,7 @@ class Qwen2MoeRotaryEmbedding(nn.Module):
|
|
|
90
95
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
91
96
|
|
|
92
97
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
93
|
-
self.original_inv_freq =
|
|
98
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
94
99
|
|
|
95
100
|
@staticmethod
|
|
96
101
|
def compute_default_rope_parameters(
|
|
@@ -292,6 +297,7 @@ class Qwen2MoeAttention(nn.Module):
|
|
|
292
297
|
return attn_output, attn_weights
|
|
293
298
|
|
|
294
299
|
|
|
300
|
+
@use_experts_implementation
|
|
295
301
|
class Qwen2MoeExperts(nn.Module):
|
|
296
302
|
"""Collection of expert weights stored as 3D tensors."""
|
|
297
303
|
|
|
@@ -432,7 +438,9 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
|
|
|
432
438
|
_supports_flash_attn = True
|
|
433
439
|
_supports_sdpa = True
|
|
434
440
|
_supports_flex_attn = True
|
|
435
|
-
_can_compile_fullgraph =
|
|
441
|
+
_can_compile_fullgraph = (
|
|
442
|
+
is_grouped_mm_available()
|
|
443
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
436
444
|
_supports_attention_backend = True
|
|
437
445
|
_can_record_outputs = {
|
|
438
446
|
"router_logits": OutputRecorder(Qwen2MoeTopKRouter, index=0),
|
|
@@ -159,8 +159,9 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|
|
159
159
|
**kwargs,
|
|
160
160
|
) -> None:
|
|
161
161
|
super().__init__(**kwargs)
|
|
162
|
-
if size is not None
|
|
163
|
-
|
|
162
|
+
if size is not None:
|
|
163
|
+
if "shortest_edge" not in size or "longest_edge" not in size:
|
|
164
|
+
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
|
164
165
|
else:
|
|
165
166
|
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
|
|
166
167
|
# backward compatibility: override size with min_pixels and max_pixels if they are provided
|
|
@@ -28,6 +28,7 @@ import torch.nn as nn
|
|
|
28
28
|
import torch.nn.functional as F
|
|
29
29
|
from torch.nn import LayerNorm
|
|
30
30
|
|
|
31
|
+
from ... import initialization as init
|
|
31
32
|
from ...activations import ACT2FN
|
|
32
33
|
from ...cache_utils import Cache, DynamicCache
|
|
33
34
|
from ...generation import GenerationMixin
|
|
@@ -125,7 +126,7 @@ class Qwen2VLRotaryEmbedding(nn.Module):
|
|
|
125
126
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
126
127
|
|
|
127
128
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
128
|
-
self.original_inv_freq =
|
|
129
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
129
130
|
|
|
130
131
|
@staticmethod
|
|
131
132
|
def compute_default_rope_parameters(
|
|
@@ -246,6 +247,8 @@ class VisionRotaryEmbedding(nn.Module):
|
|
|
246
247
|
|
|
247
248
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
248
249
|
super().__init__()
|
|
250
|
+
self.dim = dim
|
|
251
|
+
self.theta = theta
|
|
249
252
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
|
250
253
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
251
254
|
|
|
@@ -384,8 +387,8 @@ class VisionAttention(nn.Module):
|
|
|
384
387
|
if self.config._attn_implementation != "eager":
|
|
385
388
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
386
389
|
|
|
387
|
-
if self.config._attn_implementation
|
|
388
|
-
# Flash Attention
|
|
390
|
+
if "flash" in self.config._attn_implementation:
|
|
391
|
+
# Flash Attention: Use cu_seqlens for variable length attention
|
|
389
392
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
390
393
|
attn_output, _ = attention_interface(
|
|
391
394
|
self,
|
|
@@ -665,6 +668,12 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
|
|
|
665
668
|
_can_compile_fullgraph = True
|
|
666
669
|
_supports_attention_backend = True
|
|
667
670
|
|
|
671
|
+
def _init_weights(self, module):
|
|
672
|
+
super()._init_weights(module)
|
|
673
|
+
if isinstance(module, VisionRotaryEmbedding):
|
|
674
|
+
inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
|
|
675
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
676
|
+
|
|
668
677
|
|
|
669
678
|
@auto_docstring
|
|
670
679
|
class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
|
@@ -693,6 +702,8 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
|
|
693
702
|
)
|
|
694
703
|
self.gradient_checkpointing = False
|
|
695
704
|
|
|
705
|
+
self.post_init()
|
|
706
|
+
|
|
696
707
|
def get_dtype(self) -> torch.dtype:
|
|
697
708
|
return self.blocks[0].mlp.fc2.weight.dtype
|
|
698
709
|
|
|
@@ -1416,6 +1427,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|
|
1416
1427
|
pixel_values_videos=None,
|
|
1417
1428
|
image_grid_thw=None,
|
|
1418
1429
|
video_grid_thw=None,
|
|
1430
|
+
is_first_iteration=False,
|
|
1419
1431
|
**kwargs,
|
|
1420
1432
|
):
|
|
1421
1433
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
@@ -1432,6 +1444,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|
|
1432
1444
|
image_grid_thw=image_grid_thw,
|
|
1433
1445
|
video_grid_thw=video_grid_thw,
|
|
1434
1446
|
use_cache=use_cache,
|
|
1447
|
+
is_first_iteration=is_first_iteration,
|
|
1435
1448
|
**kwargs,
|
|
1436
1449
|
)
|
|
1437
1450
|
|
|
@@ -1463,7 +1476,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|
|
1463
1476
|
text_positions = model_inputs["position_ids"][None, ...]
|
|
1464
1477
|
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
|
|
1465
1478
|
|
|
1466
|
-
if
|
|
1479
|
+
if not is_first_iteration and use_cache:
|
|
1467
1480
|
model_inputs["pixel_values"] = None
|
|
1468
1481
|
model_inputs["pixel_values_videos"] = None
|
|
1469
1482
|
|
|
@@ -100,7 +100,7 @@ class Qwen3RotaryEmbedding(nn.Module):
|
|
|
100
100
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
101
101
|
|
|
102
102
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
103
|
-
self.original_inv_freq =
|
|
103
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
104
104
|
|
|
105
105
|
@staticmethod
|
|
106
106
|
def compute_default_rope_parameters(
|
|
@@ -30,7 +30,12 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import
|
|
33
|
+
from ...integrations import (
|
|
34
|
+
use_experts_implementation,
|
|
35
|
+
use_kernel_forward_from_hub,
|
|
36
|
+
use_kernel_func_from_hub,
|
|
37
|
+
use_kernelized_func,
|
|
38
|
+
)
|
|
34
39
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
35
40
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
36
41
|
from ...modeling_layers import (
|
|
@@ -43,7 +48,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPas
|
|
|
43
48
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
44
49
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
45
50
|
from ...processing_utils import Unpack
|
|
46
|
-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
51
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
47
52
|
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
48
53
|
from .configuration_qwen3_moe import Qwen3MoeConfig
|
|
49
54
|
|
|
@@ -212,6 +217,7 @@ class Qwen3MoeMLP(nn.Module):
|
|
|
212
217
|
return down_proj
|
|
213
218
|
|
|
214
219
|
|
|
220
|
+
@use_experts_implementation
|
|
215
221
|
class Qwen3MoeExperts(nn.Module):
|
|
216
222
|
"""Collection of expert weights stored as 3D tensors."""
|
|
217
223
|
|
|
@@ -365,7 +371,9 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
|
|
|
365
371
|
_supports_flash_attn = True
|
|
366
372
|
_supports_sdpa = True
|
|
367
373
|
_supports_flex_attn = True
|
|
368
|
-
_can_compile_fullgraph =
|
|
374
|
+
_can_compile_fullgraph = (
|
|
375
|
+
is_grouped_mm_available()
|
|
376
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
369
377
|
_supports_attention_backend = True
|
|
370
378
|
_can_record_outputs = {
|
|
371
379
|
"router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.gate", index=0),
|
|
@@ -401,7 +409,7 @@ class Qwen3MoeRotaryEmbedding(nn.Module):
|
|
|
401
409
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
402
410
|
|
|
403
411
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
404
|
-
self.original_inv_freq =
|
|
412
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
405
413
|
|
|
406
414
|
@staticmethod
|
|
407
415
|
def compute_default_rope_parameters(
|
|
@@ -30,7 +30,7 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernelized_func
|
|
33
|
+
from ...integrations import use_experts_implementation, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
36
36
|
from ...modeling_layers import (
|
|
@@ -45,10 +45,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
|
45
45
|
from ...processing_utils import Unpack
|
|
46
46
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
47
47
|
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
48
|
-
from ...utils.import_utils import
|
|
49
|
-
is_causal_conv1d_available,
|
|
50
|
-
is_flash_linear_attention_available,
|
|
51
|
-
)
|
|
48
|
+
from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
|
|
52
49
|
from .configuration_qwen3_next import Qwen3NextConfig
|
|
53
50
|
|
|
54
51
|
|
|
@@ -192,7 +189,7 @@ class Qwen3NextRotaryEmbedding(nn.Module):
|
|
|
192
189
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
193
190
|
|
|
194
191
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
195
|
-
self.original_inv_freq =
|
|
192
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
196
193
|
|
|
197
194
|
@staticmethod
|
|
198
195
|
def compute_default_rope_parameters(
|
|
@@ -822,6 +819,7 @@ class Qwen3NextMLP(nn.Module):
|
|
|
822
819
|
return down_proj
|
|
823
820
|
|
|
824
821
|
|
|
822
|
+
@use_experts_implementation
|
|
825
823
|
class Qwen3NextExperts(nn.Module):
|
|
826
824
|
"""Collection of expert weights stored as 3D tensors."""
|
|
827
825
|
|
|
@@ -907,6 +907,7 @@ class Qwen3OmniMoeTalkerConfig(PreTrainedConfig):
|
|
|
907
907
|
self.audio_start_token_id = audio_start_token_id
|
|
908
908
|
self.vision_start_token_id = vision_start_token_id
|
|
909
909
|
self.speaker_id = speaker_id
|
|
910
|
+
self.initializer_range = self.text_config.initializer_range
|
|
910
911
|
super().__init__(**kwargs)
|
|
911
912
|
|
|
912
913
|
|
|
@@ -997,6 +998,7 @@ class Qwen3OmniMoeCode2WavConfig(PreTrainedConfig):
|
|
|
997
998
|
upsampling_ratios=(2, 2),
|
|
998
999
|
decoder_dim=1536,
|
|
999
1000
|
attention_dropout=0.0,
|
|
1001
|
+
initializer_range=0.02,
|
|
1000
1002
|
**kwargs,
|
|
1001
1003
|
):
|
|
1002
1004
|
self.codebook_size = codebook_size
|
|
@@ -1016,6 +1018,7 @@ class Qwen3OmniMoeCode2WavConfig(PreTrainedConfig):
|
|
|
1016
1018
|
self.upsampling_ratios = upsampling_ratios
|
|
1017
1019
|
self.decoder_dim = decoder_dim
|
|
1018
1020
|
self.attention_dropout = attention_dropout
|
|
1021
|
+
self.initializer_range = initializer_range
|
|
1019
1022
|
self.rope_parameters = rope_parameters
|
|
1020
1023
|
|
|
1021
1024
|
super().__init__(**kwargs)
|
|
@@ -1104,6 +1107,7 @@ class Qwen3OmniMoeConfig(PreTrainedConfig):
|
|
|
1104
1107
|
self.thinker_config = Qwen3OmniMoeThinkerConfig(**thinker_config)
|
|
1105
1108
|
self.talker_config = Qwen3OmniMoeTalkerConfig(**talker_config)
|
|
1106
1109
|
self.code2wav_config = Qwen3OmniMoeCode2WavConfig(**code2wav_config)
|
|
1110
|
+
self.initializer_range = self.thinker_config.initializer_range
|
|
1107
1111
|
self.enable_audio_output = enable_audio_output
|
|
1108
1112
|
self.im_start_token_id = im_start_token_id
|
|
1109
1113
|
self.im_end_token_id = im_end_token_id
|
|
@@ -35,7 +35,12 @@ from ... import initialization as init
|
|
|
35
35
|
from ...activations import ACT2FN
|
|
36
36
|
from ...cache_utils import Cache, DynamicCache
|
|
37
37
|
from ...generation import GenerationMixin
|
|
38
|
-
from ...integrations import
|
|
38
|
+
from ...integrations import (
|
|
39
|
+
use_experts_implementation,
|
|
40
|
+
use_kernel_forward_from_hub,
|
|
41
|
+
use_kernel_func_from_hub,
|
|
42
|
+
use_kernelized_func,
|
|
43
|
+
)
|
|
39
44
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
40
45
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
41
46
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -49,7 +54,7 @@ from ...modeling_outputs import (
|
|
|
49
54
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
50
55
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
51
56
|
from ...processing_utils import Unpack
|
|
52
|
-
from ...utils import auto_docstring, can_return_tuple
|
|
57
|
+
from ...utils import auto_docstring, can_return_tuple, is_grouped_mm_available
|
|
53
58
|
from ...utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs, maybe_autocast
|
|
54
59
|
from .configuration_qwen3_omni_moe import (
|
|
55
60
|
Qwen3OmniMoeAudioEncoderConfig,
|
|
@@ -64,6 +69,27 @@ from .configuration_qwen3_omni_moe import (
|
|
|
64
69
|
)
|
|
65
70
|
|
|
66
71
|
|
|
72
|
+
class SinusoidsPositionEmbedding(nn.Module):
|
|
73
|
+
def __init__(self, length, channels, max_timescale=10000):
|
|
74
|
+
super().__init__()
|
|
75
|
+
self.length = length
|
|
76
|
+
self.channels = channels
|
|
77
|
+
self.max_timescale = max_timescale
|
|
78
|
+
if channels % 2 != 0:
|
|
79
|
+
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
|
80
|
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
|
81
|
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
|
|
82
|
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
|
83
|
+
self.register_buffer(
|
|
84
|
+
"positional_embedding",
|
|
85
|
+
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
|
|
86
|
+
persistent=False,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def forward(self, seqlen: int):
|
|
90
|
+
return self.positional_embedding[:seqlen, :]
|
|
91
|
+
|
|
92
|
+
|
|
67
93
|
@auto_docstring
|
|
68
94
|
class Qwen3OmniMoePreTrainedModel(PreTrainedModel):
|
|
69
95
|
config: Qwen3OmniMoeConfig
|
|
@@ -85,6 +111,19 @@ class Qwen3OmniMoePreTrainedModel(PreTrainedModel):
|
|
|
85
111
|
init.normal_(module.experts.gate_up_proj, mean=0.0, std=std)
|
|
86
112
|
init.normal_(module.experts.down_proj, mean=0.0, std=std)
|
|
87
113
|
init.normal_(module.gate.weight, mean=0.0, std=std)
|
|
114
|
+
elif isinstance(module, Qwen3OmniMoeCode2Wav):
|
|
115
|
+
init.copy_(
|
|
116
|
+
module.code_offset,
|
|
117
|
+
torch.arange(module.config.num_quantizers).view(1, -1, 1) * module.config.codebook_size,
|
|
118
|
+
)
|
|
119
|
+
elif isinstance(module, SinusoidsPositionEmbedding):
|
|
120
|
+
log_timescale_increment = np.log(module.max_timescale) / (module.channels // 2 - 1)
|
|
121
|
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(module.channels // 2).float())
|
|
122
|
+
scaled_time = torch.arange(module.length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
|
123
|
+
init.copy_(module.positional_embedding, torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1))
|
|
124
|
+
elif isinstance(module, Qwen3OmniMoeVisionRotaryEmbedding):
|
|
125
|
+
inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
|
|
126
|
+
init.copy_(module.inv_freq, inv_freq)
|
|
88
127
|
|
|
89
128
|
|
|
90
129
|
def _get_feat_extract_output_lengths(input_lengths):
|
|
@@ -620,24 +659,6 @@ class Qwen3OmniMoeAudioEncoderLayer(GradientCheckpointingLayer):
|
|
|
620
659
|
return outputs
|
|
621
660
|
|
|
622
661
|
|
|
623
|
-
class SinusoidsPositionEmbedding(nn.Module):
|
|
624
|
-
def __init__(self, length, channels, max_timescale=10000):
|
|
625
|
-
super().__init__()
|
|
626
|
-
if channels % 2 != 0:
|
|
627
|
-
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
|
628
|
-
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
|
629
|
-
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
|
|
630
|
-
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
|
631
|
-
self.register_buffer(
|
|
632
|
-
"positional_embedding",
|
|
633
|
-
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
|
|
634
|
-
persistent=False,
|
|
635
|
-
)
|
|
636
|
-
|
|
637
|
-
def forward(self, seqlen: int):
|
|
638
|
-
return self.positional_embedding[:seqlen, :]
|
|
639
|
-
|
|
640
|
-
|
|
641
662
|
@auto_docstring(
|
|
642
663
|
custom_intro="""
|
|
643
664
|
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
|
@@ -891,8 +912,8 @@ class Qwen3OmniMoeVisionAttention(nn.Module):
|
|
|
891
912
|
if self.config._attn_implementation != "eager":
|
|
892
913
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
893
914
|
|
|
894
|
-
if self.config._attn_implementation
|
|
895
|
-
# Flash Attention
|
|
915
|
+
if "flash" in self.config._attn_implementation:
|
|
916
|
+
# Flash Attention: Use cu_seqlens for variable length attention
|
|
896
917
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
897
918
|
attn_output, _ = attention_interface(
|
|
898
919
|
self,
|
|
@@ -960,6 +981,22 @@ class Qwen3OmniMoeVisionPatchMerger(nn.Module):
|
|
|
960
981
|
return hidden
|
|
961
982
|
|
|
962
983
|
|
|
984
|
+
class Qwen3OmniMoeVisionRotaryEmbedding(nn.Module):
|
|
985
|
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
|
986
|
+
|
|
987
|
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
988
|
+
super().__init__()
|
|
989
|
+
self.dim = dim
|
|
990
|
+
self.theta = theta
|
|
991
|
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
|
992
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
993
|
+
|
|
994
|
+
def forward(self, seqlen: int) -> torch.Tensor:
|
|
995
|
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
|
996
|
+
freqs = torch.outer(seq, self.inv_freq)
|
|
997
|
+
return freqs
|
|
998
|
+
|
|
999
|
+
|
|
963
1000
|
class Qwen3OmniMoeVisionMLP(nn.Module):
|
|
964
1001
|
def __init__(self, config):
|
|
965
1002
|
super().__init__()
|
|
@@ -993,20 +1030,6 @@ class Qwen3OmniMoeVisionPatchEmbed(nn.Module):
|
|
|
993
1030
|
return hidden_states
|
|
994
1031
|
|
|
995
1032
|
|
|
996
|
-
class Qwen3OmniMoeVisionRotaryEmbedding(nn.Module):
|
|
997
|
-
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
|
998
|
-
|
|
999
|
-
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
1000
|
-
super().__init__()
|
|
1001
|
-
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
|
1002
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
1003
|
-
|
|
1004
|
-
def forward(self, seqlen: int) -> torch.Tensor:
|
|
1005
|
-
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
|
1006
|
-
freqs = torch.outer(seq, self.inv_freq)
|
|
1007
|
-
return freqs
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
1033
|
class Qwen3OmniMoeVisionBlock(GradientCheckpointingLayer):
|
|
1011
1034
|
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
|
1012
1035
|
super().__init__()
|
|
@@ -1073,6 +1096,8 @@ class Qwen3OmniMoeVisionEncoder(Qwen3OmniMoePreTrainedModel):
|
|
|
1073
1096
|
|
|
1074
1097
|
self.gradient_checkpointing = False
|
|
1075
1098
|
|
|
1099
|
+
self.post_init()
|
|
1100
|
+
|
|
1076
1101
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
|
1077
1102
|
merge_size = self.spatial_merge_size
|
|
1078
1103
|
|
|
@@ -1246,7 +1271,7 @@ class Qwen3OmniMoeThinkerTextRotaryEmbedding(nn.Module):
|
|
|
1246
1271
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
1247
1272
|
|
|
1248
1273
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
1249
|
-
self.original_inv_freq =
|
|
1274
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
1250
1275
|
|
|
1251
1276
|
self.mrope_section = config.rope_parameters.get("mrope_section", [24, 20, 20])
|
|
1252
1277
|
|
|
@@ -1318,6 +1343,7 @@ class Qwen3OmniMoeThinkerTextRotaryEmbedding(nn.Module):
|
|
|
1318
1343
|
return freqs_t
|
|
1319
1344
|
|
|
1320
1345
|
|
|
1346
|
+
@use_experts_implementation
|
|
1321
1347
|
class Qwen3OmniMoeThinkerTextExperts(nn.Module):
|
|
1322
1348
|
"""
|
|
1323
1349
|
ModuleList of experts.
|
|
@@ -1596,7 +1622,9 @@ class Qwen3OmniMoeThinkerTextPreTrainedModel(PreTrainedModel):
|
|
|
1596
1622
|
_supports_flash_attn = True
|
|
1597
1623
|
_supports_sdpa = True
|
|
1598
1624
|
_supports_flex_attn = True
|
|
1599
|
-
_can_compile_fullgraph =
|
|
1625
|
+
_can_compile_fullgraph = (
|
|
1626
|
+
is_grouped_mm_available()
|
|
1627
|
+
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
|
|
1600
1628
|
_supports_attention_backend = True
|
|
1601
1629
|
_can_record_outputs = {
|
|
1602
1630
|
"router_logits": OutputRecorder(Qwen3OmniMoeThinkerTextTopKRouter, layer_name="mlp.gate", index=0),
|
|
@@ -2248,6 +2276,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|
|
2248
2276
|
feature_attention_mask=None,
|
|
2249
2277
|
use_audio_in_video=False,
|
|
2250
2278
|
video_second_per_grid=None,
|
|
2279
|
+
is_first_iteration=False,
|
|
2251
2280
|
**kwargs,
|
|
2252
2281
|
):
|
|
2253
2282
|
model_inputs = super().prepare_inputs_for_generation(
|
|
@@ -2266,12 +2295,13 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|
|
2266
2295
|
feature_attention_mask=feature_attention_mask,
|
|
2267
2296
|
use_audio_in_video=use_audio_in_video,
|
|
2268
2297
|
video_second_per_grid=video_second_per_grid,
|
|
2298
|
+
is_first_iteration=is_first_iteration,
|
|
2269
2299
|
**kwargs,
|
|
2270
2300
|
)
|
|
2271
2301
|
|
|
2272
2302
|
model_inputs["position_ids"] = None
|
|
2273
2303
|
|
|
2274
|
-
if
|
|
2304
|
+
if not is_first_iteration and use_cache:
|
|
2275
2305
|
model_inputs["pixel_values"] = None
|
|
2276
2306
|
model_inputs["pixel_values_videos"] = None
|
|
2277
2307
|
model_inputs["input_features"] = None
|
|
@@ -2477,7 +2507,7 @@ class Qwen3OmniMoeRotaryEmbedding(nn.Module):
|
|
|
2477
2507
|
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
|
2478
2508
|
|
|
2479
2509
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
2480
|
-
self.original_inv_freq =
|
|
2510
|
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
|
2481
2511
|
|
|
2482
2512
|
@staticmethod
|
|
2483
2513
|
def compute_default_rope_parameters(
|
|
@@ -2745,6 +2775,7 @@ class Qwen3OmniMoeTalkerTextMLP(nn.Module):
|
|
|
2745
2775
|
return down_proj
|
|
2746
2776
|
|
|
2747
2777
|
|
|
2778
|
+
@use_experts_implementation
|
|
2748
2779
|
class Qwen3OmniMoeTalkerTextExperts(nn.Module):
|
|
2749
2780
|
"""Collection of expert weights stored as 3D tensors."""
|
|
2750
2781
|
|
|
@@ -3020,9 +3051,9 @@ class Qwen3OmniMoeTalkerModel(Qwen3OmniMoePreTrainedModel):
|
|
|
3020
3051
|
|
|
3021
3052
|
@auto_docstring
|
|
3022
3053
|
class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrainedModel, GenerationMixin):
|
|
3023
|
-
_tied_weights_keys = {"
|
|
3024
|
-
_tp_plan = {"
|
|
3025
|
-
_pp_plan = {"
|
|
3054
|
+
_tied_weights_keys = {"codec_head": "model.codec_embedding.weight"}
|
|
3055
|
+
_tp_plan = {"codec_head": "colwise_rep"}
|
|
3056
|
+
_pp_plan = {"codec_head": (["hidden_states"], ["logits"])}
|
|
3026
3057
|
config_class = Qwen3OmniMoeTalkerConfig
|
|
3027
3058
|
base_model_prefix = "talker"
|
|
3028
3059
|
_no_split_modules = ["Qwen3OmniMoeTalkerCodePredictorModelForConditionalGeneration"]
|
|
@@ -3213,18 +3244,31 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
|
|
|
3213
3244
|
return model_kwargs
|
|
3214
3245
|
|
|
3215
3246
|
def prepare_inputs_for_generation(
|
|
3216
|
-
self,
|
|
3247
|
+
self,
|
|
3248
|
+
input_ids,
|
|
3249
|
+
past_key_values=None,
|
|
3250
|
+
attention_mask=None,
|
|
3251
|
+
inputs_embeds=None,
|
|
3252
|
+
cache_position=None,
|
|
3253
|
+
is_first_iteration=False,
|
|
3254
|
+
**kwargs,
|
|
3217
3255
|
):
|
|
3218
3256
|
hidden_states = kwargs.pop("hidden_states", None)
|
|
3219
3257
|
inputs = super().prepare_inputs_for_generation(
|
|
3220
|
-
input_ids,
|
|
3258
|
+
input_ids,
|
|
3259
|
+
past_key_values,
|
|
3260
|
+
attention_mask,
|
|
3261
|
+
inputs_embeds,
|
|
3262
|
+
cache_position,
|
|
3263
|
+
is_first_iteration=is_first_iteration,
|
|
3264
|
+
**kwargs,
|
|
3221
3265
|
)
|
|
3222
3266
|
|
|
3223
3267
|
# Qwen3-Omni will prepare position ids in forward with deltas
|
|
3224
3268
|
inputs["position_ids"] = None
|
|
3225
3269
|
|
|
3226
3270
|
# TODO(raushan, gante): Refactor this part to a utility function
|
|
3227
|
-
if
|
|
3271
|
+
if not is_first_iteration and kwargs.get("use_cache", True):
|
|
3228
3272
|
input_ids = input_ids[:, -1:]
|
|
3229
3273
|
generation_step = kwargs.get("generation_step")
|
|
3230
3274
|
trailing_text_hidden = kwargs.get("trailing_text_hidden")
|
|
@@ -3716,6 +3760,8 @@ class Qwen3OmniMoeCode2WavDecoderBlock(Qwen3OmniMoePreTrainedModel):
|
|
|
3716
3760
|
|
|
3717
3761
|
self.block = nn.ModuleList(block)
|
|
3718
3762
|
|
|
3763
|
+
self.post_init()
|
|
3764
|
+
|
|
3719
3765
|
def forward(self, hidden, **kwargs):
|
|
3720
3766
|
for block in self.block:
|
|
3721
3767
|
hidden = block(hidden)
|