transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +30 -3
- transformers/cli/serve.py +47 -17
- transformers/conversion_mapping.py +15 -2
- transformers/convert_slow_tokenizer.py +225 -10
- transformers/core_model_loading.py +196 -135
- transformers/data/data_collator.py +12 -4
- transformers/dependency_versions_table.py +1 -2
- transformers/dynamic_module_utils.py +1 -2
- transformers/feature_extraction_utils.py +1 -2
- transformers/file_utils.py +0 -1
- transformers/generation/__init__.py +11 -1
- transformers/generation/configuration_utils.py +3 -2
- transformers/generation/continuous_batching/__init__.py +4 -0
- transformers/generation/continuous_batching/continuous_api.py +134 -79
- transformers/image_processing_base.py +1 -2
- transformers/integrations/__init__.py +4 -2
- transformers/integrations/accelerate.py +15 -3
- transformers/integrations/aqlm.py +38 -66
- transformers/integrations/awq.py +48 -514
- transformers/integrations/bitnet.py +45 -100
- transformers/integrations/bitsandbytes.py +79 -191
- transformers/integrations/deepspeed.py +1 -0
- transformers/integrations/eetq.py +84 -79
- transformers/integrations/fbgemm_fp8.py +191 -145
- transformers/integrations/finegrained_fp8.py +236 -193
- transformers/integrations/fp_quant.py +92 -0
- transformers/integrations/ggml.py +11 -1
- transformers/integrations/higgs.py +40 -62
- transformers/integrations/hub_kernels.py +42 -3
- transformers/integrations/integration_utils.py +10 -0
- transformers/integrations/mxfp4.py +25 -65
- transformers/integrations/peft.py +7 -29
- transformers/integrations/quanto.py +73 -55
- transformers/integrations/quark.py +55 -0
- transformers/integrations/spqr.py +44 -90
- transformers/integrations/torchao.py +32 -38
- transformers/integrations/vptq.py +42 -59
- transformers/modelcard.py +1 -2
- transformers/modeling_gguf_pytorch_utils.py +8 -0
- transformers/modeling_rope_utils.py +30 -6
- transformers/modeling_utils.py +116 -112
- transformers/models/__init__.py +3 -0
- transformers/models/afmoe/modeling_afmoe.py +4 -4
- transformers/models/albert/tokenization_albert.py +6 -12
- transformers/models/align/modeling_align.py +2 -0
- transformers/models/altclip/modeling_altclip.py +4 -0
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/modeling_aria.py +4 -4
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
- transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
- transformers/models/auto/configuration_auto.py +11 -0
- transformers/models/auto/feature_extraction_auto.py +2 -0
- transformers/models/auto/image_processing_auto.py +1 -0
- transformers/models/auto/modeling_auto.py +6 -0
- transformers/models/auto/processing_auto.py +18 -10
- transformers/models/auto/tokenization_auto.py +74 -472
- transformers/models/autoformer/modeling_autoformer.py +4 -0
- transformers/models/bamba/modeling_bamba.py +4 -3
- transformers/models/bark/modeling_bark.py +2 -0
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/barthez/tokenization_barthez.py +5 -10
- transformers/models/beit/modeling_beit.py +6 -1
- transformers/models/bert/tokenization_bert.py +8 -21
- transformers/models/big_bird/modeling_big_bird.py +6 -0
- transformers/models/big_bird/tokenization_big_bird.py +18 -42
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
- transformers/models/biogpt/modeling_biogpt.py +2 -0
- transformers/models/biogpt/modular_biogpt.py +2 -0
- transformers/models/bit/modeling_bit.py +11 -2
- transformers/models/bitnet/modeling_bitnet.py +4 -4
- transformers/models/blenderbot/modeling_blenderbot.py +5 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
- transformers/models/blip/modeling_blip_text.py +2 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -1
- transformers/models/bloom/modeling_bloom.py +4 -0
- transformers/models/blt/modeling_blt.py +2 -2
- transformers/models/blt/modular_blt.py +2 -2
- transformers/models/bridgetower/modeling_bridgetower.py +5 -1
- transformers/models/bros/modeling_bros.py +4 -0
- transformers/models/camembert/tokenization_camembert.py +8 -12
- transformers/models/canine/modeling_canine.py +5 -0
- transformers/models/chameleon/modeling_chameleon.py +2 -1
- transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
- transformers/models/clap/modeling_clap.py +5 -0
- transformers/models/clip/tokenization_clip.py +22 -44
- transformers/models/clipseg/modeling_clipseg.py +5 -0
- transformers/models/clvp/modeling_clvp.py +5 -0
- transformers/models/clvp/tokenization_clvp.py +1 -63
- transformers/models/code_llama/tokenization_code_llama.py +20 -43
- transformers/models/codegen/tokenization_codegen.py +14 -43
- transformers/models/cohere/modeling_cohere.py +4 -3
- transformers/models/cohere/modular_cohere.py +2 -1
- transformers/models/cohere/tokenization_cohere.py +12 -42
- transformers/models/cohere2/modeling_cohere2.py +7 -6
- transformers/models/cohere2/modular_cohere2.py +5 -5
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
- transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
- transformers/models/colqwen2/modeling_colqwen2.py +1 -0
- transformers/models/colqwen2/modular_colqwen2.py +1 -0
- transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
- transformers/models/convbert/modeling_convbert.py +6 -0
- transformers/models/convnext/modeling_convnext.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +2 -4
- transformers/models/csm/modeling_csm.py +4 -3
- transformers/models/ctrl/modeling_ctrl.py +1 -0
- transformers/models/cvt/modeling_cvt.py +2 -0
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/modeling_d_fine.py +2 -0
- transformers/models/d_fine/modular_d_fine.py +1 -0
- transformers/models/dab_detr/modeling_dab_detr.py +4 -0
- transformers/models/dac/modeling_dac.py +2 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
- transformers/models/dbrx/modeling_dbrx.py +2 -2
- transformers/models/deberta/modeling_deberta.py +5 -0
- transformers/models/deberta/tokenization_deberta.py +11 -20
- transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
- transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
- transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
- transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
- transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
- transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
- transformers/models/depth_anything/modeling_depth_anything.py +1 -0
- transformers/models/depth_pro/modeling_depth_pro.py +2 -0
- transformers/models/detr/modeling_detr.py +5 -0
- transformers/models/dia/modeling_dia.py +4 -3
- transformers/models/dia/modular_dia.py +0 -1
- transformers/models/diffllama/modeling_diffllama.py +2 -2
- transformers/models/dinat/modeling_dinat.py +3 -0
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
- transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
- transformers/models/distilbert/tokenization_distilbert.py +13 -0
- transformers/models/doge/modeling_doge.py +2 -3
- transformers/models/doge/modular_doge.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +2 -0
- transformers/models/dots1/modeling_dots1.py +10 -7
- transformers/models/dots1/modular_dots1.py +5 -3
- transformers/models/dpr/modeling_dpr.py +5 -0
- transformers/models/dpr/tokenization_dpr.py +12 -0
- transformers/models/edgetam/modeling_edgetam.py +1 -1
- transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
- transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
- transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
- transformers/models/efficientnet/modeling_efficientnet.py +2 -0
- transformers/models/emu3/modeling_emu3.py +4 -4
- transformers/models/eomt/image_processing_eomt.py +13 -1
- transformers/models/eomt/image_processing_eomt_fast.py +14 -2
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
- transformers/models/esm/modeling_esmfold.py +5 -4
- transformers/models/evolla/modeling_evolla.py +4 -4
- transformers/models/exaone4/modeling_exaone4.py +2 -2
- transformers/models/exaone4/modular_exaone4.py +0 -1
- transformers/models/falcon/modeling_falcon.py +6 -1
- transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
- transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
- transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
- transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
- transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
- transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
- transformers/models/flaubert/modeling_flaubert.py +7 -0
- transformers/models/flava/modeling_flava.py +6 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
- transformers/models/florence2/modeling_florence2.py +2 -1
- transformers/models/florence2/modular_florence2.py +2 -1
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/focalnet/modeling_focalnet.py +4 -0
- transformers/models/fsmt/modeling_fsmt.py +2 -0
- transformers/models/funnel/modeling_funnel.py +8 -0
- transformers/models/funnel/tokenization_funnel.py +17 -24
- transformers/models/fuyu/processing_fuyu.py +3 -3
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/tokenization_gemma.py +10 -27
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +2 -1
- transformers/models/gemma3/modeling_gemma3.py +14 -84
- transformers/models/gemma3/modular_gemma3.py +12 -81
- transformers/models/gemma3n/modeling_gemma3n.py +18 -209
- transformers/models/gemma3n/modular_gemma3n.py +17 -59
- transformers/models/git/modeling_git.py +2 -0
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
- transformers/models/glm4v/configuration_glm4v.py +3 -1
- transformers/models/glm4v/modeling_glm4v.py +3 -3
- transformers/models/glm4v/modular_glm4v.py +6 -4
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
- transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
- transformers/models/glpn/modeling_glpn.py +2 -0
- transformers/models/gpt2/modeling_gpt2.py +5 -1
- transformers/models/gpt2/tokenization_gpt2.py +16 -44
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
- transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
- transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
- transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
- transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
- transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
- transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
- transformers/models/gptj/modeling_gptj.py +3 -0
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granitemoe/modeling_granitemoe.py +4 -6
- transformers/models/granitemoe/modular_granitemoe.py +0 -2
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
- transformers/models/groupvit/modeling_groupvit.py +3 -0
- transformers/models/helium/modeling_helium.py +4 -3
- transformers/models/herbert/tokenization_herbert.py +9 -25
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
- transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
- transformers/models/hiera/modeling_hiera.py +4 -0
- transformers/models/hubert/modeling_hubert.py +3 -0
- transformers/models/hubert/modular_hubert.py +1 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
- transformers/models/ibert/modeling_ibert.py +6 -0
- transformers/models/idefics/modeling_idefics.py +5 -21
- transformers/models/imagegpt/modeling_imagegpt.py +2 -1
- transformers/models/informer/modeling_informer.py +4 -0
- transformers/models/informer/modular_informer.py +1 -0
- transformers/models/internvl/modeling_internvl.py +2 -4
- transformers/models/internvl/modular_internvl.py +2 -4
- transformers/models/jamba/modeling_jamba.py +2 -2
- transformers/models/janus/modeling_janus.py +1 -0
- transformers/models/janus/modular_janus.py +1 -0
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/kosmos2/modeling_kosmos2.py +1 -0
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
- transformers/models/lasr/__init__.py +29 -0
- transformers/models/lasr/configuration_lasr.py +244 -0
- transformers/models/lasr/feature_extraction_lasr.py +277 -0
- transformers/models/lasr/modeling_lasr.py +729 -0
- transformers/models/lasr/modular_lasr.py +569 -0
- transformers/models/lasr/processing_lasr.py +96 -0
- transformers/models/lasr/tokenization_lasr.py +186 -0
- transformers/models/layoutlm/modeling_layoutlm.py +5 -0
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
- transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
- transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +3 -0
- transformers/models/lfm2/modeling_lfm2.py +4 -5
- transformers/models/lfm2/modular_lfm2.py +0 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
- transformers/models/lightglue/modeling_lightglue.py +3 -1
- transformers/models/lightglue/modular_lightglue.py +1 -0
- transformers/models/lilt/modeling_lilt.py +4 -0
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama/tokenization_llama.py +15 -43
- transformers/models/llama4/modeling_llama4.py +3 -2
- transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
- transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
- transformers/models/longformer/modeling_longformer.py +6 -0
- transformers/models/longt5/modeling_longt5.py +4 -0
- transformers/models/luke/modeling_luke.py +9 -0
- transformers/models/luke/tokenization_luke.py +11 -38
- transformers/models/lxmert/modeling_lxmert.py +2 -0
- transformers/models/m2m_100/modeling_m2m_100.py +4 -0
- transformers/models/mamba/modeling_mamba.py +14 -22
- transformers/models/marian/modeling_marian.py +5 -0
- transformers/models/markuplm/modeling_markuplm.py +4 -0
- transformers/models/markuplm/tokenization_markuplm.py +28 -61
- transformers/models/mask2former/modeling_mask2former.py +2 -0
- transformers/models/maskformer/modeling_maskformer.py +2 -0
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/mbart/tokenization_mbart.py +11 -52
- transformers/models/mbart50/tokenization_mbart50.py +7 -10
- transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
- transformers/models/mgp_str/modeling_mgp_str.py +2 -0
- transformers/models/mimi/modeling_mimi.py +3 -1
- transformers/models/minimax/modeling_minimax.py +4 -4
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral3/configuration_ministral3.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +4 -3
- transformers/models/mistral/modeling_mistral.py +4 -3
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mllama/modeling_mllama.py +2 -2
- transformers/models/mluke/tokenization_mluke.py +6 -6
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
- transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
- transformers/models/mobilevit/modeling_mobilevit.py +3 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
- transformers/models/modernbert/modeling_modernbert.py +4 -1
- transformers/models/modernbert/modular_modernbert.py +2 -0
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
- transformers/models/moonshine/modeling_moonshine.py +4 -2
- transformers/models/moshi/modeling_moshi.py +5 -2
- transformers/models/mpnet/modeling_mpnet.py +5 -0
- transformers/models/mpnet/tokenization_mpnet.py +5 -13
- transformers/models/mpt/modeling_mpt.py +2 -0
- transformers/models/mra/modeling_mra.py +6 -0
- transformers/models/mt5/modeling_mt5.py +7 -0
- transformers/models/musicgen/modeling_musicgen.py +2 -0
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nemotron/modeling_nemotron.py +4 -2
- transformers/models/nllb/tokenization_nllb.py +8 -22
- transformers/models/nougat/tokenization_nougat.py +11 -59
- transformers/models/nystromformer/modeling_nystromformer.py +6 -0
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +2 -2
- transformers/models/olmo2/modeling_olmo2.py +4 -5
- transformers/models/olmo2/modular_olmo2.py +0 -1
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
- transformers/models/oneformer/modeling_oneformer.py +4 -1
- transformers/models/openai/modeling_openai.py +3 -0
- transformers/models/openai/tokenization_openai.py +10 -46
- transformers/models/opt/modeling_opt.py +2 -0
- transformers/models/owlv2/modeling_owlv2.py +4 -0
- transformers/models/owlvit/modeling_owlvit.py +4 -0
- transformers/models/paddleocr_vl/__init__.py +32 -0
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
- transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
- transformers/models/parakeet/configuration_parakeet.py +4 -6
- transformers/models/parakeet/modeling_parakeet.py +9 -6
- transformers/models/parakeet/modular_parakeet.py +2 -2
- transformers/models/parakeet/processing_parakeet.py +1 -0
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
- transformers/models/patchtst/modeling_patchtst.py +20 -2
- transformers/models/pegasus/modeling_pegasus.py +5 -0
- transformers/models/pegasus/tokenization_pegasus.py +17 -44
- transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
- transformers/models/perceiver/modeling_perceiver.py +8 -0
- transformers/models/persimmon/modeling_persimmon.py +2 -1
- transformers/models/phi/modeling_phi.py +4 -5
- transformers/models/phi/modular_phi.py +0 -1
- transformers/models/phi3/modeling_phi3.py +2 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
- transformers/models/phimoe/modeling_phimoe.py +4 -4
- transformers/models/phimoe/modular_phimoe.py +2 -2
- transformers/models/pix2struct/modeling_pix2struct.py +2 -0
- transformers/models/pixtral/modeling_pixtral.py +2 -1
- transformers/models/plbart/modeling_plbart.py +6 -0
- transformers/models/plbart/modular_plbart.py +2 -0
- transformers/models/plbart/tokenization_plbart.py +0 -2
- transformers/models/poolformer/modeling_poolformer.py +2 -0
- transformers/models/pop2piano/modeling_pop2piano.py +2 -0
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
- transformers/models/prophetnet/modeling_prophetnet.py +3 -0
- transformers/models/pvt/modeling_pvt.py +2 -0
- transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/tokenization_qwen2.py +14 -18
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
- transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
- transformers/models/rag/modeling_rag.py +1 -0
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
- transformers/models/reformer/modeling_reformer.py +4 -0
- transformers/models/reformer/tokenization_reformer.py +11 -28
- transformers/models/regnet/modeling_regnet.py +6 -1
- transformers/models/rembert/modeling_rembert.py +6 -0
- transformers/models/rembert/tokenization_rembert.py +3 -10
- transformers/models/resnet/modeling_resnet.py +11 -2
- transformers/models/roberta/tokenization_roberta.py +18 -27
- transformers/models/roformer/modeling_roformer.py +6 -0
- transformers/models/roformer/tokenization_roformer.py +77 -412
- transformers/models/rt_detr/modeling_rt_detr.py +2 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
- transformers/models/rwkv/modeling_rwkv.py +1 -0
- transformers/models/sam2/modeling_sam2.py +2 -2
- transformers/models/sam2/modular_sam2.py +2 -2
- transformers/models/sam2_video/modeling_sam2_video.py +1 -0
- transformers/models/sam2_video/modular_sam2_video.py +1 -0
- transformers/models/sam3/modeling_sam3.py +77 -80
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
- transformers/models/sam3_video/modeling_sam3_video.py +1 -0
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
- transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
- transformers/models/seed_oss/modeling_seed_oss.py +2 -2
- transformers/models/segformer/modeling_segformer.py +4 -1
- transformers/models/seggpt/modeling_seggpt.py +2 -0
- transformers/models/sew/modeling_sew.py +3 -0
- transformers/models/sew/modular_sew.py +1 -0
- transformers/models/sew_d/modeling_sew_d.py +3 -0
- transformers/models/siglip2/modeling_siglip2.py +4 -0
- transformers/models/siglip2/modular_siglip2.py +4 -0
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smolvlm/processing_smolvlm.py +0 -7
- transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
- transformers/models/speecht5/modeling_speecht5.py +13 -1
- transformers/models/splinter/modeling_splinter.py +3 -0
- transformers/models/splinter/tokenization_splinter.py +9 -28
- transformers/models/squeezebert/modeling_squeezebert.py +6 -0
- transformers/models/stablelm/modeling_stablelm.py +3 -1
- transformers/models/starcoder2/modeling_starcoder2.py +4 -3
- transformers/models/superglue/modeling_superglue.py +1 -0
- transformers/models/superpoint/modeling_superpoint.py +1 -0
- transformers/models/swiftformer/modeling_swiftformer.py +2 -0
- transformers/models/swin/modeling_swin.py +4 -0
- transformers/models/swin2sr/modeling_swin2sr.py +2 -0
- transformers/models/swinv2/modeling_swinv2.py +4 -0
- transformers/models/t5/modeling_t5.py +7 -0
- transformers/models/t5/tokenization_t5.py +4 -8
- transformers/models/t5gemma/modeling_t5gemma.py +5 -5
- transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
- transformers/models/table_transformer/modeling_table_transformer.py +4 -0
- transformers/models/tapas/modeling_tapas.py +3 -0
- transformers/models/textnet/modeling_textnet.py +11 -2
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
- transformers/models/timesfm/modeling_timesfm.py +2 -0
- transformers/models/timesfm/modular_timesfm.py +2 -0
- transformers/models/timesformer/modeling_timesformer.py +2 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
- transformers/models/trocr/modeling_trocr.py +2 -0
- transformers/models/tvp/modeling_tvp.py +2 -0
- transformers/models/udop/modeling_udop.py +4 -0
- transformers/models/udop/tokenization_udop.py +5 -13
- transformers/models/umt5/modeling_umt5.py +7 -0
- transformers/models/unispeech/modeling_unispeech.py +4 -0
- transformers/models/unispeech/modular_unispeech.py +2 -0
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
- transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
- transformers/models/univnet/modeling_univnet.py +1 -0
- transformers/models/upernet/modeling_upernet.py +1 -0
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/vilt/modeling_vilt.py +6 -0
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
- transformers/models/visual_bert/modeling_visual_bert.py +6 -0
- transformers/models/vitdet/modeling_vitdet.py +2 -0
- transformers/models/vitmatte/modeling_vitmatte.py +1 -0
- transformers/models/vits/modeling_vits.py +1 -0
- transformers/models/vjepa2/modeling_vjepa2.py +1 -0
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
- transformers/models/wavlm/modeling_wavlm.py +5 -0
- transformers/models/whisper/modeling_whisper.py +6 -0
- transformers/models/whisper/tokenization_whisper.py +4 -15
- transformers/models/x_clip/modeling_x_clip.py +3 -0
- transformers/models/xglm/modeling_xglm.py +1 -0
- transformers/models/xglm/tokenization_xglm.py +4 -9
- transformers/models/xlm/modeling_xlm.py +5 -0
- transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
- transformers/models/xlnet/tokenization_xlnet.py +3 -7
- transformers/models/yoso/modeling_yoso.py +6 -0
- transformers/models/zamba/modeling_zamba.py +2 -0
- transformers/models/zamba2/modeling_zamba2.py +4 -2
- transformers/models/zamba2/modular_zamba2.py +1 -1
- transformers/models/zoedepth/modeling_zoedepth.py +1 -0
- transformers/pipelines/__init__.py +2 -3
- transformers/pipelines/base.py +1 -9
- transformers/pipelines/document_question_answering.py +3 -1
- transformers/pipelines/text_generation.py +1 -1
- transformers/processing_utils.py +23 -11
- transformers/quantizers/base.py +35 -110
- transformers/quantizers/quantizer_aqlm.py +1 -5
- transformers/quantizers/quantizer_auto_round.py +1 -2
- transformers/quantizers/quantizer_awq.py +17 -81
- transformers/quantizers/quantizer_bitnet.py +3 -8
- transformers/quantizers/quantizer_bnb_4bit.py +13 -110
- transformers/quantizers/quantizer_bnb_8bit.py +16 -92
- transformers/quantizers/quantizer_compressed_tensors.py +1 -5
- transformers/quantizers/quantizer_eetq.py +14 -62
- transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
- transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
- transformers/quantizers/quantizer_fp_quant.py +48 -78
- transformers/quantizers/quantizer_gptq.py +7 -24
- transformers/quantizers/quantizer_higgs.py +40 -54
- transformers/quantizers/quantizer_hqq.py +144 -153
- transformers/quantizers/quantizer_mxfp4.py +13 -167
- transformers/quantizers/quantizer_quanto.py +20 -64
- transformers/quantizers/quantizer_quark.py +36 -17
- transformers/quantizers/quantizer_spqr.py +1 -4
- transformers/quantizers/quantizer_torchao.py +23 -202
- transformers/quantizers/quantizer_vptq.py +8 -22
- transformers/quantizers/quantizers_utils.py +20 -0
- transformers/testing_utils.py +297 -36
- transformers/tokenization_mistral_common.py +4 -0
- transformers/tokenization_utils_base.py +113 -222
- transformers/tokenization_utils_tokenizers.py +168 -107
- transformers/trainer.py +28 -31
- transformers/trainer_jit_checkpoint.py +126 -0
- transformers/trainer_utils.py +1 -1
- transformers/training_args.py +66 -28
- transformers/utils/__init__.py +3 -4
- transformers/utils/auto_docstring.py +1 -0
- transformers/utils/generic.py +27 -1
- transformers/utils/hub.py +5 -15
- transformers/utils/import_utils.py +61 -16
- transformers/utils/kernel_config.py +4 -2
- transformers/utils/loading_report.py +19 -10
- transformers/utils/quantization_config.py +75 -242
- transformers/video_processing_utils.py +1 -2
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
- transformers/kernels/__init__.py +0 -0
- transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
- transformers/models/roformer/tokenization_roformer_fast.py +0 -160
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -43,6 +43,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
|
43
43
|
from ...processing_utils import Unpack
|
|
44
44
|
from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging
|
|
45
45
|
from ...utils.deprecation import deprecate_kwarg
|
|
46
|
+
from ...utils.generic import maybe_autocast
|
|
46
47
|
from ...utils.hub import cached_file
|
|
47
48
|
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
|
|
48
49
|
from .configuration_qwen2_5_omni import (
|
|
@@ -1291,7 +1292,7 @@ class Qwen2_5OmniRotaryEmbedding(nn.Module):
|
|
|
1291
1292
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
1292
1293
|
|
|
1293
1294
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
1294
|
-
with
|
|
1295
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
1295
1296
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
1296
1297
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1297
1298
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -1958,11 +1959,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|
|
1958
1959
|
audio_feature_lengths = None
|
|
1959
1960
|
|
|
1960
1961
|
if attention_mask is not None and position_ids is None:
|
|
1961
|
-
if (
|
|
1962
|
-
|
|
1963
|
-
or (cache_position is not None and cache_position[0] == 0)
|
|
1964
|
-
or self.rope_deltas is None
|
|
1965
|
-
):
|
|
1962
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
1963
|
+
if past_key_values_length == 0 or self.rope_deltas is None:
|
|
1966
1964
|
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
|
|
1967
1965
|
position_ids, rope_deltas = self.get_rope_index(
|
|
1968
1966
|
input_ids,
|
|
@@ -1977,7 +1975,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|
|
1977
1975
|
self.rope_deltas = rope_deltas
|
|
1978
1976
|
else:
|
|
1979
1977
|
batch_size, seq_length = input_ids.shape
|
|
1980
|
-
delta =
|
|
1978
|
+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
|
|
1981
1979
|
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
1982
1980
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
1983
1981
|
position_ids = position_ids.add(delta)
|
|
@@ -2317,6 +2315,7 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
|
|
|
2317
2315
|
output_attentions: Optional[bool] = None,
|
|
2318
2316
|
output_hidden_states: Optional[bool] = None,
|
|
2319
2317
|
return_dict: Optional[bool] = None,
|
|
2318
|
+
**kwargs,
|
|
2320
2319
|
) -> Union[tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]:
|
|
2321
2320
|
r"""
|
|
2322
2321
|
thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
@@ -2366,11 +2365,8 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
|
|
|
2366
2365
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
2367
2366
|
|
|
2368
2367
|
if attention_mask is not None and position_ids is None:
|
|
2369
|
-
if (
|
|
2370
|
-
|
|
2371
|
-
or (cache_position is not None and cache_position[0] == 0)
|
|
2372
|
-
or self.rope_deltas is None
|
|
2373
|
-
):
|
|
2368
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
2369
|
+
if past_key_values_length == 0 or self.rope_deltas is None:
|
|
2374
2370
|
position_ids, rope_deltas = self.get_rope_index(
|
|
2375
2371
|
input_text_ids,
|
|
2376
2372
|
image_grid_thw,
|
|
@@ -2390,8 +2386,8 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
|
|
|
2390
2386
|
self.rope_deltas = rope_deltas
|
|
2391
2387
|
|
|
2392
2388
|
else:
|
|
2393
|
-
batch_size, seq_length =
|
|
2394
|
-
delta =
|
|
2389
|
+
batch_size, seq_length, _ = inputs_embeds.shape
|
|
2390
|
+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
|
|
2395
2391
|
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
2396
2392
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
2397
2393
|
position_ids = position_ids.add(delta)
|
|
@@ -2564,7 +2560,7 @@ class Qwen2_5OmniDiTRotaryEmbedding(nn.Module):
|
|
|
2564
2560
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
2565
2561
|
|
|
2566
2562
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
2567
|
-
with
|
|
2563
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
2568
2564
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
2569
2565
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
2570
2566
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -3459,7 +3455,7 @@ class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel):
|
|
|
3459
3455
|
decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20
|
|
3460
3456
|
return self.normalize_spectrogram(decibel_spectrum, 1, -115)
|
|
3461
3457
|
|
|
3462
|
-
def forward(self, mel_spectrogram):
|
|
3458
|
+
def forward(self, mel_spectrogram, **kwargs):
|
|
3463
3459
|
processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram)
|
|
3464
3460
|
hidden_representation = self.conv_pre(processed_spectrogram)
|
|
3465
3461
|
|
|
@@ -3592,6 +3588,7 @@ class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel):
|
|
|
3592
3588
|
drop_audio_conditioning=False,
|
|
3593
3589
|
drop_code=False,
|
|
3594
3590
|
apply_cfg=True,
|
|
3591
|
+
**kwargs,
|
|
3595
3592
|
):
|
|
3596
3593
|
batch_size = hidden_states.shape[0]
|
|
3597
3594
|
if time_step.ndim == 0:
|
|
@@ -399,7 +399,7 @@ class Qwen2_5OmniTextConfig(PreTrainedConfig):
|
|
|
399
399
|
self.rope_parameters = rope_parameters
|
|
400
400
|
super().__init__(
|
|
401
401
|
tie_word_embeddings=tie_word_embeddings,
|
|
402
|
-
ignore_keys_at_rope_validation={"
|
|
402
|
+
ignore_keys_at_rope_validation={"mrope_section"},
|
|
403
403
|
**kwargs,
|
|
404
404
|
)
|
|
405
405
|
|
|
@@ -747,7 +747,9 @@ class Qwen2_5OmniTalkerConfig(PreTrainedConfig):
|
|
|
747
747
|
layer_type_validation(self.layer_types, self.num_hidden_layers)
|
|
748
748
|
|
|
749
749
|
self.rope_parameters = rope_parameters
|
|
750
|
-
super().__init__(
|
|
750
|
+
super().__init__(
|
|
751
|
+
tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
|
|
752
|
+
)
|
|
751
753
|
|
|
752
754
|
|
|
753
755
|
class Qwen2_5OmniDiTConfig(PreTrainedConfig):
|
|
@@ -2306,11 +2308,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|
|
2306
2308
|
audio_feature_lengths = None
|
|
2307
2309
|
|
|
2308
2310
|
if attention_mask is not None and position_ids is None:
|
|
2309
|
-
if (
|
|
2310
|
-
|
|
2311
|
-
or (cache_position is not None and cache_position[0] == 0)
|
|
2312
|
-
or self.rope_deltas is None
|
|
2313
|
-
):
|
|
2311
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
2312
|
+
if past_key_values_length == 0 or self.rope_deltas is None:
|
|
2314
2313
|
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
|
|
2315
2314
|
position_ids, rope_deltas = self.get_rope_index(
|
|
2316
2315
|
input_ids,
|
|
@@ -2325,7 +2324,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|
|
2325
2324
|
self.rope_deltas = rope_deltas
|
|
2326
2325
|
else:
|
|
2327
2326
|
batch_size, seq_length = input_ids.shape
|
|
2328
|
-
delta =
|
|
2327
|
+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
|
|
2329
2328
|
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
2330
2329
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
2331
2330
|
position_ids = position_ids.add(delta)
|
|
@@ -2518,6 +2517,7 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
|
|
|
2518
2517
|
output_attentions: Optional[bool] = None,
|
|
2519
2518
|
output_hidden_states: Optional[bool] = None,
|
|
2520
2519
|
return_dict: Optional[bool] = None,
|
|
2520
|
+
**kwargs,
|
|
2521
2521
|
) -> Union[tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]:
|
|
2522
2522
|
r"""
|
|
2523
2523
|
thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
@@ -2567,11 +2567,8 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
|
|
|
2567
2567
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
2568
2568
|
|
|
2569
2569
|
if attention_mask is not None and position_ids is None:
|
|
2570
|
-
if (
|
|
2571
|
-
|
|
2572
|
-
or (cache_position is not None and cache_position[0] == 0)
|
|
2573
|
-
or self.rope_deltas is None
|
|
2574
|
-
):
|
|
2570
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
2571
|
+
if past_key_values_length == 0 or self.rope_deltas is None:
|
|
2575
2572
|
position_ids, rope_deltas = self.get_rope_index(
|
|
2576
2573
|
input_text_ids,
|
|
2577
2574
|
image_grid_thw,
|
|
@@ -2591,8 +2588,8 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
|
|
|
2591
2588
|
self.rope_deltas = rope_deltas
|
|
2592
2589
|
|
|
2593
2590
|
else:
|
|
2594
|
-
batch_size, seq_length =
|
|
2595
|
-
delta =
|
|
2591
|
+
batch_size, seq_length, _ = inputs_embeds.shape
|
|
2592
|
+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
|
|
2596
2593
|
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
2597
2594
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
2598
2595
|
position_ids = position_ids.add(delta)
|
|
@@ -3617,7 +3614,7 @@ class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel):
|
|
|
3617
3614
|
decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20
|
|
3618
3615
|
return self.normalize_spectrogram(decibel_spectrum, 1, -115)
|
|
3619
3616
|
|
|
3620
|
-
def forward(self, mel_spectrogram):
|
|
3617
|
+
def forward(self, mel_spectrogram, **kwargs):
|
|
3621
3618
|
processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram)
|
|
3622
3619
|
hidden_representation = self.conv_pre(processed_spectrogram)
|
|
3623
3620
|
|
|
@@ -3750,6 +3747,7 @@ class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel):
|
|
|
3750
3747
|
drop_audio_conditioning=False,
|
|
3751
3748
|
drop_code=False,
|
|
3752
3749
|
apply_cfg=True,
|
|
3750
|
+
**kwargs,
|
|
3753
3751
|
):
|
|
3754
3752
|
batch_size = hidden_states.shape[0]
|
|
3755
3753
|
if time_step.ndim == 0:
|
|
@@ -230,7 +230,7 @@ class Qwen2_5_VLTextConfig(PreTrainedConfig):
|
|
|
230
230
|
bos_token_id=bos_token_id,
|
|
231
231
|
eos_token_id=eos_token_id,
|
|
232
232
|
pad_token_id=pad_token_id,
|
|
233
|
-
ignore_keys_at_rope_validation={"
|
|
233
|
+
ignore_keys_at_rope_validation={"mrope_section"},
|
|
234
234
|
**kwargs,
|
|
235
235
|
)
|
|
236
236
|
|
|
@@ -43,6 +43,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
|
43
43
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
44
44
|
from ...processing_utils import Unpack
|
|
45
45
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
46
|
+
from ...utils.generic import maybe_autocast
|
|
46
47
|
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
|
|
47
48
|
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
|
|
48
49
|
|
|
@@ -547,7 +548,7 @@ class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
|
|
547
548
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
548
549
|
|
|
549
550
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
550
|
-
with
|
|
551
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
551
552
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
552
553
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
553
554
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -1290,7 +1291,8 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|
|
1290
1291
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
1291
1292
|
|
|
1292
1293
|
if position_ids is None:
|
|
1293
|
-
|
|
1294
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
1295
|
+
if self.rope_deltas is None or past_key_values_length == 0:
|
|
1294
1296
|
position_ids, rope_deltas = self.get_rope_index(
|
|
1295
1297
|
input_ids,
|
|
1296
1298
|
image_grid_thw,
|
|
@@ -1303,10 +1305,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|
|
1303
1305
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
1304
1306
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
1305
1307
|
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
|
1306
|
-
|
|
1307
|
-
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
1308
|
-
else:
|
|
1309
|
-
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
|
|
1308
|
+
delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
|
|
1310
1309
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
|
|
1311
1310
|
position_ids = position_ids + delta.to(position_ids.device)
|
|
1312
1311
|
|
|
@@ -595,7 +595,8 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
|
|
595
595
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
596
596
|
|
|
597
597
|
if position_ids is None:
|
|
598
|
-
|
|
598
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
599
|
+
if self.rope_deltas is None or past_key_values_length == 0:
|
|
599
600
|
position_ids, rope_deltas = self.get_rope_index(
|
|
600
601
|
input_ids,
|
|
601
602
|
image_grid_thw,
|
|
@@ -608,10 +609,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
|
|
608
609
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
609
610
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
610
611
|
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
|
611
|
-
|
|
612
|
-
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
613
|
-
else:
|
|
614
|
-
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
|
|
612
|
+
delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
|
|
615
613
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
|
|
616
614
|
position_ids = position_ids + delta.to(position_ids.device)
|
|
617
615
|
|
|
@@ -323,6 +323,7 @@ class Qwen2AudioEncoder(Qwen2AudioPreTrainedModel):
|
|
|
323
323
|
output_attentions=None,
|
|
324
324
|
output_hidden_states=None,
|
|
325
325
|
return_dict=None,
|
|
326
|
+
**kwargs,
|
|
326
327
|
):
|
|
327
328
|
r"""
|
|
328
329
|
Args:
|
|
@@ -685,6 +686,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
|
|
685
686
|
output_hidden_states: Optional[bool] = None,
|
|
686
687
|
return_dict: Optional[bool] = None,
|
|
687
688
|
cache_position: Optional[torch.LongTensor] = None,
|
|
689
|
+
**kwargs,
|
|
688
690
|
) -> Union[tuple, Qwen2AudioCausalLMOutputWithPast]:
|
|
689
691
|
r"""
|
|
690
692
|
feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
|
|
@@ -35,7 +35,7 @@ from ... import initialization as init
|
|
|
35
35
|
from ...activations import ACT2FN
|
|
36
36
|
from ...cache_utils import Cache, DynamicCache
|
|
37
37
|
from ...generation import GenerationMixin
|
|
38
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
38
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
39
39
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
40
40
|
from ...modeling_layers import (
|
|
41
41
|
GenericForQuestionAnswering,
|
|
@@ -48,7 +48,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
48
48
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
49
49
|
from ...processing_utils import Unpack
|
|
50
50
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
51
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
51
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
52
52
|
from .configuration_qwen2_moe import Qwen2MoeConfig
|
|
53
53
|
|
|
54
54
|
|
|
@@ -129,7 +129,7 @@ class Qwen2MoeRotaryEmbedding(nn.Module):
|
|
|
129
129
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
130
130
|
|
|
131
131
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
132
|
-
with
|
|
132
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
133
133
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
134
134
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
135
135
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -227,6 +227,7 @@ def eager_attention_forward(
|
|
|
227
227
|
return attn_output, attn_weights
|
|
228
228
|
|
|
229
229
|
|
|
230
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
230
231
|
class Qwen2MoeAttention(nn.Module):
|
|
231
232
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
232
233
|
|
|
@@ -244,7 +245,6 @@ class Qwen2MoeAttention(nn.Module):
|
|
|
244
245
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
|
|
245
246
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
|
|
246
247
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
|
247
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
248
248
|
if self.config.layer_types[layer_idx] == "sliding_attention":
|
|
249
249
|
self.sliding_window = config.sliding_window
|
|
250
250
|
|
|
@@ -218,7 +218,7 @@ class Qwen2VLTextConfig(PreTrainedConfig):
|
|
|
218
218
|
bos_token_id=bos_token_id,
|
|
219
219
|
eos_token_id=eos_token_id,
|
|
220
220
|
pad_token_id=pad_token_id,
|
|
221
|
-
ignore_keys_at_rope_validation={"
|
|
221
|
+
ignore_keys_at_rope_validation={"mrope_section"},
|
|
222
222
|
**kwargs,
|
|
223
223
|
)
|
|
224
224
|
|
|
@@ -42,9 +42,9 @@ from ...utils import (
|
|
|
42
42
|
TransformersKwargs,
|
|
43
43
|
auto_docstring,
|
|
44
44
|
can_return_tuple,
|
|
45
|
-
is_torchdynamo_compiling,
|
|
46
45
|
logging,
|
|
47
46
|
)
|
|
47
|
+
from ...utils.generic import maybe_autocast
|
|
48
48
|
from ..qwen2.modeling_qwen2 import (
|
|
49
49
|
Qwen2RMSNorm,
|
|
50
50
|
)
|
|
@@ -165,7 +165,7 @@ class Qwen2VLRotaryEmbedding(nn.Module):
|
|
|
165
165
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
|
166
166
|
|
|
167
167
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
168
|
-
with
|
|
168
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
169
169
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
|
170
170
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
171
171
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -1222,7 +1222,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|
|
1222
1222
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
1223
1223
|
|
|
1224
1224
|
if position_ids is None:
|
|
1225
|
-
|
|
1225
|
+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
|
|
1226
|
+
if self.rope_deltas is None or past_key_values_length == 0:
|
|
1226
1227
|
position_ids, rope_deltas = self.get_rope_index(
|
|
1227
1228
|
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
|
1228
1229
|
)
|
|
@@ -1232,10 +1233,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|
|
1232
1233
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
1233
1234
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
1234
1235
|
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
|
1235
|
-
|
|
1236
|
-
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
1237
|
-
else:
|
|
1238
|
-
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
|
|
1236
|
+
delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
|
|
1239
1237
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
1240
1238
|
position_ids = position_ids + delta.to(position_ids.device)
|
|
1241
1239
|
|
|
@@ -1443,15 +1441,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|
|
1443
1441
|
# When compiling, we can't check tensor values thus we check only input length
|
|
1444
1442
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
|
1445
1443
|
# models currently cannot do asssisted decoding
|
|
1446
|
-
|
|
1447
|
-
(input_ids is not None and input_ids.shape[1] != 1)
|
|
1448
|
-
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
|
1449
|
-
)
|
|
1450
|
-
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
|
1451
|
-
(cache_position is not None and cache_position[0] == 0)
|
|
1452
|
-
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
|
1453
|
-
)
|
|
1454
|
-
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.model.rope_deltas is None:
|
|
1444
|
+
if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
|
|
1455
1445
|
vision_positions, rope_deltas = self.model.get_rope_index(
|
|
1456
1446
|
model_inputs.get("input_ids", None),
|
|
1457
1447
|
image_grid_thw=image_grid_thw,
|
|
@@ -28,7 +28,7 @@ from torch import nn
|
|
|
28
28
|
from ...activations import ACT2FN
|
|
29
29
|
from ...cache_utils import Cache, DynamicCache
|
|
30
30
|
from ...generation import GenerationMixin
|
|
31
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
31
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
32
32
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
33
33
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
34
34
|
from ...modeling_layers import (
|
|
@@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
42
42
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
43
43
|
from ...processing_utils import Unpack
|
|
44
44
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
45
|
-
from ...utils.generic import check_model_inputs
|
|
45
|
+
from ...utils.generic import check_model_inputs, maybe_autocast
|
|
46
46
|
from .configuration_qwen3 import Qwen3Config
|
|
47
47
|
|
|
48
48
|
|
|
@@ -139,7 +139,7 @@ class Qwen3RotaryEmbedding(nn.Module):
|
|
|
139
139
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
140
140
|
|
|
141
141
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
142
|
-
with
|
|
142
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
143
143
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
144
144
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
145
145
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -221,6 +221,7 @@ def eager_attention_forward(
|
|
|
221
221
|
return attn_output, attn_weights
|
|
222
222
|
|
|
223
223
|
|
|
224
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
224
225
|
class Qwen3Attention(nn.Module):
|
|
225
226
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
226
227
|
|
|
@@ -247,7 +248,6 @@ class Qwen3Attention(nn.Module):
|
|
|
247
248
|
self.o_proj = nn.Linear(
|
|
248
249
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
249
250
|
)
|
|
250
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
251
251
|
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
|
252
252
|
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
|
|
253
253
|
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
|
@@ -30,7 +30,7 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache, DynamicCache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
-
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
|
33
|
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
|
34
34
|
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
35
35
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
36
36
|
from ...modeling_layers import (
|
|
@@ -44,7 +44,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
44
44
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
45
45
|
from ...processing_utils import Unpack
|
|
46
46
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
|
47
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
47
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
48
48
|
from .configuration_qwen3_moe import Qwen3MoeConfig
|
|
49
49
|
|
|
50
50
|
|
|
@@ -121,6 +121,7 @@ def eager_attention_forward(
|
|
|
121
121
|
return attn_output, attn_weights
|
|
122
122
|
|
|
123
123
|
|
|
124
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
124
125
|
class Qwen3MoeAttention(nn.Module):
|
|
125
126
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
126
127
|
|
|
@@ -146,7 +147,6 @@ class Qwen3MoeAttention(nn.Module):
|
|
|
146
147
|
self.o_proj = nn.Linear(
|
|
147
148
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
148
149
|
)
|
|
149
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
150
150
|
self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
|
151
151
|
self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
|
|
152
152
|
self.sliding_window = getattr(config, "sliding_window", None)
|
|
@@ -440,7 +440,7 @@ class Qwen3MoeRotaryEmbedding(nn.Module):
|
|
|
440
440
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
441
441
|
|
|
442
442
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
443
|
-
with
|
|
443
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
444
444
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
445
445
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
446
446
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -30,6 +30,7 @@ from ... import initialization as init
|
|
|
30
30
|
from ...activations import ACT2FN
|
|
31
31
|
from ...cache_utils import Cache
|
|
32
32
|
from ...generation import GenerationMixin
|
|
33
|
+
from ...integrations import use_kernelized_func
|
|
33
34
|
from ...masking_utils import create_causal_mask
|
|
34
35
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
35
36
|
from ...modeling_layers import (
|
|
@@ -43,7 +44,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
|
43
44
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
44
45
|
from ...processing_utils import Unpack
|
|
45
46
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
|
46
|
-
from ...utils.generic import OutputRecorder, check_model_inputs
|
|
47
|
+
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
|
|
47
48
|
from ...utils.import_utils import (
|
|
48
49
|
is_causal_conv1d_available,
|
|
49
50
|
is_flash_linear_attention_available,
|
|
@@ -232,7 +233,7 @@ class Qwen3NextRotaryEmbedding(nn.Module):
|
|
|
232
233
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
233
234
|
|
|
234
235
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
235
|
-
with
|
|
236
|
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
|
236
237
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
237
238
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
238
239
|
cos = emb.cos() * self.attention_scaling
|
|
@@ -347,6 +348,7 @@ def eager_attention_forward(
|
|
|
347
348
|
return attn_output, attn_weights
|
|
348
349
|
|
|
349
350
|
|
|
351
|
+
@use_kernelized_func(apply_rotary_pos_emb)
|
|
350
352
|
class Qwen3NextAttention(nn.Module):
|
|
351
353
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
352
354
|
|
|
@@ -371,7 +373,6 @@ class Qwen3NextAttention(nn.Module):
|
|
|
371
373
|
self.o_proj = nn.Linear(
|
|
372
374
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
373
375
|
)
|
|
374
|
-
self.rotary_fn = apply_rotary_pos_emb
|
|
375
376
|
self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
|
376
377
|
self.k_norm = Qwen3NextRMSNorm(
|
|
377
378
|
self.head_dim, eps=config.rms_norm_eps
|