sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
29
29
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
30
30
|
"""
|
31
31
|
|
32
|
+
import copy
|
32
33
|
import dataclasses
|
33
34
|
import logging
|
34
35
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
@@ -43,14 +44,15 @@ from sglang.srt.configs.model_config import ModelConfig
|
|
43
44
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
44
45
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
45
46
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
46
|
-
from sglang.srt.mem_cache.memory_pool import
|
47
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
47
48
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
48
49
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
49
50
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
50
51
|
from sglang.srt.server_args import ServerArgs
|
51
52
|
|
52
53
|
if TYPE_CHECKING:
|
53
|
-
from sglang.srt.speculative.
|
54
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
55
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
54
56
|
|
55
57
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
56
58
|
|
@@ -65,7 +67,11 @@ global_server_args_dict = {
|
|
65
67
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
66
68
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
67
69
|
"device": ServerArgs.device,
|
70
|
+
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
71
|
+
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
68
72
|
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
73
|
+
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
74
|
+
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
69
75
|
}
|
70
76
|
|
71
77
|
logger = logging.getLogger(__name__)
|
@@ -228,12 +234,14 @@ class Req:
|
|
228
234
|
sampling_params: SamplingParams,
|
229
235
|
return_logprob: bool = False,
|
230
236
|
top_logprobs_num: int = 0,
|
237
|
+
token_ids_logprob: List[int] = None,
|
231
238
|
stream: bool = False,
|
232
239
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
233
240
|
lora_path: Optional[str] = None,
|
234
241
|
input_embeds: Optional[List[List[float]]] = None,
|
235
242
|
session_id: Optional[str] = None,
|
236
243
|
custom_logit_processor: Optional[str] = None,
|
244
|
+
return_hidden_states: bool = False,
|
237
245
|
eos_token_ids: Optional[Set[int]] = None,
|
238
246
|
):
|
239
247
|
# Input and output info
|
@@ -253,16 +261,27 @@ class Req:
|
|
253
261
|
self.input_embeds = input_embeds
|
254
262
|
|
255
263
|
# Sampling info
|
264
|
+
if isinstance(sampling_params.custom_params, dict):
|
265
|
+
sampling_params = copy.copy(sampling_params)
|
266
|
+
sampling_params.custom_params = sampling_params.custom_params | {
|
267
|
+
"__req__": self
|
268
|
+
}
|
256
269
|
self.sampling_params = sampling_params
|
270
|
+
|
257
271
|
self.custom_logit_processor = custom_logit_processor
|
272
|
+
self.return_hidden_states = return_hidden_states
|
258
273
|
|
259
274
|
# Memory pool info
|
260
|
-
self.req_pool_idx = None
|
275
|
+
self.req_pool_idx: Optional[int] = None
|
261
276
|
|
262
277
|
# Check finish
|
263
278
|
self.tokenizer = None
|
264
279
|
self.finished_reason = None
|
280
|
+
# If we want to abort the request in the middle of the event loop, set this to true
|
281
|
+
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
265
282
|
self.to_abort = False
|
283
|
+
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
284
|
+
self.to_abort_message: str = "Unknown error"
|
266
285
|
self.stream = stream
|
267
286
|
self.eos_token_ids = eos_token_ids
|
268
287
|
|
@@ -275,7 +294,6 @@ class Req:
|
|
275
294
|
# 1: surr_offset
|
276
295
|
# 2: read_offset
|
277
296
|
# 3: last token
|
278
|
-
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
279
297
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
280
298
|
self.read_offset = None
|
281
299
|
self.decoded_text = ""
|
@@ -284,47 +302,58 @@ class Req:
|
|
284
302
|
self.image_inputs: Optional[ImageInputs] = None
|
285
303
|
|
286
304
|
# Prefix info
|
305
|
+
# The indices to kv cache for the shared prefix.
|
287
306
|
self.prefix_indices = []
|
288
|
-
#
|
289
|
-
# Updated if chunked.
|
307
|
+
# Number of tokens to run prefill.
|
290
308
|
self.extend_input_len = 0
|
309
|
+
# The relative logprob_start_len in an extend batch
|
310
|
+
self.extend_logprob_start_len = 0
|
291
311
|
self.last_node = None
|
292
312
|
|
293
|
-
#
|
294
|
-
|
313
|
+
# Whether or not if it is chunked. It increments whenever
|
314
|
+
# it is chunked, and decrement whenever chunked request is
|
315
|
+
# processed.
|
316
|
+
self.is_chunked = 0
|
295
317
|
|
296
318
|
# For retraction
|
297
319
|
self.is_retracted = False
|
298
320
|
|
299
321
|
# Logprobs (arguments)
|
300
322
|
self.return_logprob = return_logprob
|
323
|
+
# Start index to compute logprob from.
|
301
324
|
self.logprob_start_len = 0
|
302
325
|
self.top_logprobs_num = top_logprobs_num
|
326
|
+
self.token_ids_logprob = token_ids_logprob
|
303
327
|
|
304
328
|
# Logprobs (return values)
|
305
329
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
306
330
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
307
331
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
308
332
|
self.input_top_logprobs_idx: Optional[List[int]] = None
|
333
|
+
self.input_token_ids_logprobs_val: Optional[List[float]] = None
|
334
|
+
self.input_token_ids_logprobs_idx: Optional[List[int]] = None
|
335
|
+
# Temporary holder to store input_token_logprobs.
|
336
|
+
self.input_token_logprobs: Optional[List[Tuple[int]]] = None
|
337
|
+
self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
|
338
|
+
self.temp_input_top_logprobs_idx: Optional[List[int]] = None
|
339
|
+
self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
|
340
|
+
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
309
341
|
|
310
342
|
if return_logprob:
|
311
343
|
self.output_token_logprobs_val = []
|
312
344
|
self.output_token_logprobs_idx = []
|
313
345
|
self.output_top_logprobs_val = []
|
314
346
|
self.output_top_logprobs_idx = []
|
347
|
+
self.output_token_ids_logprobs_val = []
|
348
|
+
self.output_token_ids_logprobs_idx = []
|
315
349
|
else:
|
316
350
|
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
317
351
|
self.output_top_logprobs_val
|
318
|
-
) = self.output_top_logprobs_idx =
|
352
|
+
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
|
353
|
+
self.output_token_ids_logprobs_idx
|
354
|
+
) = None
|
319
355
|
self.hidden_states = []
|
320
356
|
|
321
|
-
# Logprobs (internal values)
|
322
|
-
# The tokens is prefilled but need to be considered as decode tokens
|
323
|
-
# and should be updated for the decode logprobs
|
324
|
-
self.last_update_decode_tokens = 0
|
325
|
-
# The relative logprob_start_len in an extend batch
|
326
|
-
self.extend_logprob_start_len = 0
|
327
|
-
|
328
357
|
# Embedding (return values)
|
329
358
|
self.embedding = None
|
330
359
|
|
@@ -340,6 +369,10 @@ class Req:
|
|
340
369
|
self.spec_verify_ct = 0
|
341
370
|
self.lora_path = lora_path
|
342
371
|
|
372
|
+
@property
|
373
|
+
def seqlen(self):
|
374
|
+
return len(self.origin_input_ids) + len(self.output_ids)
|
375
|
+
|
343
376
|
def extend_image_inputs(self, image_inputs):
|
344
377
|
if self.image_inputs is None:
|
345
378
|
self.image_inputs = image_inputs
|
@@ -417,7 +450,9 @@ class Req:
|
|
417
450
|
return
|
418
451
|
|
419
452
|
if self.to_abort:
|
420
|
-
self.finished_reason = FINISH_ABORT(
|
453
|
+
self.finished_reason = FINISH_ABORT(
|
454
|
+
message=self.to_abort_message,
|
455
|
+
)
|
421
456
|
return
|
422
457
|
|
423
458
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
@@ -457,81 +492,22 @@ class Req:
|
|
457
492
|
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
458
493
|
return
|
459
494
|
|
460
|
-
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
461
|
-
if self.origin_input_text is None:
|
462
|
-
# Recovering text can only use unpadded ids
|
463
|
-
self.origin_input_text = self.tokenizer.decode(
|
464
|
-
self.origin_input_ids_unpadded
|
465
|
-
)
|
466
|
-
|
467
|
-
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
468
|
-
all_ids = self.tokenizer.encode(all_text)
|
469
|
-
if not all_ids:
|
470
|
-
logger.warning("Encoded all_text resulted in empty all_ids")
|
471
|
-
return False
|
472
|
-
|
473
|
-
prompt_tokens = len(self.origin_input_ids_unpadded)
|
474
|
-
if prompt_tokens > len(all_ids):
|
475
|
-
logger.warning("prompt_tokens is larger than encoded all_ids")
|
476
|
-
return False
|
477
|
-
|
478
|
-
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
479
|
-
# TODO(lsyin): fix token fusion
|
480
|
-
logger.warning(
|
481
|
-
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
482
|
-
)
|
483
|
-
return False
|
484
|
-
|
485
|
-
old_output_ids = self.output_ids
|
486
|
-
self.output_ids = all_ids[prompt_tokens:]
|
487
|
-
self.decoded_text = self.decoded_text + jump_forward_str
|
488
|
-
self.surr_offset = prompt_tokens
|
489
|
-
self.read_offset = len(all_ids)
|
490
|
-
|
491
|
-
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
492
|
-
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
493
|
-
surr_text_ = self.tokenizer.decode(
|
494
|
-
all_ids[self.read_offset - i : self.read_offset]
|
495
|
-
)
|
496
|
-
if not surr_text_.endswith("�"):
|
497
|
-
self.surr_offset = self.read_offset - i
|
498
|
-
break
|
499
|
-
|
500
|
-
# update the inner state of the grammar
|
501
|
-
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
|
502
|
-
|
503
|
-
if self.return_logprob:
|
504
|
-
# For fast-forward part's logprobs
|
505
|
-
k = 0
|
506
|
-
for i, old_id in enumerate(old_output_ids):
|
507
|
-
if old_id == self.output_ids[i]:
|
508
|
-
k = k + 1
|
509
|
-
else:
|
510
|
-
break
|
511
|
-
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
|
512
|
-
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
513
|
-
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
514
|
-
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
515
|
-
self.logprob_start_len = prompt_tokens + k
|
516
|
-
self.last_update_decode_tokens = len(self.output_ids) - k
|
517
|
-
|
518
|
-
return True
|
519
|
-
|
520
495
|
def reset_for_retract(self):
|
521
496
|
self.prefix_indices = []
|
522
497
|
self.last_node = None
|
523
498
|
self.extend_input_len = 0
|
524
499
|
self.is_retracted = True
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
self.
|
529
|
-
self.
|
500
|
+
self.input_token_logprobs = None
|
501
|
+
self.temp_input_top_logprobs_val = None
|
502
|
+
self.temp_input_top_logprobs_idx = None
|
503
|
+
self.extend_logprob_start_len = 0
|
504
|
+
self.is_chunked = 0
|
505
|
+
self.req_pool_idx = None
|
530
506
|
|
531
507
|
def __repr__(self):
|
532
508
|
return (
|
533
|
-
f"rid
|
534
|
-
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
509
|
+
f"Req(rid={self.rid}, "
|
510
|
+
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
|
535
511
|
)
|
536
512
|
|
537
513
|
|
@@ -545,7 +521,7 @@ class ScheduleBatch:
|
|
545
521
|
# Request, memory pool, and cache
|
546
522
|
reqs: List[Req]
|
547
523
|
req_to_token_pool: ReqToTokenPool = None
|
548
|
-
|
524
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
|
549
525
|
tree_cache: BasePrefixCache = None
|
550
526
|
|
551
527
|
# Batch configs
|
@@ -571,11 +547,13 @@ class ScheduleBatch:
|
|
571
547
|
|
572
548
|
# For DP attention
|
573
549
|
global_num_tokens: Optional[List[int]] = None
|
550
|
+
global_num_tokens_for_logprob: Optional[List[int]] = None
|
574
551
|
can_run_dp_cuda_graph: bool = False
|
575
552
|
|
576
553
|
# For processing logprobs
|
577
554
|
return_logprob: bool = False
|
578
555
|
top_logprobs_nums: Optional[List[int]] = None
|
556
|
+
token_ids_logprobs: Optional[List[List[int]]] = None
|
579
557
|
|
580
558
|
# For extend and mixed chunekd prefill
|
581
559
|
prefix_lens: List[int] = None
|
@@ -583,6 +561,8 @@ class ScheduleBatch:
|
|
583
561
|
extend_num_tokens: int = None
|
584
562
|
decoding_reqs: List[Req] = None
|
585
563
|
extend_logprob_start_lens: List[int] = None
|
564
|
+
# It comes empty list if logprob is not required.
|
565
|
+
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
|
586
566
|
|
587
567
|
# For encoder-decoder
|
588
568
|
encoder_cached: Optional[List[bool]] = None
|
@@ -601,12 +581,12 @@ class ScheduleBatch:
|
|
601
581
|
|
602
582
|
# Speculative decoding
|
603
583
|
spec_algorithm: SpeculativeAlgorithm = None
|
604
|
-
spec_info: Optional[
|
584
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
605
585
|
|
606
586
|
# Enable custom logit processor
|
607
587
|
enable_custom_logit_processor: bool = False
|
608
588
|
|
609
|
-
#
|
589
|
+
# Whether to return hidden states
|
610
590
|
return_hidden_states: bool = False
|
611
591
|
|
612
592
|
@classmethod
|
@@ -614,18 +594,17 @@ class ScheduleBatch:
|
|
614
594
|
cls,
|
615
595
|
reqs: List[Req],
|
616
596
|
req_to_token_pool: ReqToTokenPool,
|
617
|
-
|
597
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
618
598
|
tree_cache: BasePrefixCache,
|
619
599
|
model_config: ModelConfig,
|
620
600
|
enable_overlap: bool,
|
621
601
|
spec_algorithm: SpeculativeAlgorithm,
|
622
602
|
enable_custom_logit_processor: bool,
|
623
|
-
return_hidden_states: bool = False,
|
624
603
|
):
|
625
604
|
return cls(
|
626
605
|
reqs=reqs,
|
627
606
|
req_to_token_pool=req_to_token_pool,
|
628
|
-
|
607
|
+
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
629
608
|
tree_cache=tree_cache,
|
630
609
|
model_config=model_config,
|
631
610
|
enable_overlap=enable_overlap,
|
@@ -635,7 +614,7 @@ class ScheduleBatch:
|
|
635
614
|
device=req_to_token_pool.device,
|
636
615
|
spec_algorithm=spec_algorithm,
|
637
616
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
638
|
-
return_hidden_states=return_hidden_states,
|
617
|
+
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
639
618
|
)
|
640
619
|
|
641
620
|
def batch_size(self):
|
@@ -648,25 +627,27 @@ class ScheduleBatch:
|
|
648
627
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
649
628
|
if req_pool_indices is None:
|
650
629
|
raise RuntimeError(
|
651
|
-
"
|
652
|
-
"Please set a smaller number for `--max-running-requests`."
|
630
|
+
"alloc_req_slots runs out of memory. "
|
631
|
+
"Please set a smaller number for `--max-running-requests`. "
|
632
|
+
f"{self.req_to_token_pool.available_size()=}, "
|
633
|
+
f"{num_reqs=}, "
|
653
634
|
)
|
654
635
|
return req_pool_indices
|
655
636
|
|
656
637
|
def alloc_token_slots(self, num_tokens: int):
|
657
|
-
out_cache_loc = self.
|
638
|
+
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
658
639
|
|
659
640
|
if out_cache_loc is None:
|
660
641
|
if self.tree_cache is not None:
|
661
|
-
self.tree_cache.evict(num_tokens, self.
|
662
|
-
out_cache_loc = self.
|
642
|
+
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
|
643
|
+
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
663
644
|
|
664
645
|
if out_cache_loc is None:
|
665
646
|
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
666
647
|
logger.error(
|
667
648
|
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
668
649
|
f"Try to allocate {num_tokens} tokens.\n"
|
669
|
-
f"Avaliable tokens: {self.
|
650
|
+
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
670
651
|
)
|
671
652
|
if self.tree_cache is not None:
|
672
653
|
self.tree_cache.pretty_print()
|
@@ -760,6 +741,7 @@ class ScheduleBatch:
|
|
760
741
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
761
742
|
|
762
743
|
input_embeds = []
|
744
|
+
extend_input_logprob_token_ids = []
|
763
745
|
|
764
746
|
pt = 0
|
765
747
|
for i, req in enumerate(reqs):
|
@@ -778,22 +760,64 @@ class ScheduleBatch:
|
|
778
760
|
# If req.input_embeds is already a list, append its content directly
|
779
761
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
780
762
|
|
781
|
-
if req.return_logprob:
|
782
|
-
# Compute the relative logprob_start_len in an extend batch
|
783
|
-
if req.logprob_start_len >= pre_len:
|
784
|
-
extend_logprob_start_len = min(
|
785
|
-
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
786
|
-
)
|
787
|
-
else:
|
788
|
-
raise RuntimeError(
|
789
|
-
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
790
|
-
)
|
791
|
-
req.extend_logprob_start_len = extend_logprob_start_len
|
792
|
-
|
793
763
|
req.cached_tokens += pre_len - req.already_computed
|
794
764
|
req.already_computed = seq_len
|
795
765
|
req.is_retracted = False
|
796
766
|
pre_lens.append(pre_len)
|
767
|
+
# Compute the relative logprob_start_len in an extend batch
|
768
|
+
if req.logprob_start_len >= pre_len:
|
769
|
+
req.extend_logprob_start_len = min(
|
770
|
+
req.logprob_start_len - pre_len,
|
771
|
+
req.extend_input_len,
|
772
|
+
req.seqlen - 1,
|
773
|
+
)
|
774
|
+
else:
|
775
|
+
req.extend_logprob_start_len = 0
|
776
|
+
|
777
|
+
if self.return_logprob:
|
778
|
+
# Find input logprob token ids.
|
779
|
+
# First, find a global index within origin_input_ids and slide it by 1
|
780
|
+
# to compute input logprobs. It is because you need the next token
|
781
|
+
# to compute input logprobs. E.g., (chunk size 2)
|
782
|
+
#
|
783
|
+
# input_logprobs = [1, 2, 3, 4]
|
784
|
+
# fill_ids = [1, 2]
|
785
|
+
# extend_input_logprob_token_id = [2, 3]
|
786
|
+
#
|
787
|
+
# Note that it can also overflow. In this case, we pad it with 0.
|
788
|
+
# input_logprobs = [1, 2, 3, 4]
|
789
|
+
# fill_ids = [3, 4]
|
790
|
+
# extend_input_logprob_token_id = [4, 0]
|
791
|
+
global_start_idx, global_end_idx = (
|
792
|
+
len(req.prefix_indices),
|
793
|
+
len(req.fill_ids),
|
794
|
+
)
|
795
|
+
# Apply logprob_start_len
|
796
|
+
if global_start_idx < req.logprob_start_len:
|
797
|
+
global_start_idx = req.logprob_start_len
|
798
|
+
|
799
|
+
logprob_token_ids = req.origin_input_ids[
|
800
|
+
global_start_idx + 1 : global_end_idx + 1
|
801
|
+
]
|
802
|
+
extend_input_logprob_token_ids.extend(logprob_token_ids)
|
803
|
+
|
804
|
+
# We will need req.extend_input_len - req.extend_logprob_start_len number of
|
805
|
+
# tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
|
806
|
+
extend_input_logprob_token_ids.extend(
|
807
|
+
[0]
|
808
|
+
* (
|
809
|
+
req.extend_input_len
|
810
|
+
- req.extend_logprob_start_len
|
811
|
+
- len(logprob_token_ids)
|
812
|
+
)
|
813
|
+
)
|
814
|
+
|
815
|
+
if self.return_logprob:
|
816
|
+
extend_input_logprob_token_ids = torch.tensor(
|
817
|
+
extend_input_logprob_token_ids
|
818
|
+
)
|
819
|
+
else:
|
820
|
+
extend_input_logprob_token_ids = None
|
797
821
|
|
798
822
|
# Set fields
|
799
823
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
@@ -816,10 +840,12 @@ class ScheduleBatch:
|
|
816
840
|
self.seq_lens_sum = sum(seq_lens)
|
817
841
|
if self.return_logprob:
|
818
842
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
843
|
+
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
819
844
|
self.extend_num_tokens = extend_num_tokens
|
820
845
|
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
821
846
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
822
847
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
848
|
+
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
823
849
|
|
824
850
|
# Write to req_to_token_pool
|
825
851
|
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
@@ -855,7 +881,6 @@ class ScheduleBatch:
|
|
855
881
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
856
882
|
self,
|
857
883
|
self.model_config.vocab_size,
|
858
|
-
enable_overlap_schedule=self.enable_overlap,
|
859
884
|
)
|
860
885
|
|
861
886
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
@@ -890,41 +915,60 @@ class ScheduleBatch:
|
|
890
915
|
|
891
916
|
def check_decode_mem(self, buf_multiplier=1):
|
892
917
|
bs = len(self.reqs) * buf_multiplier
|
893
|
-
if self.
|
918
|
+
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
894
919
|
return True
|
895
920
|
|
896
|
-
self.tree_cache.evict(bs, self.
|
921
|
+
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
|
897
922
|
|
898
|
-
if self.
|
923
|
+
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
899
924
|
return True
|
900
925
|
|
901
926
|
return False
|
902
927
|
|
903
|
-
def retract_decode(self):
|
928
|
+
def retract_decode(self, server_args: ServerArgs):
|
904
929
|
"""Retract the decoding requests when there is not enough memory."""
|
905
930
|
sorted_indices = [i for i in range(len(self.reqs))]
|
906
931
|
|
907
932
|
# TODO(lsyin): improve retraction policy for radix cache
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
933
|
+
# For spec decoding, filter_batch API can only filter
|
934
|
+
# requests from the back, so we can only retract from the back.
|
935
|
+
# TODO(sang): Clean up finish path and support better retract
|
936
|
+
# policy.
|
937
|
+
if not server_args.speculative_algorithm:
|
938
|
+
sorted_indices.sort(
|
939
|
+
key=lambda i: (
|
940
|
+
len(self.reqs[i].output_ids),
|
941
|
+
-len(self.reqs[i].origin_input_ids),
|
942
|
+
),
|
943
|
+
reverse=True,
|
944
|
+
)
|
915
945
|
|
916
946
|
retracted_reqs = []
|
917
947
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
918
948
|
first_iter = True
|
949
|
+
|
950
|
+
def get_required_tokens(num_reqs: int):
|
951
|
+
headroom_for_spec_decode = 0
|
952
|
+
if server_args.speculative_algorithm:
|
953
|
+
headroom_for_spec_decode += (
|
954
|
+
num_reqs
|
955
|
+
* server_args.speculative_eagle_topk
|
956
|
+
* server_args.speculative_num_steps
|
957
|
+
+ num_reqs * server_args.speculative_num_draft_tokens
|
958
|
+
)
|
959
|
+
return (
|
960
|
+
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
961
|
+
)
|
962
|
+
|
919
963
|
while (
|
920
|
-
self.
|
921
|
-
< len(sorted_indices)
|
964
|
+
self.token_to_kv_pool_allocator.available_size()
|
965
|
+
< get_required_tokens(len(sorted_indices))
|
922
966
|
or first_iter
|
923
967
|
):
|
924
968
|
if len(sorted_indices) == 1:
|
925
969
|
# Corner case: only one request left
|
926
970
|
assert (
|
927
|
-
self.
|
971
|
+
self.token_to_kv_pool_allocator.available_size() > 0
|
928
972
|
), "No space left for only one request"
|
929
973
|
break
|
930
974
|
|
@@ -938,7 +982,7 @@ class ScheduleBatch:
|
|
938
982
|
token_indices = self.req_to_token_pool.req_to_token[
|
939
983
|
req.req_pool_idx, : seq_lens_cpu[idx]
|
940
984
|
]
|
941
|
-
self.
|
985
|
+
self.token_to_kv_pool_allocator.free(token_indices)
|
942
986
|
self.req_to_token_pool.free(req.req_pool_idx)
|
943
987
|
del self.tree_cache.entries[req.rid]
|
944
988
|
else:
|
@@ -947,7 +991,7 @@ class ScheduleBatch:
|
|
947
991
|
token_indices = self.req_to_token_pool.req_to_token[
|
948
992
|
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
949
993
|
]
|
950
|
-
self.
|
994
|
+
self.token_to_kv_pool_allocator.free(token_indices)
|
951
995
|
self.req_to_token_pool.free(req.req_pool_idx)
|
952
996
|
|
953
997
|
# release the last node
|
@@ -956,10 +1000,13 @@ class ScheduleBatch:
|
|
956
1000
|
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
957
1001
|
residual_size = (
|
958
1002
|
len(sorted_indices) * global_config.retract_decode_steps
|
959
|
-
- self.
|
1003
|
+
- self.token_to_kv_pool_allocator.available_size()
|
960
1004
|
)
|
961
1005
|
residual_size = max(0, residual_size)
|
962
|
-
self.tree_cache.evict(
|
1006
|
+
self.tree_cache.evict(
|
1007
|
+
residual_size, self.token_to_kv_pool_allocator.free
|
1008
|
+
)
|
1009
|
+
|
963
1010
|
req.reset_for_retract()
|
964
1011
|
|
965
1012
|
self.filter_batch(keep_indices=sorted_indices)
|
@@ -975,59 +1022,6 @@ class ScheduleBatch:
|
|
975
1022
|
|
976
1023
|
return retracted_reqs, new_estimate_ratio
|
977
1024
|
|
978
|
-
def check_for_jump_forward(self, pad_input_ids_func):
|
979
|
-
jump_forward_reqs = []
|
980
|
-
keep_indices = set(i for i in range(len(self.reqs)))
|
981
|
-
|
982
|
-
for i, req in enumerate(self.reqs):
|
983
|
-
if req.grammar is not None:
|
984
|
-
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
|
985
|
-
if jump_helper:
|
986
|
-
suffix_ids, _ = jump_helper
|
987
|
-
|
988
|
-
# Current ids, for cache and revert
|
989
|
-
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
990
|
-
cur_output_ids = req.output_ids
|
991
|
-
|
992
|
-
req.output_ids.extend(suffix_ids)
|
993
|
-
decode_res, new_text = req.get_next_inc_detokenization()
|
994
|
-
if not decode_res:
|
995
|
-
req.output_ids = cur_output_ids
|
996
|
-
continue
|
997
|
-
|
998
|
-
(
|
999
|
-
jump_forward_str,
|
1000
|
-
next_state,
|
1001
|
-
) = req.grammar.jump_forward_str_state(jump_helper)
|
1002
|
-
|
1003
|
-
# Make the incrementally decoded text part of jump_forward_str
|
1004
|
-
# so that the UTF-8 will not corrupt
|
1005
|
-
jump_forward_str = new_text + jump_forward_str
|
1006
|
-
if not req.jump_forward_and_retokenize(
|
1007
|
-
jump_forward_str, next_state
|
1008
|
-
):
|
1009
|
-
req.output_ids = cur_output_ids
|
1010
|
-
continue
|
1011
|
-
|
1012
|
-
# The decode status has diverged from detokenizer_manager
|
1013
|
-
req.vid += 1
|
1014
|
-
|
1015
|
-
# insert the old request into tree_cache
|
1016
|
-
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
1017
|
-
|
1018
|
-
# re-applying image padding
|
1019
|
-
if req.image_inputs is not None:
|
1020
|
-
req.origin_input_ids = pad_input_ids_func(
|
1021
|
-
req.origin_input_ids_unpadded, req.image_inputs
|
1022
|
-
)
|
1023
|
-
|
1024
|
-
jump_forward_reqs.append(req)
|
1025
|
-
keep_indices.remove(i)
|
1026
|
-
|
1027
|
-
self.filter_batch(keep_indices=list(keep_indices))
|
1028
|
-
|
1029
|
-
return jump_forward_reqs
|
1030
|
-
|
1031
1025
|
def prepare_encoder_info_decode(self):
|
1032
1026
|
# Reset the encoder cached status
|
1033
1027
|
self.encoder_cached = [True] * len(self.reqs)
|
@@ -1043,17 +1037,40 @@ class ScheduleBatch:
|
|
1043
1037
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
1044
1038
|
self,
|
1045
1039
|
self.model_config.vocab_size,
|
1046
|
-
enable_overlap_schedule=self.enable_overlap,
|
1047
1040
|
)
|
1048
1041
|
|
1049
1042
|
def prepare_for_decode(self):
|
1050
1043
|
self.forward_mode = ForwardMode.DECODE
|
1051
1044
|
if self.spec_algorithm.is_eagle():
|
1045
|
+
# if spec decoding is used, the decode batch is prepared inside
|
1046
|
+
# `forward_batch_speculative_generation` after running draft models.
|
1052
1047
|
return
|
1053
1048
|
|
1049
|
+
if self.sampling_info.penalizer_orchestrator.is_required:
|
1050
|
+
if self.enable_overlap:
|
1051
|
+
# TODO: this can be slow, optimize this.
|
1052
|
+
delayed_output_ids = torch.tensor(
|
1053
|
+
[
|
1054
|
+
(
|
1055
|
+
req.output_ids[-1]
|
1056
|
+
if len(req.output_ids)
|
1057
|
+
else req.origin_input_ids[-1]
|
1058
|
+
)
|
1059
|
+
for req in self.reqs
|
1060
|
+
],
|
1061
|
+
dtype=torch.int64,
|
1062
|
+
device=self.device,
|
1063
|
+
)
|
1064
|
+
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
1065
|
+
delayed_output_ids
|
1066
|
+
)
|
1067
|
+
else:
|
1068
|
+
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
1069
|
+
self.output_ids.to(torch.int64)
|
1070
|
+
)
|
1071
|
+
|
1054
1072
|
self.input_ids = self.output_ids
|
1055
1073
|
self.output_ids = None
|
1056
|
-
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
|
1057
1074
|
|
1058
1075
|
# Alloc mem
|
1059
1076
|
bs = len(self.reqs)
|
@@ -1081,14 +1098,15 @@ class ScheduleBatch:
|
|
1081
1098
|
|
1082
1099
|
def filter_batch(
|
1083
1100
|
self,
|
1084
|
-
|
1101
|
+
chunked_req_to_exclude: Optional[Req] = None,
|
1085
1102
|
keep_indices: Optional[List[int]] = None,
|
1086
1103
|
):
|
1087
1104
|
if keep_indices is None:
|
1088
1105
|
keep_indices = [
|
1089
1106
|
i
|
1090
1107
|
for i in range(len(self.reqs))
|
1091
|
-
if not self.reqs[i].finished()
|
1108
|
+
if not self.reqs[i].finished()
|
1109
|
+
and self.reqs[i] is not chunked_req_to_exclude
|
1092
1110
|
]
|
1093
1111
|
|
1094
1112
|
if keep_indices is None or len(keep_indices) == 0:
|
@@ -1100,31 +1118,34 @@ class ScheduleBatch:
|
|
1100
1118
|
# No need to filter
|
1101
1119
|
return
|
1102
1120
|
|
1121
|
+
keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
|
1122
|
+
self.device, non_blocking=True
|
1123
|
+
)
|
1124
|
+
|
1103
1125
|
if self.model_config.is_encoder_decoder:
|
1104
|
-
self.encoder_lens = self.encoder_lens[
|
1126
|
+
self.encoder_lens = self.encoder_lens[keep_indices_device]
|
1105
1127
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
1106
1128
|
|
1107
1129
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
1108
|
-
|
1109
|
-
|
1110
|
-
)
|
1111
|
-
self.req_pool_indices = self.req_pool_indices[new_indices]
|
1112
|
-
self.seq_lens = self.seq_lens[new_indices]
|
1130
|
+
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1131
|
+
self.seq_lens = self.seq_lens[keep_indices_device]
|
1113
1132
|
self.out_cache_loc = None
|
1114
1133
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
1115
|
-
self.output_ids = self.output_ids[
|
1134
|
+
self.output_ids = self.output_ids[keep_indices_device]
|
1116
1135
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
1117
1136
|
if self.return_logprob:
|
1118
1137
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
1138
|
+
self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
|
1119
1139
|
else:
|
1120
1140
|
self.top_logprobs_nums = None
|
1141
|
+
self.token_ids_logprobs = None
|
1121
1142
|
|
1122
1143
|
self.has_stream = any(req.stream for req in self.reqs)
|
1123
1144
|
self.has_grammar = any(req.grammar for req in self.reqs)
|
1124
1145
|
|
1125
|
-
self.sampling_info.filter_batch(keep_indices,
|
1146
|
+
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
1126
1147
|
if self.spec_info:
|
1127
|
-
self.spec_info.filter_batch(
|
1148
|
+
self.spec_info.filter_batch(keep_indices_device)
|
1128
1149
|
|
1129
1150
|
def merge_batch(self, other: "ScheduleBatch"):
|
1130
1151
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
@@ -1147,23 +1168,32 @@ class ScheduleBatch:
|
|
1147
1168
|
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
1148
1169
|
if self.return_logprob and other.return_logprob:
|
1149
1170
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
1171
|
+
self.token_ids_logprobs.extend(other.token_ids_logprobs)
|
1150
1172
|
elif self.return_logprob:
|
1151
1173
|
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
1174
|
+
self.token_ids_logprobs.extend([None] * len(other.reqs))
|
1152
1175
|
elif other.return_logprob:
|
1153
1176
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
1177
|
+
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
1154
1178
|
self.reqs.extend(other.reqs)
|
1155
1179
|
|
1156
1180
|
self.return_logprob |= other.return_logprob
|
1157
1181
|
self.has_stream |= other.has_stream
|
1158
1182
|
self.has_grammar |= other.has_grammar
|
1183
|
+
self.return_hidden_states |= other.return_hidden_states
|
1159
1184
|
|
1160
1185
|
if self.spec_info:
|
1161
1186
|
self.spec_info.merge_batch(other.spec_info)
|
1162
1187
|
|
1163
|
-
def get_model_worker_batch(self):
|
1188
|
+
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
1164
1189
|
if self.forward_mode.is_decode_or_idle():
|
1190
|
+
if global_server_args_dict["enable_flashinfer_mla"]:
|
1191
|
+
decode_seq_lens = self.seq_lens.cpu()
|
1192
|
+
else:
|
1193
|
+
decode_seq_lens = None
|
1165
1194
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
1166
1195
|
else:
|
1196
|
+
decode_seq_lens = None
|
1167
1197
|
extend_seq_lens = self.extend_lens
|
1168
1198
|
extend_prefix_lens = self.prefix_lens
|
1169
1199
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
@@ -1186,8 +1216,11 @@ class ScheduleBatch:
|
|
1186
1216
|
seq_lens_sum=self.seq_lens_sum,
|
1187
1217
|
return_logprob=self.return_logprob,
|
1188
1218
|
top_logprobs_nums=self.top_logprobs_nums,
|
1219
|
+
token_ids_logprobs=self.token_ids_logprobs,
|
1189
1220
|
global_num_tokens=self.global_num_tokens,
|
1221
|
+
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1190
1222
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1223
|
+
decode_seq_lens=decode_seq_lens,
|
1191
1224
|
extend_num_tokens=self.extend_num_tokens,
|
1192
1225
|
extend_seq_lens=extend_seq_lens,
|
1193
1226
|
extend_prefix_lens=extend_prefix_lens,
|
@@ -1213,6 +1246,7 @@ class ScheduleBatch:
|
|
1213
1246
|
else CaptureHiddenMode.NULL
|
1214
1247
|
)
|
1215
1248
|
),
|
1249
|
+
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
1216
1250
|
)
|
1217
1251
|
|
1218
1252
|
def copy(self):
|
@@ -1247,7 +1281,7 @@ class ModelWorkerBatch:
|
|
1247
1281
|
req_pool_indices: torch.Tensor
|
1248
1282
|
# The sequence length
|
1249
1283
|
seq_lens: torch.Tensor
|
1250
|
-
# The indices of output tokens in the
|
1284
|
+
# The indices of output tokens in the token_to_kv_pool_allocator
|
1251
1285
|
out_cache_loc: torch.Tensor
|
1252
1286
|
|
1253
1287
|
# The sum of all sequence lengths
|
@@ -1256,16 +1290,22 @@ class ModelWorkerBatch:
|
|
1256
1290
|
# For logprob
|
1257
1291
|
return_logprob: bool
|
1258
1292
|
top_logprobs_nums: Optional[List[int]]
|
1293
|
+
token_ids_logprobs: Optional[List[List[int]]]
|
1259
1294
|
|
1260
1295
|
# For DP attention
|
1261
1296
|
global_num_tokens: Optional[List[int]]
|
1297
|
+
global_num_tokens_for_logprob: Optional[List[int]]
|
1262
1298
|
can_run_dp_cuda_graph: bool
|
1263
1299
|
|
1300
|
+
# For decode
|
1301
|
+
decode_seq_lens: Optional[torch.Tensor]
|
1302
|
+
|
1264
1303
|
# For extend
|
1265
1304
|
extend_num_tokens: Optional[int]
|
1266
1305
|
extend_seq_lens: Optional[List[int]]
|
1267
1306
|
extend_prefix_lens: Optional[List[int]]
|
1268
1307
|
extend_logprob_start_lens: Optional[List[int]]
|
1308
|
+
extend_input_logprob_token_ids: Optional[torch.Tensor]
|
1269
1309
|
|
1270
1310
|
# For multimodal
|
1271
1311
|
image_inputs: Optional[List[ImageInputs]]
|
@@ -1287,7 +1327,8 @@ class ModelWorkerBatch:
|
|
1287
1327
|
|
1288
1328
|
# Speculative decoding
|
1289
1329
|
spec_algorithm: SpeculativeAlgorithm = None
|
1290
|
-
spec_info: Optional[
|
1330
|
+
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
1331
|
+
# If set, the output of the batch contains the hidden states of the run.
|
1291
1332
|
capture_hidden_mode: CaptureHiddenMode = None
|
1292
1333
|
|
1293
1334
|
|