transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -35,7 +35,7 @@ from ... import initialization as init
|
|
|
35
35
|
from ...activations import ACT2FN
|
|
36
36
|
from ...cache_utils import Cache, DynamicCache
|
|
37
37
|
from ...generation import GenerationMixin
|
|
38
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
38
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
39
39
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
40
40
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
41
41
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -50,7 +50,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
50
50
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
51
51
|
from ...processing_utils import Unpack
|
|
52
52
|
from ...utils import auto_docstring, can_return_tuple
|
|
53
|
-
from ...utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
|
|
53
|
+
from ...utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs, maybe_autocast
|
|
54
54
|
from .configuration_qwen3_omni_moe import (
|
|
55
55
|
Qwen3OmniMoeAudioEncoderConfig,
|
|
56
56
|
Qwen3OmniMoeCode2WavConfig,
|
|
@@ -716,6 +716,7 @@ class Qwen3OmniMoeAudioEncoder(Qwen3OmniMoePreTrainedModel):
|
|
|
716
716
|
input_features,
|
|
717
717
|
feature_lens=None,
|
|
718
718
|
aftercnn_lens=None,
|
|
719
|
+
**kwargs,
|
|
719
720
|
):
|
|
720
721
|
r"""
|
|
721
722
|
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
|
|
@@ -1290,7 +1291,7 @@ class Qwen3OmniMoeThinkerTextRotaryEmbedding(nn.Module):
|
|
|
1290
1291
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
1291
1292
|
|
|
1292
1293
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
1293
|
-
with
|
|
1294
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
1294
1295
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
1295
1296
|
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
|
|
1296
1297
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
@@ -1442,6 +1443,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
1442
1443
|
return q_embed, k_embed
|
|
1443
1444
|
|
|
1444
1445
|
|
|
1446
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
1445
1447
|
class Qwen3OmniMoeThinkerTextAttention(nn.Module):
|
|
1446
1448
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
1447
1449
|
|
|
@@ -1467,7 +1469,6 @@ class Qwen3OmniMoeThinkerTextAttention(nn.Module):
|
|
|
1467
1469
|
self.o_proj = nn.Linear(
|
|
1468
1470
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
1469
1471
|
)
|
|
1470
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
1471
1472
|
self.q_norm = Qwen3OmniMoeThinkerTextRMSNorm(
|
|
1472
1473
|
self.head_dim, eps=config.rms_norm_eps
|
|
1473
1474
|
) # unlike olmo, only on the head dim!
|
|
@@ -2165,11 +2166,8 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|
|
2165
2166
|
audio_feature_lengths = None
|
|
2166
2167
|
|
|
2167
2168
|
if attention_mask is not None and position_ids is None:
|
|
2168
|
-
if (
|
|
2169
|
-
|
|
2170
|
-
or (cache_position is not None and cache_position[0] == 0)
|
|
2171
|
-
or self.rope_deltas is None
|
|
2172
|
-
):
|
|
2169
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
2170
|
+
if past_key_values_length == 0 or self.rope_deltas is None:
|
|
2173
2171
|
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
|
|
2174
2172
|
position_ids, rope_deltas = self.get_rope_index(
|
|
2175
2173
|
input_ids,
|
|
@@ -2184,7 +2182,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|
|
2184
2182
|
self.rope_deltas = rope_deltas
|
|
2185
2183
|
else:
|
|
2186
2184
|
batch_size, seq_length = input_ids.shape
|
|
2187
|
-
delta =
|
|
2185
|
+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
|
|
2188
2186
|
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
2189
2187
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
2190
2188
|
position_ids = position_ids.add(delta)
|
|
@@ -2323,6 +2321,7 @@ class Qwen3OmniMoeRMSNorm(nn.Module):
|
|
|
2323
2321
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
2324
2322
|
|
|
2325
2323
|
|
|
2324
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
2326
2325
|
class Qwen3OmniMoeTalkerCodePredictorAttention(nn.Module):
|
|
2327
2326
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
2328
2327
|
|
|
@@ -2349,7 +2348,6 @@ class Qwen3OmniMoeTalkerCodePredictorAttention(nn.Module):
|
|
|
2349
2348
|
self.o_proj = nn.Linear(
|
|
2350
2349
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
2351
2350
|
)
|
|
2352
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
2353
2351
|
self.q_norm = Qwen3OmniMoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
|
2354
2352
|
self.k_norm = Qwen3OmniMoeRMSNorm(
|
|
2355
2353
|
self.head_dim, eps=config.rms_norm_eps
|
|
@@ -2518,7 +2516,7 @@ class Qwen3OmniMoeRotaryEmbedding(nn.Module):
|
|
|
2518
2516
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
2519
2517
|
|
|
2520
2518
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
2521
|
-
with
|
|
2519
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
2522
2520
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
2523
2521
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
2524
2522
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -3103,12 +3101,9 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
|
|
|
3103
3101
|
if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
|
|
3104
3102
|
generation_step = -1
|
|
3105
3103
|
residual_codes = None
|
|
3106
|
-
if
|
|
3107
|
-
if (
|
|
3108
|
-
|
|
3109
|
-
or (cache_position is not None and cache_position[0] == 0)
|
|
3110
|
-
or self.rope_deltas is None
|
|
3111
|
-
):
|
|
3104
|
+
if position_ids is None:
|
|
3105
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
3106
|
+
if past_key_values_length == 0 or self.rope_deltas is None:
|
|
3112
3107
|
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
|
|
3113
3108
|
position_ids, rope_deltas = self.get_rope_index(
|
|
3114
3109
|
talker_input_ids,
|
|
@@ -3123,7 +3118,7 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
|
|
|
3123
3118
|
self.rope_deltas = rope_deltas
|
|
3124
3119
|
else:
|
|
3125
3120
|
batch_size, seq_length = input_ids.shape
|
|
3126
|
-
delta =
|
|
3121
|
+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
|
|
3127
3122
|
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
3128
3123
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
3129
3124
|
position_ids = position_ids.add(delta)
|
|
@@ -3224,7 +3219,10 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
|
|
|
3224
3219
|
inputs = super().prepare_inputs_for_generation(
|
|
3225
3220
|
input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
|
|
3226
3221
|
)
|
|
3227
|
-
|
|
3222
|
+
|
|
3223
|
+
# Qwen3-Omni will prepare position ids in forward with deltas
|
|
3224
|
+
inputs["position_ids"] = None
|
|
3225
|
+
|
|
3228
3226
|
# TODO(raushan, gante): Refactor this part to a utility function
|
|
3229
3227
|
if cache_position[0] != 0:
|
|
3230
3228
|
input_ids = input_ids[:, -1:]
|
|
@@ -3352,6 +3350,7 @@ class Qwen3OmniMoeConvNeXtBlock(nn.Module):
|
|
|
3352
3350
|
return hidden_states
|
|
3353
3351
|
|
|
3354
3352
|
|
|
3353
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
3355
3354
|
class Qwen3OmniMoeCode2WavAttention(nn.Module):
|
|
3356
3355
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
3357
3356
|
|
|
@@ -3378,7 +3377,6 @@ class Qwen3OmniMoeCode2WavAttention(nn.Module):
|
|
|
3378
3377
|
self.o_proj = nn.Linear(
|
|
3379
3378
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
3380
3379
|
)
|
|
3381
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
3382
3380
|
self.q_norm = nn.Identity()
|
|
3383
3381
|
self.k_norm = nn.Identity()
|
|
3384
3382
|
self.sliding_window = config.sliding_window
|
|
@@ -3718,7 +3716,7 @@ class Qwen3OmniMoeCode2WavDecoderBlock(Qwen3OmniMoePreTrainedModel):
|
|
|
3718
3716
|
|
|
3719
3717
|
self.block = nn.ModuleList(block)
|
|
3720
3718
|
|
|
3721
|
-
def forward(self, hidden):
|
|
3719
|
+
def forward(self, hidden, **kwargs):
|
|
3722
3720
|
for block in self.block:
|
|
3723
3721
|
hidden = block(hidden)
|
|
3724
3722
|
return hidden
|
|
@@ -3760,7 +3758,7 @@ class Qwen3OmniMoeCode2Wav(Qwen3OmniMoePreTrainedModel):
|
|
|
3760
3758
|
|
|
3761
3759
|
self.post_init()
|
|
3762
3760
|
|
|
3763
|
-
def forward(self, codes):
|
|
3761
|
+
def forward(self, codes, **kwargs):
|
|
3764
3762
|
if codes.shape[1] != self.config.num_quantizers:
|
|
3765
3763
|
raise ValueError(f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}")
|
|
3766
3764
|
hidden = self.code_embedding(codes + self.code_offset).mean(1)
|
|
@@ -1205,6 +1205,7 @@ class Qwen3OmniMoeAudioEncoder(Qwen2_5OmniAudioEncoder):
|
|
|
1205
1205
|
input_features,
|
|
1206
1206
|
feature_lens=None,
|
|
1207
1207
|
aftercnn_lens=None,
|
|
1208
|
+
**kwargs,
|
|
1208
1209
|
):
|
|
1209
1210
|
aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
|
|
1210
1211
|
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
|
|
@@ -1521,11 +1522,8 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen2_5OmniThinkerForCondition
|
|
|
1521
1522
|
audio_feature_lengths = None
|
|
1522
1523
|
|
|
1523
1524
|
if attention_mask is not None and position_ids is None:
|
|
1524
|
-
if (
|
|
1525
|
-
|
|
1526
|
-
or (cache_position is not None and cache_position[0] == 0)
|
|
1527
|
-
or self.rope_deltas is None
|
|
1528
|
-
):
|
|
1525
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
1526
|
+
if past_key_values_length == 0 or self.rope_deltas is None:
|
|
1529
1527
|
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
|
|
1530
1528
|
position_ids, rope_deltas = self.get_rope_index(
|
|
1531
1529
|
input_ids,
|
|
@@ -1540,7 +1538,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen2_5OmniThinkerForCondition
|
|
|
1540
1538
|
self.rope_deltas = rope_deltas
|
|
1541
1539
|
else:
|
|
1542
1540
|
batch_size, seq_length = input_ids.shape
|
|
1543
|
-
delta =
|
|
1541
|
+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
|
|
1544
1542
|
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
1545
1543
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
1546
1544
|
position_ids = position_ids.add(delta)
|
|
@@ -1961,12 +1959,9 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3MoeForCausalLM):
|
|
|
1961
1959
|
if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
|
|
1962
1960
|
generation_step = -1
|
|
1963
1961
|
residual_codes = None
|
|
1964
|
-
if
|
|
1965
|
-
if (
|
|
1966
|
-
|
|
1967
|
-
or (cache_position is not None and cache_position[0] == 0)
|
|
1968
|
-
or self.rope_deltas is None
|
|
1969
|
-
):
|
|
1962
|
+
if position_ids is None:
|
|
1963
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
1964
|
+
if past_key_values_length == 0 or self.rope_deltas is None:
|
|
1970
1965
|
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
|
|
1971
1966
|
position_ids, rope_deltas = self.get_rope_index(
|
|
1972
1967
|
talker_input_ids,
|
|
@@ -1981,7 +1976,7 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3MoeForCausalLM):
|
|
|
1981
1976
|
self.rope_deltas = rope_deltas
|
|
1982
1977
|
else:
|
|
1983
1978
|
batch_size, seq_length = input_ids.shape
|
|
1984
|
-
delta =
|
|
1979
|
+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
|
|
1985
1980
|
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
1986
1981
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
1987
1982
|
position_ids = position_ids.add(delta)
|
|
@@ -2044,7 +2039,10 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3MoeForCausalLM):
|
|
|
2044
2039
|
inputs = super().prepare_inputs_for_generation(
|
|
2045
2040
|
input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
|
|
2046
2041
|
)
|
|
2047
|
-
|
|
2042
|
+
|
|
2043
|
+
# Qwen3-Omni will prepare position ids in forward with deltas
|
|
2044
|
+
inputs["position_ids"] = None
|
|
2045
|
+
|
|
2048
2046
|
# TODO(raushan, gante): Refactor this part to a utility function
|
|
2049
2047
|
if cache_position[0] != 0:
|
|
2050
2048
|
input_ids = input_ids[:, -1:]
|
|
@@ -2339,7 +2337,7 @@ class Qwen3OmniMoeCode2WavDecoderBlock(Qwen3OmniMoePreTrainedModel):
|
|
|
2339
2337
|
|
|
2340
2338
|
self.block = nn.ModuleList(block)
|
|
2341
2339
|
|
|
2342
|
-
def forward(self, hidden):
|
|
2340
|
+
def forward(self, hidden, **kwargs):
|
|
2343
2341
|
for block in self.block:
|
|
2344
2342
|
hidden = block(hidden)
|
|
2345
2343
|
return hidden
|
|
@@ -2381,7 +2379,7 @@ class Qwen3OmniMoeCode2Wav(Qwen3OmniMoePreTrainedModel):
|
|
|
2381
2379
|
|
|
2382
2380
|
self.post_init()
|
|
2383
2381
|
|
|
2384
|
-
def forward(self, codes):
|
|
2382
|
+
def forward(self, codes, **kwargs):
|
|
2385
2383
|
if codes.shape[1] != self.config.num_quantizers:
|
|
2386
2384
|
raise ValueError(f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}")
|
|
2387
2385
|
hidden = self.code_embedding(codes + self.code_offset).mean(1)
|
|
@@ -30,7 +30,7 @@ import torch.nn.functional as F
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
33
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask
|
|
35
35
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
36
36
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
@@ -38,8 +38,8 @@ from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
|
|
38
38
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
39
39
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
40
40
|
from ...processing_utils import Unpack
|
|
41
|
-
from ...utils import TransformersKwargs, auto_docstring,
|
|
42
|
-
from ...utils.generic import check_model_inputs
|
|
41
|
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
42
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
43
43
|
from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig
|
|
44
44
|
|
|
45
45
|
|
|
@@ -337,7 +337,7 @@ class Qwen3VLTextRotaryEmbedding(nn.Module):
|
|
|
337
337
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
338
338
|
|
|
339
339
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
340
|
-
with
|
|
340
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
341
341
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
342
342
|
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
|
|
343
343
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
@@ -413,6 +413,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
413
413
|
return q_embed, k_embed
|
|
414
414
|
|
|
415
415
|
|
|
416
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
416
417
|
class Qwen3VLTextAttention(nn.Module):
|
|
417
418
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
418
419
|
|
|
@@ -439,7 +440,6 @@ class Qwen3VLTextAttention(nn.Module):
|
|
|
439
440
|
self.o_proj = nn.Linear(
|
|
440
441
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
441
442
|
)
|
|
442
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
443
443
|
self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
|
444
444
|
self.k_norm = Qwen3VLTextRMSNorm(
|
|
445
445
|
self.head_dim, eps=config.rms_norm_eps
|
|
@@ -1201,44 +1201,19 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
|
|
|
1201
1201
|
deepstack_visual_embeds = deepstack_video_embeds
|
|
1202
1202
|
|
|
1203
1203
|
if position_ids is None:
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
)
|
|
1207
|
-
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
|
1208
|
-
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
|
1209
|
-
# Only apply conversion for floating point tensors (inverted masks)
|
|
1210
|
-
if attention_mask_tensor.dtype.is_floating_point:
|
|
1211
|
-
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
|
1212
|
-
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
|
1213
|
-
|
|
1214
|
-
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
1215
|
-
# When compiling, we can't check tensor values thus we check only input length
|
|
1216
|
-
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
|
1217
|
-
# models currently cannot do asssisted decoding
|
|
1218
|
-
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
|
1219
|
-
(input_ids is not None and input_ids.shape[1] != 1)
|
|
1220
|
-
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
|
1221
|
-
)
|
|
1222
|
-
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
|
1223
|
-
(cache_position is not None and cache_position[0] == 0)
|
|
1224
|
-
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
|
1225
|
-
)
|
|
1226
|
-
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
|
1204
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
1205
|
+
if self.rope_deltas is None or past_key_values_length == 0:
|
|
1227
1206
|
position_ids, rope_deltas = self.get_rope_index(
|
|
1228
1207
|
input_ids,
|
|
1229
1208
|
image_grid_thw,
|
|
1230
1209
|
video_grid_thw,
|
|
1231
|
-
attention_mask=
|
|
1210
|
+
attention_mask=attention_mask,
|
|
1232
1211
|
)
|
|
1233
1212
|
self.rope_deltas = rope_deltas
|
|
1234
1213
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
1235
1214
|
else:
|
|
1236
1215
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
1237
|
-
delta = (
|
|
1238
|
-
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
1239
|
-
if cache_position is not None
|
|
1240
|
-
else 0
|
|
1241
|
-
)
|
|
1216
|
+
delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
|
|
1242
1217
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
1243
1218
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
1244
1219
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
@@ -1322,7 +1297,7 @@ class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
|
|
|
1322
1297
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
|
1323
1298
|
return self.model.get_image_features(pixel_values, image_grid_thw)
|
|
1324
1299
|
|
|
1325
|
-
@
|
|
1300
|
+
@can_return_tuple
|
|
1326
1301
|
def forward(
|
|
1327
1302
|
self,
|
|
1328
1303
|
input_ids: torch.LongTensor = None,
|
|
@@ -1414,6 +1389,8 @@ class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
|
|
|
1414
1389
|
loss=loss,
|
|
1415
1390
|
logits=logits,
|
|
1416
1391
|
past_key_values=outputs.past_key_values,
|
|
1392
|
+
hidden_states=outputs.hidden_states,
|
|
1393
|
+
attentions=outputs.attentions,
|
|
1417
1394
|
rope_deltas=outputs.rope_deltas,
|
|
1418
1395
|
)
|
|
1419
1396
|
|
|
@@ -1449,8 +1426,33 @@ class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
|
|
|
1449
1426
|
**kwargs,
|
|
1450
1427
|
)
|
|
1451
1428
|
|
|
1452
|
-
# Qwen3VL position_ids are
|
|
1453
|
-
|
|
1429
|
+
# Qwen3VL position_ids are prepared with rope_deltas
|
|
1430
|
+
if position_ids is None:
|
|
1431
|
+
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
1432
|
+
# When compiling, we can't check tensor values thus we check only input length
|
|
1433
|
+
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
|
1434
|
+
# models currently cannot do asssisted decoding
|
|
1435
|
+
if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
|
|
1436
|
+
vision_positions, rope_deltas = self.model.get_rope_index(
|
|
1437
|
+
model_inputs.get("input_ids", None),
|
|
1438
|
+
image_grid_thw=image_grid_thw,
|
|
1439
|
+
video_grid_thw=video_grid_thw,
|
|
1440
|
+
attention_mask=attention_mask,
|
|
1441
|
+
)
|
|
1442
|
+
self.model.rope_deltas = rope_deltas
|
|
1443
|
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
1444
|
+
elif "position_ids" in model_inputs:
|
|
1445
|
+
batch_size, seq_length = model_inputs["position_ids"].shape
|
|
1446
|
+
device = model_inputs["position_ids"].device
|
|
1447
|
+
position_ids = torch.arange(seq_length, device=device)
|
|
1448
|
+
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
|
1449
|
+
delta = cache_position[0] + self.model.rope_deltas
|
|
1450
|
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
1451
|
+
vision_positions = position_ids + delta.expand_as(position_ids)
|
|
1452
|
+
|
|
1453
|
+
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
|
1454
|
+
text_positions = model_inputs["position_ids"][None, ...]
|
|
1455
|
+
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
|
|
1454
1456
|
|
|
1455
1457
|
if cache_position[0] != 0:
|
|
1456
1458
|
model_inputs["pixel_values"] = None
|
|
@@ -34,8 +34,8 @@ from ...modeling_rope_utils import RopeParameters, dynamic_rope_update
|
|
|
34
34
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
35
35
|
from ...processing_utils import ProcessingKwargs, Unpack
|
|
36
36
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
|
37
|
-
from ...utils import auto_docstring,
|
|
38
|
-
from ...utils.generic import check_model_inputs
|
|
37
|
+
from ...utils import auto_docstring, can_return_tuple, logging
|
|
38
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
39
39
|
from ...video_utils import VideoInput
|
|
40
40
|
from ..llama.modeling_llama import LlamaRotaryEmbedding
|
|
41
41
|
from ..qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
@@ -389,7 +389,7 @@ class Qwen3VLTextRotaryEmbedding(LlamaRotaryEmbedding):
|
|
|
389
389
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
390
390
|
|
|
391
391
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
392
|
-
with
|
|
392
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
393
393
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
394
394
|
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
|
|
395
395
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
@@ -1033,44 +1033,19 @@ class Qwen3VLModel(Qwen2_5_VLModel):
|
|
|
1033
1033
|
deepstack_visual_embeds = deepstack_video_embeds
|
|
1034
1034
|
|
|
1035
1035
|
if position_ids is None:
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
)
|
|
1039
|
-
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
|
1040
|
-
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
|
1041
|
-
# Only apply conversion for floating point tensors (inverted masks)
|
|
1042
|
-
if attention_mask_tensor.dtype.is_floating_point:
|
|
1043
|
-
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
|
1044
|
-
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
|
1045
|
-
|
|
1046
|
-
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
1047
|
-
# When compiling, we can't check tensor values thus we check only input length
|
|
1048
|
-
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
|
1049
|
-
# models currently cannot do asssisted decoding
|
|
1050
|
-
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
|
1051
|
-
(input_ids is not None and input_ids.shape[1] != 1)
|
|
1052
|
-
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
|
1053
|
-
)
|
|
1054
|
-
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
|
1055
|
-
(cache_position is not None and cache_position[0] == 0)
|
|
1056
|
-
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
|
1057
|
-
)
|
|
1058
|
-
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
|
1036
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
1037
|
+
if self.rope_deltas is None or past_key_values_length == 0:
|
|
1059
1038
|
position_ids, rope_deltas = self.get_rope_index(
|
|
1060
1039
|
input_ids,
|
|
1061
1040
|
image_grid_thw,
|
|
1062
1041
|
video_grid_thw,
|
|
1063
|
-
attention_mask=
|
|
1042
|
+
attention_mask=attention_mask,
|
|
1064
1043
|
)
|
|
1065
1044
|
self.rope_deltas = rope_deltas
|
|
1066
1045
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
1067
1046
|
else:
|
|
1068
1047
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
1069
|
-
delta = (
|
|
1070
|
-
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
1071
|
-
if cache_position is not None
|
|
1072
|
-
else 0
|
|
1073
|
-
)
|
|
1048
|
+
delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
|
|
1074
1049
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
1075
1050
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
1076
1051
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
@@ -1105,7 +1080,7 @@ class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
|
1105
1080
|
config: Qwen3VLConfig
|
|
1106
1081
|
_checkpoint_conversion_mapping = {}
|
|
1107
1082
|
|
|
1108
|
-
@
|
|
1083
|
+
@can_return_tuple
|
|
1109
1084
|
def forward(
|
|
1110
1085
|
self,
|
|
1111
1086
|
input_ids: torch.LongTensor = None,
|
|
@@ -1197,6 +1172,8 @@ class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
|
1197
1172
|
loss=loss,
|
|
1198
1173
|
logits=logits,
|
|
1199
1174
|
past_key_values=outputs.past_key_values,
|
|
1175
|
+
hidden_states=outputs.hidden_states,
|
|
1176
|
+
attentions=outputs.attentions,
|
|
1200
1177
|
rope_deltas=outputs.rope_deltas,
|
|
1201
1178
|
)
|
|
1202
1179
|
|
|
@@ -1232,8 +1209,33 @@ class Qwen3VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
|
1232
1209
|
**kwargs,
|
|
1233
1210
|
)
|
|
1234
1211
|
|
|
1235
|
-
# Qwen3VL position_ids are
|
|
1236
|
-
|
|
1212
|
+
# Qwen3VL position_ids are prepared with rope_deltas
|
|
1213
|
+
if position_ids is None:
|
|
1214
|
+
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
1215
|
+
# When compiling, we can't check tensor values thus we check only input length
|
|
1216
|
+
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
|
1217
|
+
# models currently cannot do asssisted decoding
|
|
1218
|
+
if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
|
|
1219
|
+
vision_positions, rope_deltas = self.model.get_rope_index(
|
|
1220
|
+
model_inputs.get("input_ids", None),
|
|
1221
|
+
image_grid_thw=image_grid_thw,
|
|
1222
|
+
video_grid_thw=video_grid_thw,
|
|
1223
|
+
attention_mask=attention_mask,
|
|
1224
|
+
)
|
|
1225
|
+
self.model.rope_deltas = rope_deltas
|
|
1226
|
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
1227
|
+
elif "position_ids" in model_inputs:
|
|
1228
|
+
batch_size, seq_length = model_inputs["position_ids"].shape
|
|
1229
|
+
device = model_inputs["position_ids"].device
|
|
1230
|
+
position_ids = torch.arange(seq_length, device=device)
|
|
1231
|
+
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
|
1232
|
+
delta = cache_position[0] + self.model.rope_deltas
|
|
1233
|
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
1234
|
+
vision_positions = position_ids + delta.expand_as(position_ids)
|
|
1235
|
+
|
|
1236
|
+
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
|
1237
|
+
text_positions = model_inputs["position_ids"][None, ...]
|
|
1238
|
+
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
|
|
1237
1239
|
|
|
1238
1240
|
if cache_position[0] != 0:
|
|
1239
1241
|
model_inputs["pixel_values"] = None
|