transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +20 -1
- transformers/activations.py +1 -1
- transformers/audio_utils.py +0 -1
- transformers/cache_utils.py +17 -15
- transformers/configuration_utils.py +114 -70
- transformers/conversion_mapping.py +68 -5
- transformers/core_model_loading.py +201 -35
- transformers/dependency_versions_table.py +1 -1
- transformers/feature_extraction_utils.py +54 -22
- transformers/generation/candidate_generator.py +79 -31
- transformers/generation/configuration_utils.py +162 -122
- transformers/generation/continuous_batching/cache.py +47 -18
- transformers/generation/continuous_batching/cache_manager.py +131 -34
- transformers/generation/continuous_batching/continuous_api.py +101 -64
- transformers/generation/continuous_batching/requests.py +28 -1
- transformers/generation/continuous_batching/scheduler.py +11 -4
- transformers/generation/stopping_criteria.py +1 -1
- transformers/generation/utils.py +108 -110
- transformers/generation/watermarking.py +8 -5
- transformers/image_processing_base.py +2 -12
- transformers/image_processing_utils_fast.py +15 -4
- transformers/initialization.py +37 -0
- transformers/integrations/__init__.py +12 -0
- transformers/integrations/accelerate.py +44 -111
- transformers/integrations/aqlm.py +3 -5
- transformers/integrations/awq.py +2 -5
- transformers/integrations/bitnet.py +5 -8
- transformers/integrations/bitsandbytes.py +16 -15
- transformers/integrations/deepspeed.py +18 -3
- transformers/integrations/eetq.py +3 -5
- transformers/integrations/fbgemm_fp8.py +1 -1
- transformers/integrations/finegrained_fp8.py +6 -16
- transformers/integrations/flash_attention.py +2 -2
- transformers/integrations/higgs.py +2 -5
- transformers/integrations/hub_kernels.py +23 -5
- transformers/integrations/integration_utils.py +35 -0
- transformers/integrations/mistral.py +12 -0
- transformers/integrations/moe.py +240 -0
- transformers/integrations/mxfp4.py +4 -10
- transformers/integrations/peft.py +5 -0
- transformers/integrations/quanto.py +5 -2
- transformers/integrations/spqr.py +3 -5
- transformers/integrations/tensor_parallel.py +167 -221
- transformers/integrations/vptq.py +3 -5
- transformers/modeling_gguf_pytorch_utils.py +66 -19
- transformers/modeling_rope_utils.py +78 -81
- transformers/modeling_utils.py +583 -503
- transformers/models/__init__.py +19 -0
- transformers/models/afmoe/modeling_afmoe.py +7 -16
- transformers/models/afmoe/modular_afmoe.py +5 -13
- transformers/models/aimv2/modeling_aimv2.py +4 -0
- transformers/models/aimv2/modular_aimv2.py +4 -0
- transformers/models/albert/modeling_albert.py +3 -0
- transformers/models/align/modeling_align.py +12 -6
- transformers/models/altclip/modeling_altclip.py +7 -3
- transformers/models/apertus/modeling_apertus.py +4 -2
- transformers/models/apertus/modular_apertus.py +4 -1
- transformers/models/arcee/modeling_arcee.py +1 -1
- transformers/models/aria/modeling_aria.py +8 -4
- transformers/models/aria/modular_aria.py +7 -3
- transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
- transformers/models/auto/auto_factory.py +1 -1
- transformers/models/auto/configuration_auto.py +27 -0
- transformers/models/auto/feature_extraction_auto.py +7 -3
- transformers/models/auto/image_processing_auto.py +4 -2
- transformers/models/auto/modeling_auto.py +31 -0
- transformers/models/auto/processing_auto.py +4 -0
- transformers/models/auto/tokenization_auto.py +132 -153
- transformers/models/auto/video_processing_auto.py +5 -2
- transformers/models/aya_vision/modeling_aya_vision.py +7 -3
- transformers/models/bamba/modeling_bamba.py +18 -19
- transformers/models/bamba/modular_bamba.py +17 -16
- transformers/models/bark/modeling_bark.py +9 -0
- transformers/models/bart/configuration_bart.py +0 -1
- transformers/models/bart/modeling_bart.py +7 -0
- transformers/models/beit/image_processing_beit_fast.py +0 -1
- transformers/models/bert/modeling_bert.py +3 -0
- transformers/models/bert_generation/modeling_bert_generation.py +2 -0
- transformers/models/big_bird/modeling_big_bird.py +3 -0
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
- transformers/models/bit/modeling_bit.py +5 -1
- transformers/models/bitnet/modeling_bitnet.py +1 -1
- transformers/models/blenderbot/modeling_blenderbot.py +7 -0
- transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
- transformers/models/blip/modeling_blip.py +2 -0
- transformers/models/blip/modeling_blip_text.py +8 -0
- transformers/models/blip_2/modeling_blip_2.py +2 -0
- transformers/models/bloom/modeling_bloom.py +13 -44
- transformers/models/blt/modeling_blt.py +162 -2
- transformers/models/blt/modular_blt.py +168 -3
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
- transformers/models/bridgetower/modeling_bridgetower.py +6 -0
- transformers/models/bros/modeling_bros.py +8 -0
- transformers/models/camembert/modeling_camembert.py +109 -106
- transformers/models/canine/modeling_canine.py +6 -0
- transformers/models/canine/tokenization_canine.py +2 -0
- transformers/models/chameleon/modeling_chameleon.py +9 -4
- transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
- transformers/models/clap/feature_extraction_clap.py +2 -2
- transformers/models/clap/modeling_clap.py +25 -15
- transformers/models/clip/modeling_clip.py +2 -0
- transformers/models/clipseg/modeling_clipseg.py +4 -0
- transformers/models/clvp/modeling_clvp.py +14 -3
- transformers/models/code_llama/tokenization_code_llama.py +1 -1
- transformers/models/codegen/modeling_codegen.py +13 -4
- transformers/models/cohere/modeling_cohere.py +1 -1
- transformers/models/cohere2/modeling_cohere2.py +1 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
- transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
- transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
- transformers/models/convbert/modeling_convbert.py +3 -0
- transformers/models/convnext/image_processing_convnext.py +2 -2
- transformers/models/convnext/image_processing_convnext_fast.py +9 -13
- transformers/models/csm/generation_csm.py +19 -22
- transformers/models/csm/modeling_csm.py +3 -1
- transformers/models/csm/modular_csm.py +2 -0
- transformers/models/ctrl/modeling_ctrl.py +14 -2
- transformers/models/cvt/modeling_cvt.py +5 -1
- transformers/models/cwm/modeling_cwm.py +1 -1
- transformers/models/d_fine/configuration_d_fine.py +3 -4
- transformers/models/d_fine/modeling_d_fine.py +46 -39
- transformers/models/d_fine/modular_d_fine.py +15 -4
- transformers/models/dab_detr/configuration_dab_detr.py +2 -2
- transformers/models/dab_detr/modeling_dab_detr.py +1 -1
- transformers/models/dac/modeling_dac.py +4 -4
- transformers/models/data2vec/modeling_data2vec_text.py +7 -0
- transformers/models/data2vec/modular_data2vec_text.py +7 -0
- transformers/models/dbrx/configuration_dbrx.py +9 -1
- transformers/models/dbrx/modeling_dbrx.py +1 -1
- transformers/models/deberta/modeling_deberta.py +2 -0
- transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
- transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
- transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
- transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
- transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
- transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
- transformers/models/depth_anything/configuration_depth_anything.py +2 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
- transformers/models/detr/configuration_detr.py +1 -1
- transformers/models/detr/modeling_detr.py +8 -1
- transformers/models/dia/generation_dia.py +3 -10
- transformers/models/dia/modeling_dia.py +12 -1
- transformers/models/dia/modular_dia.py +11 -0
- transformers/models/dia/processing_dia.py +1 -1
- transformers/models/diffllama/modeling_diffllama.py +3 -3
- transformers/models/diffllama/modular_diffllama.py +2 -2
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
- transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
- transformers/models/distilbert/modeling_distilbert.py +11 -9
- transformers/models/doge/modeling_doge.py +1 -1
- transformers/models/donut/image_processing_donut_fast.py +0 -1
- transformers/models/donut/modeling_donut_swin.py +16 -12
- transformers/models/dots1/modeling_dots1.py +14 -5
- transformers/models/dpt/configuration_dpt.py +1 -1
- transformers/models/dpt/image_processing_dpt_fast.py +1 -2
- transformers/models/dpt/modular_dpt.py +1 -2
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +5 -2
- transformers/models/edgetam/modular_edgetam.py +15 -14
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
- transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
- transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
- transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
- transformers/models/efficientnet/modeling_efficientnet.py +5 -1
- transformers/models/electra/modeling_electra.py +7 -0
- transformers/models/emu3/modeling_emu3.py +8 -2
- transformers/models/emu3/modular_emu3.py +7 -1
- transformers/models/encodec/modeling_encodec.py +14 -0
- transformers/models/eomt/image_processing_eomt_fast.py +46 -14
- transformers/models/eomt/modeling_eomt.py +7 -0
- transformers/models/eomt/modular_eomt.py +7 -0
- transformers/models/ernie/modeling_ernie.py +6 -0
- transformers/models/ernie/modular_ernie.py +6 -0
- transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
- transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
- transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
- transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
- transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
- transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
- transformers/models/esm/modeling_esm.py +6 -0
- transformers/models/esm/modeling_esmfold.py +6 -1
- transformers/models/evolla/modeling_evolla.py +9 -1
- transformers/models/evolla/modular_evolla.py +8 -0
- transformers/models/exaone4/modeling_exaone4.py +1 -1
- transformers/models/falcon/modeling_falcon.py +3 -3
- transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
- transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
- transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
- transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
- transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
- transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
- transformers/models/flaubert/modeling_flaubert.py +14 -15
- transformers/models/flava/image_processing_flava_fast.py +0 -2
- transformers/models/flava/modeling_flava.py +4 -1
- transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
- transformers/models/florence2/modeling_florence2.py +20 -3
- transformers/models/florence2/modular_florence2.py +13 -0
- transformers/models/fnet/modeling_fnet.py +7 -0
- transformers/models/fuyu/image_processing_fuyu.py +1 -1
- transformers/models/fuyu/modeling_fuyu.py +3 -1
- transformers/models/fuyu/processing_fuyu.py +16 -0
- transformers/models/gemma/modeling_gemma.py +10 -12
- transformers/models/gemma/modular_gemma.py +9 -11
- transformers/models/gemma2/modeling_gemma2.py +1 -1
- transformers/models/gemma2/modular_gemma2.py +1 -1
- transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
- transformers/models/gemma3/modeling_gemma3.py +28 -7
- transformers/models/gemma3/modular_gemma3.py +26 -6
- transformers/models/gemma3n/configuration_gemma3n.py +3 -0
- transformers/models/gemma3n/modeling_gemma3n.py +47 -9
- transformers/models/gemma3n/modular_gemma3n.py +51 -9
- transformers/models/git/modeling_git.py +181 -126
- transformers/models/glm/modeling_glm.py +1 -1
- transformers/models/glm4/modeling_glm4.py +1 -1
- transformers/models/glm46v/image_processing_glm46v.py +0 -4
- transformers/models/glm46v/modeling_glm46v.py +3 -1
- transformers/models/glm46v/modular_glm46v.py +3 -0
- transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
- transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
- transformers/models/glm4v/image_processing_glm4v.py +0 -4
- transformers/models/glm4v/modeling_glm4v.py +15 -5
- transformers/models/glm4v/modular_glm4v.py +11 -3
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
- transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
- transformers/models/glmasr/__init__.py +30 -0
- transformers/models/glmasr/configuration_glmasr.py +197 -0
- transformers/models/glmasr/modeling_glmasr.py +512 -0
- transformers/models/glmasr/modular_glmasr.py +433 -0
- transformers/models/glmasr/processing_glmasr.py +332 -0
- transformers/models/glpn/image_processing_glpn_fast.py +0 -1
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
- transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
- transformers/models/gpt2/modeling_gpt2.py +8 -5
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
- transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
- transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
- transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
- transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
- transformers/models/gptj/modeling_gptj.py +15 -6
- transformers/models/granite/modeling_granite.py +1 -1
- transformers/models/granite_speech/modeling_granite_speech.py +15 -1
- transformers/models/granitemoe/modeling_granitemoe.py +2 -3
- transformers/models/granitemoe/modular_granitemoe.py +1 -2
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
- transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
- transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
- transformers/models/groupvit/modeling_groupvit.py +6 -1
- transformers/models/helium/modeling_helium.py +1 -1
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
- transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
- transformers/models/hubert/modeling_hubert.py +4 -0
- transformers/models/hubert/modular_hubert.py +4 -0
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
- transformers/models/hunyuan_v1_moe/__init__.py +1 -1
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
- transformers/models/ibert/modeling_ibert.py +16 -0
- transformers/models/idefics/modeling_idefics.py +10 -0
- transformers/models/idefics2/modeling_idefics2.py +7 -1
- transformers/models/idefics3/modeling_idefics3.py +5 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
- transformers/models/imagegpt/modeling_imagegpt.py +9 -2
- transformers/models/instructblip/modeling_instructblip.py +2 -0
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
- transformers/models/internvl/modeling_internvl.py +11 -8
- transformers/models/internvl/modular_internvl.py +5 -9
- transformers/models/internvl/video_processing_internvl.py +0 -1
- transformers/models/jais2/__init__.py +27 -0
- transformers/models/jais2/configuration_jais2.py +152 -0
- transformers/models/jais2/modeling_jais2.py +486 -0
- transformers/models/jais2/modular_jais2.py +196 -0
- transformers/models/jamba/modeling_jamba.py +24 -19
- transformers/models/jamba/modular_jamba.py +17 -17
- transformers/models/janus/image_processing_janus_fast.py +0 -1
- transformers/models/janus/modeling_janus.py +15 -7
- transformers/models/janus/modular_janus.py +16 -7
- transformers/models/jetmoe/modeling_jetmoe.py +2 -2
- transformers/models/jetmoe/modular_jetmoe.py +1 -0
- transformers/models/kosmos2/modeling_kosmos2.py +14 -2
- transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
- transformers/models/lasr/configuration_lasr.py +4 -0
- transformers/models/lasr/modeling_lasr.py +3 -2
- transformers/models/lasr/modular_lasr.py +8 -1
- transformers/models/lasr/processing_lasr.py +0 -2
- transformers/models/layoutlm/modeling_layoutlm.py +5 -3
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
- transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
- transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
- transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
- transformers/models/led/modeling_led.py +6 -0
- transformers/models/levit/modeling_levit.py +18 -0
- transformers/models/lfm2/modeling_lfm2.py +1 -1
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
- transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
- transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
- transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
- transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
- transformers/models/lilt/modeling_lilt.py +19 -15
- transformers/models/llama/modeling_llama.py +1 -1
- transformers/models/llama4/image_processing_llama4_fast.py +1 -2
- transformers/models/llama4/modeling_llama4.py +8 -4
- transformers/models/llava/image_processing_llava_fast.py +0 -1
- transformers/models/llava/modeling_llava.py +12 -7
- transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
- transformers/models/llava_next/modeling_llava_next.py +7 -3
- transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
- transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
- transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
- transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
- transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
- transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
- transformers/models/longt5/modeling_longt5.py +0 -4
- transformers/models/m2m_100/modeling_m2m_100.py +10 -0
- transformers/models/mamba/modeling_mamba.py +2 -1
- transformers/models/mamba2/modeling_mamba2.py +24 -23
- transformers/models/marian/configuration_marian.py +1 -1
- transformers/models/marian/modeling_marian.py +3 -0
- transformers/models/markuplm/modeling_markuplm.py +5 -8
- transformers/models/mask2former/configuration_mask2former.py +3 -3
- transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
- transformers/models/mask2former/modeling_mask2former.py +9 -0
- transformers/models/maskformer/configuration_maskformer.py +3 -3
- transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
- transformers/models/maskformer/modeling_maskformer.py +9 -1
- transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
- transformers/models/mbart/configuration_mbart.py +1 -0
- transformers/models/mbart/modeling_mbart.py +7 -0
- transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
- transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
- transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
- transformers/models/mimi/modeling_mimi.py +25 -4
- transformers/models/minimax/modeling_minimax.py +16 -3
- transformers/models/minimax/modular_minimax.py +12 -1
- transformers/models/ministral/modeling_ministral.py +1 -1
- transformers/models/ministral3/modeling_ministral3.py +1 -1
- transformers/models/mistral/modeling_mistral.py +1 -1
- transformers/models/mistral3/modeling_mistral3.py +10 -4
- transformers/models/mistral3/modular_mistral3.py +3 -1
- transformers/models/mixtral/modeling_mixtral.py +12 -4
- transformers/models/mixtral/modular_mixtral.py +6 -2
- transformers/models/mlcd/modeling_mlcd.py +6 -0
- transformers/models/mlcd/modular_mlcd.py +4 -0
- transformers/models/mllama/modeling_mllama.py +13 -2
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
- transformers/models/mobilebert/modeling_mobilebert.py +2 -0
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
- transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
- transformers/models/mobilevit/modeling_mobilevit.py +4 -0
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
- transformers/models/modernbert/modeling_modernbert.py +12 -1
- transformers/models/modernbert/modular_modernbert.py +12 -1
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
- transformers/models/moonshine/modeling_moonshine.py +1 -1
- transformers/models/moshi/modeling_moshi.py +21 -51
- transformers/models/mpnet/modeling_mpnet.py +2 -0
- transformers/models/mra/modeling_mra.py +4 -1
- transformers/models/mt5/configuration_mt5.py +2 -3
- transformers/models/mt5/modeling_mt5.py +0 -10
- transformers/models/musicgen/modeling_musicgen.py +5 -9
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
- transformers/models/mvp/modeling_mvp.py +7 -0
- transformers/models/nanochat/modeling_nanochat.py +1 -1
- transformers/models/nemotron/modeling_nemotron.py +3 -3
- transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
- transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
- transformers/models/nougat/image_processing_nougat_fast.py +0 -1
- transformers/models/nougat/tokenization_nougat.py +11 -16
- transformers/models/nystromformer/modeling_nystromformer.py +7 -0
- transformers/models/olmo/modeling_olmo.py +1 -1
- transformers/models/olmo2/modeling_olmo2.py +1 -1
- transformers/models/olmo3/modeling_olmo3.py +1 -1
- transformers/models/olmoe/modeling_olmoe.py +12 -4
- transformers/models/olmoe/modular_olmoe.py +4 -2
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
- transformers/models/oneformer/configuration_oneformer.py +3 -3
- transformers/models/oneformer/modeling_oneformer.py +7 -38
- transformers/models/openai/modeling_openai.py +12 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
- transformers/models/ovis2/modeling_ovis2.py +15 -3
- transformers/models/ovis2/modular_ovis2.py +8 -0
- transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
- transformers/models/owlv2/modeling_owlv2.py +7 -3
- transformers/models/owlv2/modular_owlv2.py +0 -2
- transformers/models/owlvit/modeling_owlvit.py +7 -3
- transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
- transformers/models/paligemma/modeling_paligemma.py +25 -17
- transformers/models/parakeet/modeling_parakeet.py +5 -0
- transformers/models/parakeet/modular_parakeet.py +5 -0
- transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
- transformers/models/patchtst/modeling_patchtst.py +5 -4
- transformers/models/pe_audio/__init__.py +30 -0
- transformers/models/pe_audio/configuration_pe_audio.py +206 -0
- transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
- transformers/models/pe_audio/modeling_pe_audio.py +820 -0
- transformers/models/pe_audio/modular_pe_audio.py +299 -0
- transformers/models/pe_audio/processing_pe_audio.py +24 -0
- transformers/models/pe_audio_video/__init__.py +29 -0
- transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
- transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
- transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
- transformers/models/pe_video/__init__.py +30 -0
- transformers/models/pe_video/configuration_pe_video.py +211 -0
- transformers/models/pe_video/modeling_pe_video.py +636 -0
- transformers/models/pe_video/modular_pe_video.py +219 -0
- transformers/models/pe_video/processing_pe_video.py +10 -0
- transformers/models/pe_video/video_processing_pe_video.py +66 -0
- transformers/models/pegasus/configuration_pegasus.py +1 -0
- transformers/models/pegasus/modeling_pegasus.py +3 -0
- transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
- transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
- transformers/models/perceiver/modeling_perceiver.py +5 -1
- transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
- transformers/models/perception_lm/modeling_perception_lm.py +7 -3
- transformers/models/perception_lm/modular_perception_lm.py +7 -3
- transformers/models/persimmon/modeling_persimmon.py +1 -1
- transformers/models/phi/modeling_phi.py +1 -1
- transformers/models/phi3/modeling_phi3.py +1 -1
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
- transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
- transformers/models/phimoe/modeling_phimoe.py +12 -4
- transformers/models/phimoe/modular_phimoe.py +1 -1
- transformers/models/pix2struct/processing_pix2struct.py +0 -4
- transformers/models/pixio/__init__.py +30 -0
- transformers/models/pixio/configuration_pixio.py +151 -0
- transformers/models/pixio/modeling_pixio.py +507 -0
- transformers/models/pixio/modular_pixio.py +404 -0
- transformers/models/pixtral/modeling_pixtral.py +1 -1
- transformers/models/pixtral/processing_pixtral.py +3 -1
- transformers/models/plbart/configuration_plbart.py +1 -0
- transformers/models/plbart/modeling_plbart.py +7 -0
- transformers/models/plbart/modular_plbart.py +6 -0
- transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
- transformers/models/poolformer/modeling_poolformer.py +11 -1
- transformers/models/pop2piano/configuration_pop2piano.py +0 -1
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
- transformers/models/prophetnet/modeling_prophetnet.py +2 -1
- transformers/models/qwen2/modeling_qwen2.py +1 -1
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
- transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
- transformers/models/qwen3/modeling_qwen3.py +1 -1
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
- transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
- transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
- transformers/models/rag/configuration_rag.py +0 -8
- transformers/models/rag/modeling_rag.py +7 -9
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
- transformers/models/reformer/modeling_reformer.py +9 -1
- transformers/models/regnet/modeling_regnet.py +4 -0
- transformers/models/rembert/modeling_rembert.py +7 -1
- transformers/models/resnet/modeling_resnet.py +8 -3
- transformers/models/roberta/modeling_roberta.py +3 -0
- transformers/models/roberta/modular_roberta.py +3 -0
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
- transformers/models/roc_bert/modeling_roc_bert.py +3 -0
- transformers/models/rt_detr/configuration_rt_detr.py +1 -1
- transformers/models/rt_detr/modeling_rt_detr.py +4 -0
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
- transformers/models/rwkv/modeling_rwkv.py +1 -1
- transformers/models/sam/configuration_sam.py +1 -0
- transformers/models/sam/image_processing_sam_fast.py +0 -1
- transformers/models/sam/modeling_sam.py +4 -1
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +5 -1
- transformers/models/sam2/modular_sam2.py +5 -1
- transformers/models/sam2_video/modeling_sam2_video.py +51 -43
- transformers/models/sam2_video/modular_sam2_video.py +31 -18
- transformers/models/sam3/configuration_sam3.py +21 -1
- transformers/models/sam3/modeling_sam3.py +23 -0
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
- transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
- transformers/models/sam3_video/configuration_sam3_video.py +14 -0
- transformers/models/sam3_video/modeling_sam3_video.py +3 -3
- transformers/models/sam3_video/processing_sam3_video.py +1 -1
- transformers/models/sam_hq/configuration_sam_hq.py +1 -0
- transformers/models/sam_hq/modeling_sam_hq.py +26 -23
- transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
- transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
- transformers/models/seed_oss/modeling_seed_oss.py +1 -1
- transformers/models/segformer/image_processing_segformer_fast.py +0 -1
- transformers/models/segformer/modeling_segformer.py +2 -2
- transformers/models/segformer/modular_segformer.py +0 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
- transformers/models/siglip/modeling_siglip.py +24 -2
- transformers/models/siglip2/modeling_siglip2.py +63 -41
- transformers/models/smollm3/modeling_smollm3.py +1 -1
- transformers/models/smolvlm/modeling_smolvlm.py +5 -1
- transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
- transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
- transformers/models/speecht5/modeling_speecht5.py +28 -0
- transformers/models/splinter/modeling_splinter.py +9 -3
- transformers/models/squeezebert/modeling_squeezebert.py +2 -0
- transformers/models/stablelm/modeling_stablelm.py +1 -1
- transformers/models/starcoder2/modeling_starcoder2.py +1 -1
- transformers/models/superglue/image_processing_superglue_fast.py +1 -2
- transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
- transformers/models/swiftformer/modeling_swiftformer.py +4 -0
- transformers/models/swin/modeling_swin.py +16 -12
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
- transformers/models/swin2sr/modeling_swin2sr.py +49 -33
- transformers/models/swinv2/modeling_swinv2.py +41 -33
- transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
- transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
- transformers/models/t5/configuration_t5.py +7 -1
- transformers/models/t5/modeling_t5.py +1 -7
- transformers/models/t5gemma/modeling_t5gemma.py +1 -1
- transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
- transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
- transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
- transformers/models/table_transformer/configuration_table_transformer.py +1 -1
- transformers/models/table_transformer/modeling_table_transformer.py +1 -1
- transformers/models/textnet/image_processing_textnet_fast.py +0 -1
- transformers/models/timesfm/modeling_timesfm.py +12 -0
- transformers/models/timesfm/modular_timesfm.py +12 -0
- transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
- transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
- transformers/models/trocr/modeling_trocr.py +1 -2
- transformers/models/tvp/configuration_tvp.py +5 -1
- transformers/models/tvp/modeling_tvp.py +4 -4
- transformers/models/udop/configuration_udop.py +1 -0
- transformers/models/udop/modeling_udop.py +3 -7
- transformers/models/umt5/configuration_umt5.py +2 -2
- transformers/models/umt5/modeling_umt5.py +0 -6
- transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
- transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
- transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
- transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
- transformers/models/video_llava/modeling_video_llava.py +7 -3
- transformers/models/vilt/configuration_vilt.py +2 -2
- transformers/models/vilt/modeling_vilt.py +7 -0
- transformers/models/vipllava/modeling_vipllava.py +7 -3
- transformers/models/visual_bert/modeling_visual_bert.py +2 -0
- transformers/models/vitmatte/configuration_vitmatte.py +1 -1
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
- transformers/models/vitmatte/modeling_vitmatte.py +4 -0
- transformers/models/vitpose/configuration_vitpose.py +1 -1
- transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +2 -2
- transformers/models/voxtral/modular_voxtral.py +2 -2
- transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
- transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
- transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
- transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
- transformers/models/whisper/generation_whisper.py +1 -0
- transformers/models/whisper/modeling_whisper.py +5 -3
- transformers/models/x_clip/modeling_x_clip.py +2 -0
- transformers/models/xcodec/modeling_xcodec.py +5 -0
- transformers/models/xglm/modeling_xglm.py +10 -0
- transformers/models/xlm/modeling_xlm.py +13 -14
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
- transformers/models/xlnet/modeling_xlnet.py +3 -1
- transformers/models/xmod/modeling_xmod.py +3 -0
- transformers/models/yoso/modeling_yoso.py +4 -1
- transformers/models/zamba/modeling_zamba.py +2 -1
- transformers/models/zamba2/modeling_zamba2.py +3 -2
- transformers/models/zoedepth/configuration_zoedepth.py +1 -1
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
- transformers/models/zoedepth/modeling_zoedepth.py +7 -0
- transformers/pipelines/__init__.py +9 -6
- transformers/pipelines/automatic_speech_recognition.py +20 -12
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/document_question_answering.py +1 -1
- transformers/pipelines/question_answering.py +1 -1
- transformers/pipelines/text_to_audio.py +2 -2
- transformers/processing_utils.py +127 -56
- transformers/quantizers/auto.py +2 -4
- transformers/quantizers/base.py +9 -64
- transformers/quantizers/quantizer_aqlm.py +1 -18
- transformers/quantizers/quantizer_auto_round.py +1 -10
- transformers/quantizers/quantizer_awq.py +3 -8
- transformers/quantizers/quantizer_bitnet.py +1 -6
- transformers/quantizers/quantizer_bnb_4bit.py +9 -49
- transformers/quantizers/quantizer_bnb_8bit.py +9 -19
- transformers/quantizers/quantizer_compressed_tensors.py +1 -4
- transformers/quantizers/quantizer_eetq.py +2 -12
- transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
- transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
- transformers/quantizers/quantizer_fp_quant.py +4 -4
- transformers/quantizers/quantizer_gptq.py +1 -4
- transformers/quantizers/quantizer_higgs.py +2 -6
- transformers/quantizers/quantizer_mxfp4.py +2 -28
- transformers/quantizers/quantizer_quanto.py +14 -14
- transformers/quantizers/quantizer_spqr.py +3 -8
- transformers/quantizers/quantizer_torchao.py +28 -124
- transformers/quantizers/quantizer_vptq.py +1 -10
- transformers/testing_utils.py +28 -12
- transformers/tokenization_mistral_common.py +3 -2
- transformers/tokenization_utils_base.py +3 -2
- transformers/tokenization_utils_tokenizers.py +25 -2
- transformers/trainer.py +24 -2
- transformers/trainer_callback.py +8 -0
- transformers/trainer_seq2seq.py +4 -0
- transformers/training_args.py +8 -10
- transformers/utils/__init__.py +4 -0
- transformers/utils/attention_visualizer.py +4 -4
- transformers/utils/auto_docstring.py +34 -25
- transformers/utils/generic.py +20 -0
- transformers/utils/import_utils.py +51 -9
- transformers/utils/kernel_config.py +71 -18
- transformers/utils/quantization_config.py +8 -8
- transformers/video_processing_utils.py +16 -12
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -259,7 +259,7 @@ class ContinuousBatchProcessor:
|
|
|
259
259
|
self.cumulative_seqlens_q = torch.empty((self.max_batch_tokens + 1,), **self.tensor_metadata)
|
|
260
260
|
self.max_seqlen_q = 0
|
|
261
261
|
self.logits_indices = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
|
|
262
|
-
self.output_ids = torch.empty((
|
|
262
|
+
self.output_ids = torch.empty((self.max_batch_tokens,), **self.tensor_metadata)
|
|
263
263
|
|
|
264
264
|
# For some kwargs, we have a dict of tensors with as many items as there are attention types
|
|
265
265
|
layer_types = getattr(self.config, "layer_types", None)
|
|
@@ -311,7 +311,7 @@ class ContinuousBatchProcessor:
|
|
|
311
311
|
self.cumulative_seqlens_q[: b_size + 1].zero_()
|
|
312
312
|
self.max_seqlen_q = 0
|
|
313
313
|
self.logits_indices[:q_len].fill_(-1)
|
|
314
|
-
self.output_ids[
|
|
314
|
+
self.output_ids[:q_len].fill_(-1)
|
|
315
315
|
|
|
316
316
|
# Reset the attributes that are either tensors or dict of tensors
|
|
317
317
|
for layer_type in self.cumulative_seqlens_k:
|
|
@@ -447,7 +447,7 @@ class ContinuousBatchProcessor:
|
|
|
447
447
|
self.metrics.record_batch_metrics(self.requests_in_batch)
|
|
448
448
|
|
|
449
449
|
# Reset the static tensors used for storage
|
|
450
|
-
self.reset_static_tensors() #
|
|
450
|
+
self.reset_static_tensors() # FIXME: why does this make the generation faster?
|
|
451
451
|
|
|
452
452
|
# Prepare accumulators
|
|
453
453
|
self.actual_query_length = 0
|
|
@@ -557,13 +557,10 @@ class ContinuousBatchProcessor:
|
|
|
557
557
|
self.actual_index_sizes[i] = (len(group_read_indices), len(group_write_indices))
|
|
558
558
|
|
|
559
559
|
@traced
|
|
560
|
-
def
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
except Exception:
|
|
565
|
-
return [0, 1]
|
|
566
|
-
return [0, 0]
|
|
560
|
+
def _get_new_tokens(self, num_new_tokens: int) -> list[int]:
|
|
561
|
+
indices = self.logits_indices[:num_new_tokens]
|
|
562
|
+
new_tokens = self.output_ids[indices]
|
|
563
|
+
return new_tokens.tolist()
|
|
567
564
|
|
|
568
565
|
@traced
|
|
569
566
|
def _maybe_send_output(self, state: RequestState) -> None:
|
|
@@ -574,29 +571,56 @@ class ContinuousBatchProcessor:
|
|
|
574
571
|
@traced
|
|
575
572
|
def update_batch(self) -> None:
|
|
576
573
|
"""Update request states based on generated tokens."""
|
|
577
|
-
|
|
578
|
-
|
|
574
|
+
new_tokens = self._get_new_tokens(len(self.requests_in_batch))
|
|
575
|
+
current_logits_index = 0
|
|
576
|
+
for state in self.requests_in_batch:
|
|
579
577
|
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
|
|
580
578
|
if len(state.remaining_prefill_tokens) == 0:
|
|
581
|
-
|
|
582
|
-
state.
|
|
583
|
-
|
|
579
|
+
# If there are no generated tokens yet, it means prefill just ended
|
|
580
|
+
if state.generated_len() == 0:
|
|
581
|
+
self.metrics.record_ttft_metric(state.created_time, state.request_id)
|
|
582
|
+
state.status = RequestStatus.DECODING
|
|
583
|
+
|
|
584
|
+
token = new_tokens[current_logits_index]
|
|
584
585
|
state.tokens_to_process = [token]
|
|
586
|
+
current_logits_index += 1
|
|
587
|
+
|
|
585
588
|
# Update the request and stop if it is complete
|
|
586
589
|
is_finished = state.update_and_check_completion(token)
|
|
587
590
|
# We mark the completed blocks as such
|
|
588
|
-
self.cache.
|
|
591
|
+
self.cache.mark_shareable_blocks_as_complete(state)
|
|
589
592
|
if is_finished:
|
|
590
593
|
self.metrics.record_request_completion(state.created_time, state.request_id)
|
|
591
594
|
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
|
|
592
595
|
self._maybe_send_output(state)
|
|
593
596
|
# Otherwise, the request is still prefilling, but the prefill has been split
|
|
594
597
|
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
|
595
|
-
self.cache.
|
|
598
|
+
self.cache.mark_shareable_blocks_as_complete(state)
|
|
596
599
|
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
|
597
600
|
else:
|
|
598
601
|
raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
|
|
599
602
|
|
|
603
|
+
# If some requests need to be forked, we do it now
|
|
604
|
+
copy_source, copy_destination = [], []
|
|
605
|
+
while self.scheduler._requests_to_fork:
|
|
606
|
+
# Get the number of children and reset it so it's not forked again
|
|
607
|
+
state = self.scheduler._requests_to_fork.pop()
|
|
608
|
+
num_children = state.num_children
|
|
609
|
+
state.num_children = 0
|
|
610
|
+
# Create the new request and add them to the scheduler
|
|
611
|
+
new_request_ids = [f"{state.request_id}__child#{i}" for i in range(num_children)]
|
|
612
|
+
for new_request_id in new_request_ids:
|
|
613
|
+
self.scheduler.active_requests[new_request_id] = state.fork(new_request_id)
|
|
614
|
+
# Fork the cache
|
|
615
|
+
copy_src, copy_dst = self.cache.fork_request(state.request_id, new_request_ids)
|
|
616
|
+
copy_source.extend(copy_src)
|
|
617
|
+
copy_destination.extend(copy_dst)
|
|
618
|
+
# FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything
|
|
619
|
+
|
|
620
|
+
# The copy induced by the fork is done in one go (if it's even needed)
|
|
621
|
+
if copy_source:
|
|
622
|
+
self.cache.copy_cache(copy_source, copy_destination)
|
|
623
|
+
|
|
600
624
|
if self.cache.get_num_free_blocks() == 0:
|
|
601
625
|
raise ValueError("No more free blocks")
|
|
602
626
|
|
|
@@ -727,12 +751,11 @@ class ContinuousBatchProcessor:
|
|
|
727
751
|
probs = nn.functional.softmax(probs, dim=-1)
|
|
728
752
|
# probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
|
|
729
753
|
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
|
|
730
|
-
# Add batch dimension back to match argmax output
|
|
731
|
-
next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
|
|
732
754
|
else:
|
|
733
|
-
next_tokens = torch.argmax(probs, dim=-1) #
|
|
734
|
-
|
|
735
|
-
|
|
755
|
+
next_tokens = torch.argmax(probs, dim=-1) # shape is [1, seq_len]
|
|
756
|
+
next_tokens = next_tokens.squeeze(0) # shape is [seq_len]
|
|
757
|
+
tokens = next_tokens.size(0) # Get seq_len dimension
|
|
758
|
+
self.output_ids[:tokens].copy_(next_tokens)
|
|
736
759
|
|
|
737
760
|
|
|
738
761
|
# Manager Class (User Interface)
|
|
@@ -752,7 +775,7 @@ class ContinuousBatchingManager:
|
|
|
752
775
|
max_queue_size: int = 0,
|
|
753
776
|
num_q_padding_intervals: int = 0,
|
|
754
777
|
num_kv_padding_intervals: int = 0,
|
|
755
|
-
|
|
778
|
+
allow_block_sharing: bool = True,
|
|
756
779
|
) -> None:
|
|
757
780
|
"""Initialize the continuous batching manager.
|
|
758
781
|
|
|
@@ -762,30 +785,37 @@ class ContinuousBatchingManager:
|
|
|
762
785
|
max_queue_size: Maximum size of the request queue (0 = unlimited)
|
|
763
786
|
num_q_padding_intervals: (optional) Number of intervals used to pad the query dimension
|
|
764
787
|
num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension
|
|
765
|
-
|
|
788
|
+
allow_block_sharing: (optional) Whether to allow block sharing if the model has some full attention layers
|
|
766
789
|
"""
|
|
767
|
-
#
|
|
790
|
+
# Reload paged version of the attention implementation if necessary
|
|
768
791
|
if "paged|" not in model.config._attn_implementation:
|
|
769
792
|
model.set_attn_implementation(f"paged|{model.config._attn_implementation}")
|
|
770
793
|
|
|
794
|
+
# Internal arguments
|
|
771
795
|
self.model = model.eval()
|
|
772
|
-
|
|
773
|
-
self.
|
|
796
|
+
self.manual_eviction = manual_eviction
|
|
797
|
+
self._allow_block_sharing = allow_block_sharing
|
|
798
|
+
self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created
|
|
799
|
+
|
|
774
800
|
self.input_queue = queue.Queue(maxsize=max_queue_size)
|
|
775
801
|
self.output_queue = queue.Queue()
|
|
776
802
|
self.stop_event = threading.Event()
|
|
777
|
-
self.
|
|
803
|
+
self.batch_processor: ContinuousBatchProcessor | None = None
|
|
778
804
|
self._generation_thread = None
|
|
779
805
|
self._request_counter = 0
|
|
780
806
|
self._request_lock = threading.Lock()
|
|
781
|
-
|
|
807
|
+
|
|
808
|
+
# Generation config related arguments
|
|
809
|
+
generation_config = model.generation_config if generation_config is None else generation_config
|
|
810
|
+
self.generation_config = generation_config
|
|
811
|
+
self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
|
|
782
812
|
self.do_sample = getattr(generation_config, "do_sample", True)
|
|
783
813
|
self.logit_processor = self.model._get_logits_processor(generation_config)
|
|
784
|
-
self.
|
|
785
|
-
|
|
786
|
-
self.
|
|
787
|
-
self._allow_prefix_sharing = allow_prefix_sharing
|
|
814
|
+
self.num_return_sequences = getattr(generation_config, "num_return_sequences", 1)
|
|
815
|
+
|
|
816
|
+
# self.model.generation_config.top_p = None NOTE: figure out why this was here
|
|
788
817
|
|
|
818
|
+
# Cuda graph behavior is determined below using either user-specified arguments or heuristics
|
|
789
819
|
self.use_cuda_graph = self._decide_use_cuda_graphs(
|
|
790
820
|
use_cuda_graph=getattr(generation_config, "use_cuda_graph", None),
|
|
791
821
|
num_q_padding_intervals=num_q_padding_intervals,
|
|
@@ -799,6 +829,7 @@ class ContinuousBatchingManager:
|
|
|
799
829
|
num_kv_padding_intervals if num_kv_padding_intervals > 0 else NUM_KV_PADDING_INTERVALS
|
|
800
830
|
)
|
|
801
831
|
|
|
832
|
+
# Log probability generation is not supported yet (TODO)
|
|
802
833
|
if self.log_prob_generation:
|
|
803
834
|
raise NotImplementedError("log_prob_generation is not supported yet")
|
|
804
835
|
|
|
@@ -932,6 +963,7 @@ class ContinuousBatchingManager:
|
|
|
932
963
|
state = RequestState(
|
|
933
964
|
request_id=request_id,
|
|
934
965
|
initial_tokens=list(input_ids),
|
|
966
|
+
num_children=self.num_return_sequences - 1,
|
|
935
967
|
record_timestamps=record_timestamps,
|
|
936
968
|
tokens_to_process=list(input_ids),
|
|
937
969
|
max_new_tokens=max_new_tokens,
|
|
@@ -950,6 +982,10 @@ class ContinuousBatchingManager:
|
|
|
950
982
|
streaming: bool = False,
|
|
951
983
|
record_timestamps: bool = False,
|
|
952
984
|
) -> None:
|
|
985
|
+
# If there is prefix sharing, we sort the inputs to maximize cache hits
|
|
986
|
+
if self._use_prefix_sharing:
|
|
987
|
+
inputs = sorted(inputs, reverse=True)
|
|
988
|
+
# Add requests in order
|
|
953
989
|
for input_ids in inputs:
|
|
954
990
|
self.add_request(
|
|
955
991
|
input_ids, max_new_tokens=max_new_tokens, streaming=streaming, record_timestamps=record_timestamps
|
|
@@ -1020,8 +1056,9 @@ class ContinuousBatchingManager:
|
|
|
1020
1056
|
self.model.device,
|
|
1021
1057
|
self.model.dtype,
|
|
1022
1058
|
tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
|
|
1023
|
-
|
|
1059
|
+
allow_block_sharing=self._allow_block_sharing,
|
|
1024
1060
|
)
|
|
1061
|
+
self._use_prefix_sharing = paged_attention_cache.use_prefix_sharing # update the approximation
|
|
1025
1062
|
logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")
|
|
1026
1063
|
|
|
1027
1064
|
scheduler = None
|
|
@@ -1080,10 +1117,6 @@ class ContinuousBatchingManager:
|
|
|
1080
1117
|
)
|
|
1081
1118
|
|
|
1082
1119
|
self._generation_step()
|
|
1083
|
-
|
|
1084
|
-
if torch.cuda.is_available():
|
|
1085
|
-
torch.cuda.synchronize() # FIXME: why is this needed?
|
|
1086
|
-
# Processor updates the batch after generation step is truly over
|
|
1087
1120
|
batch_processor.update_batch()
|
|
1088
1121
|
|
|
1089
1122
|
@traced
|
|
@@ -1125,7 +1158,7 @@ class ContinuousMixin:
|
|
|
1125
1158
|
max_queue_size: int = 0,
|
|
1126
1159
|
num_q_cuda_graphs: int = 0,
|
|
1127
1160
|
num_kv_cuda_graphs: int = 0,
|
|
1128
|
-
|
|
1161
|
+
allow_block_sharing: bool = True,
|
|
1129
1162
|
block: bool = True,
|
|
1130
1163
|
timeout: float | None = None,
|
|
1131
1164
|
) -> Generator[ContinuousBatchingManager]:
|
|
@@ -1135,7 +1168,7 @@ class ContinuousMixin:
|
|
|
1135
1168
|
max_queue_size,
|
|
1136
1169
|
num_q_cuda_graphs,
|
|
1137
1170
|
num_kv_cuda_graphs,
|
|
1138
|
-
|
|
1171
|
+
allow_block_sharing,
|
|
1139
1172
|
)
|
|
1140
1173
|
manager.start()
|
|
1141
1174
|
try:
|
|
@@ -1154,7 +1187,7 @@ class ContinuousMixin:
|
|
|
1154
1187
|
max_queue_size: int = 0,
|
|
1155
1188
|
num_q_padding_intervals: int = 0,
|
|
1156
1189
|
num_kv_padding_intervals: int = 0,
|
|
1157
|
-
|
|
1190
|
+
allow_block_sharing: bool = True,
|
|
1158
1191
|
) -> ContinuousBatchingManager:
|
|
1159
1192
|
"""Initialize a manager for continuous batching inference.
|
|
1160
1193
|
|
|
@@ -1164,7 +1197,7 @@ class ContinuousMixin:
|
|
|
1164
1197
|
max_queue_size: Maximum size of the input request queue
|
|
1165
1198
|
num_q_padding_intervals: Number of intervals used to pad the query dimension
|
|
1166
1199
|
num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
|
|
1167
|
-
|
|
1200
|
+
allow_block_sharing: A flag to allow block sharing if the model has some full attention layers
|
|
1168
1201
|
|
|
1169
1202
|
Returns:
|
|
1170
1203
|
`ContinuousBatchingManager`: The manager instance to add requests and retrieve results.
|
|
@@ -1188,7 +1221,7 @@ class ContinuousMixin:
|
|
|
1188
1221
|
max_queue_size=max_queue_size,
|
|
1189
1222
|
num_q_padding_intervals=num_q_padding_intervals,
|
|
1190
1223
|
num_kv_padding_intervals=num_kv_padding_intervals,
|
|
1191
|
-
|
|
1224
|
+
allow_block_sharing=allow_block_sharing,
|
|
1192
1225
|
)
|
|
1193
1226
|
|
|
1194
1227
|
# TODO: support streaming
|
|
@@ -1200,7 +1233,7 @@ class ContinuousMixin:
|
|
|
1200
1233
|
generation_config: GenerationConfig | None = None,
|
|
1201
1234
|
num_q_padding_intervals: int = 0,
|
|
1202
1235
|
num_kv_padding_intervals: int = 0,
|
|
1203
|
-
|
|
1236
|
+
allow_block_sharing: bool = True,
|
|
1204
1237
|
record_timestamps: bool = False,
|
|
1205
1238
|
progress_bar: bool = True,
|
|
1206
1239
|
**kwargs,
|
|
@@ -1212,7 +1245,7 @@ class ContinuousMixin:
|
|
|
1212
1245
|
generation_config: Optional generation configuration
|
|
1213
1246
|
num_q_padding_intervals: Number of intervals used to pad the query dimension
|
|
1214
1247
|
num_kv_padding_intervals: Number of intervals used to pad the keys/values dimension
|
|
1215
|
-
|
|
1248
|
+
allow_block_sharing: A flag to allow block sharing if the model has some full attention layers
|
|
1216
1249
|
record_timestamps: If set to true, the requests will have a timestamp for each token generated
|
|
1217
1250
|
progress_bar: If set to true, a progress bar will be displayed
|
|
1218
1251
|
**kwargs: Additional generation parameters
|
|
@@ -1228,26 +1261,30 @@ class ContinuousMixin:
|
|
|
1228
1261
|
|
|
1229
1262
|
# Initialize manager with the batch inputs
|
|
1230
1263
|
results = {}
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
)
|
|
1264
|
+
gen_cfg = self.generation_config if generation_config is None else generation_config
|
|
1265
|
+
num_requests = len(inputs) * gen_cfg.num_return_sequences
|
|
1266
|
+
# Prepare context managers for the main loop
|
|
1267
|
+
manager_cm = self.continuous_batching_context_manager(
|
|
1268
|
+
generation_config=generation_config,
|
|
1269
|
+
num_q_cuda_graphs=num_q_padding_intervals,
|
|
1270
|
+
num_kv_cuda_graphs=num_kv_padding_intervals,
|
|
1271
|
+
allow_block_sharing=allow_block_sharing,
|
|
1272
|
+
block=True,
|
|
1273
|
+
timeout=5,
|
|
1274
|
+
)
|
|
1275
|
+
logging_cm = logging_redirect_tqdm([logger])
|
|
1276
|
+
pbar_cm = tqdm(
|
|
1277
|
+
total=num_requests,
|
|
1278
|
+
disable=(not progress_bar),
|
|
1279
|
+
desc=f"Solving {num_requests} requests",
|
|
1280
|
+
unit="request",
|
|
1281
|
+
)
|
|
1282
|
+
# Main loop
|
|
1283
|
+
with manager_cm as manager, logging_cm, pbar_cm as pbar:
|
|
1249
1284
|
try:
|
|
1250
|
-
manager.add_requests(
|
|
1285
|
+
manager.add_requests(
|
|
1286
|
+
inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"), record_timestamps=record_timestamps
|
|
1287
|
+
)
|
|
1251
1288
|
finished_count = 0
|
|
1252
1289
|
while finished_count < num_requests:
|
|
1253
1290
|
result = manager.get_result(timeout=1)
|
|
@@ -101,6 +101,8 @@ class RequestState:
|
|
|
101
101
|
|
|
102
102
|
Attributes:
|
|
103
103
|
request_id (str): The ID of the generation request.
|
|
104
|
+
initial_tokens (list[int]): The initial prompt tokens.
|
|
105
|
+
num_children (int): The number of children requests
|
|
104
106
|
full_prompt_ids (list[int] | None): The tokens IDs of the full prompt.
|
|
105
107
|
prompt_ids (list[int] | None): The tokens IDs currently being processed.
|
|
106
108
|
remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests).
|
|
@@ -121,6 +123,7 @@ class RequestState:
|
|
|
121
123
|
initial_tokens: list[int] # Initial prompt tokens
|
|
122
124
|
# Optional fields
|
|
123
125
|
record_timestamps: bool = False # Whether to record timestamps for the generated tokens
|
|
126
|
+
num_children: int = 0 # Number of children requests
|
|
124
127
|
# Internal fields
|
|
125
128
|
tokens_to_process: list[int] | None = None # Tokens IDs currently being processed
|
|
126
129
|
remaining_prefill_tokens: list[int] = field(default_factory=list) # For split requests, prefill left to process
|
|
@@ -181,7 +184,7 @@ class RequestState:
|
|
|
181
184
|
Returns:
|
|
182
185
|
bool: True if the request is now complete, False otherwise
|
|
183
186
|
"""
|
|
184
|
-
# Only update if we're in decoding state
|
|
187
|
+
# Only update if we're in decoding state # TODO: seems useless (always true) -- remove this
|
|
185
188
|
if self.status != RequestStatus.DECODING:
|
|
186
189
|
return False
|
|
187
190
|
|
|
@@ -227,3 +230,27 @@ class RequestState:
|
|
|
227
230
|
error=self.error,
|
|
228
231
|
timestamps=self.timestamps,
|
|
229
232
|
)
|
|
233
|
+
|
|
234
|
+
def fork(self, new_request_id: str) -> "RequestState":
|
|
235
|
+
"""Fork the request into a new request with the same state expect for request_id, created_time and lifespan."""
|
|
236
|
+
t = time.perf_counter()
|
|
237
|
+
new_request = RequestState(
|
|
238
|
+
request_id=new_request_id,
|
|
239
|
+
initial_tokens=self.initial_tokens,
|
|
240
|
+
num_children=self.num_children,
|
|
241
|
+
tokens_to_process=self.tokens_to_process[:],
|
|
242
|
+
remaining_prefill_tokens=self.remaining_prefill_tokens[:],
|
|
243
|
+
generated_tokens=self.generated_tokens[:],
|
|
244
|
+
allocated_blocks=self.allocated_blocks,
|
|
245
|
+
position_offset=self.position_offset,
|
|
246
|
+
status=self.status,
|
|
247
|
+
max_new_tokens=self.max_new_tokens,
|
|
248
|
+
eos_token_id=self.eos_token_id,
|
|
249
|
+
streaming=self.streaming,
|
|
250
|
+
created_time=t,
|
|
251
|
+
lifespan=(t, -1),
|
|
252
|
+
timestamps=None if self.timestamps is None else self.timestamps[:],
|
|
253
|
+
error=self.error,
|
|
254
|
+
record_timestamps=self.record_timestamps,
|
|
255
|
+
)
|
|
256
|
+
return new_request
|
|
@@ -36,6 +36,7 @@ class Scheduler(ABC):
|
|
|
36
36
|
self.retain_cache_on_finish = retain_cache_on_finish
|
|
37
37
|
self._cancellation_lock = threading.Lock()
|
|
38
38
|
self._requests_to_cancel: set[str] = set()
|
|
39
|
+
self._requests_to_fork: list[RequestState] = []
|
|
39
40
|
|
|
40
41
|
@traced
|
|
41
42
|
def add_waiting_request(self, state: RequestState):
|
|
@@ -151,8 +152,13 @@ class Scheduler(ABC):
|
|
|
151
152
|
else:
|
|
152
153
|
request_tokens = state.tokens_to_process
|
|
153
154
|
|
|
155
|
+
# If the request has one or more children we make sure not to prefill it entrirely
|
|
156
|
+
if state.num_children > 0 and token_budget >= len(request_tokens) - 1:
|
|
157
|
+
token_budget = len(request_tokens) - 1
|
|
158
|
+
self._requests_to_fork.append(state)
|
|
159
|
+
|
|
160
|
+
# Case: we can process the entire prompt/remainder
|
|
154
161
|
if len(request_tokens) < token_budget:
|
|
155
|
-
# Can process the entire prompt/remainder
|
|
156
162
|
if state.status == RequestStatus.PENDING:
|
|
157
163
|
self.active_requests[state.request_id] = state
|
|
158
164
|
state.status = RequestStatus.PREFILLING
|
|
@@ -161,8 +167,9 @@ class Scheduler(ABC):
|
|
|
161
167
|
state.status = RequestStatus.PREFILLING
|
|
162
168
|
state.tokens_to_process = state.remaining_prefill_tokens
|
|
163
169
|
state.remaining_prefill_tokens = []
|
|
170
|
+
|
|
171
|
+
# Otherwise: we need to split the request
|
|
164
172
|
else:
|
|
165
|
-
# Need to split the request
|
|
166
173
|
if state.status == RequestStatus.PENDING:
|
|
167
174
|
self.active_requests[state.request_id] = state
|
|
168
175
|
state.status = RequestStatus.PREFILLING_SPLIT
|
|
@@ -229,7 +236,7 @@ class FIFOScheduler(Scheduler):
|
|
|
229
236
|
# Update the token budget
|
|
230
237
|
token_budget -= request_len
|
|
231
238
|
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
|
232
|
-
if self.cache.
|
|
239
|
+
if self.cache.allow_block_sharing:
|
|
233
240
|
tokens_in_current_block = state.current_len() % self.cache.block_size
|
|
234
241
|
tokens_after_forward = tokens_in_current_block + request_len
|
|
235
242
|
complete_blocks = tokens_after_forward // self.cache.block_size
|
|
@@ -295,7 +302,7 @@ class PrefillFirstScheduler(Scheduler):
|
|
|
295
302
|
# Update the token budget
|
|
296
303
|
token_budget -= request_len
|
|
297
304
|
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
|
298
|
-
if self.cache.
|
|
305
|
+
if self.cache.allow_block_sharing:
|
|
299
306
|
tokens_in_current_block = state.current_len() % self.cache.block_size
|
|
300
307
|
tokens_after_forward = tokens_in_current_block + request_len
|
|
301
308
|
complete_blocks = tokens_after_forward // self.cache.block_size
|
|
@@ -430,7 +430,7 @@ class StopStringCriteria(StoppingCriteria):
|
|
|
430
430
|
initial_match = end_lengths > 0
|
|
431
431
|
|
|
432
432
|
# Tokens continue the string if the cumsum() so far is one of the valid positions for that token
|
|
433
|
-
# Note that we're actually tracking one cumsum() for
|
|
433
|
+
# Note that we're actually tracking one cumsum() for each possible end_length
|
|
434
434
|
later_match = torch.any(cumsum[:, :-1, :, None] == valid_positions[:, :, :, :, None], axis=-2)
|
|
435
435
|
|
|
436
436
|
# The match vector is a boolean vector that indicates which positions have valid tokens
|