sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -17,10 +17,11 @@ import faulthandler
|
|
17
17
|
import logging
|
18
18
|
import os
|
19
19
|
import signal
|
20
|
+
import sys
|
20
21
|
import threading
|
21
22
|
import time
|
22
23
|
import warnings
|
23
|
-
from collections import deque
|
24
|
+
from collections import defaultdict, deque
|
24
25
|
from concurrent import futures
|
25
26
|
from dataclasses import dataclass
|
26
27
|
from http import HTTPStatus
|
@@ -44,17 +45,24 @@ from sglang.srt.managers.io_struct import (
|
|
44
45
|
BatchTokenIDOut,
|
45
46
|
CloseSessionReqInput,
|
46
47
|
FlushCacheReq,
|
48
|
+
GetInternalStateReq,
|
49
|
+
GetInternalStateReqOutput,
|
47
50
|
GetWeightsByNameReqInput,
|
48
51
|
GetWeightsByNameReqOutput,
|
52
|
+
HealthCheckOutput,
|
49
53
|
InitWeightsUpdateGroupReqInput,
|
50
54
|
InitWeightsUpdateGroupReqOutput,
|
51
55
|
OpenSessionReqInput,
|
52
56
|
OpenSessionReqOutput,
|
53
57
|
ProfileReq,
|
58
|
+
ProfileReqOutput,
|
59
|
+
ProfileReqType,
|
54
60
|
ReleaseMemoryOccupationReqInput,
|
55
61
|
ReleaseMemoryOccupationReqOutput,
|
56
62
|
ResumeMemoryOccupationReqInput,
|
57
63
|
ResumeMemoryOccupationReqOutput,
|
64
|
+
SetInternalStateReq,
|
65
|
+
SetInternalStateReqOutput,
|
58
66
|
TokenizedEmbeddingReqInput,
|
59
67
|
TokenizedGenerateReqInput,
|
60
68
|
UpdateWeightFromDiskReqInput,
|
@@ -82,6 +90,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
|
82
90
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
83
91
|
from sglang.srt.managers.utils import validate_input_length
|
84
92
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
93
|
+
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
85
94
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
86
95
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
87
96
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
@@ -94,6 +103,7 @@ from sglang.srt.utils import (
|
|
94
103
|
crash_on_warnings,
|
95
104
|
get_bool_env_var,
|
96
105
|
get_zmq_socket,
|
106
|
+
pyspy_dump_schedulers,
|
97
107
|
set_gpu_proc_affinity,
|
98
108
|
set_random_seed,
|
99
109
|
suppress_other_loggers,
|
@@ -103,13 +113,16 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
|
103
113
|
logger = logging.getLogger(__name__)
|
104
114
|
|
105
115
|
# Test retract decode for debugging purposes
|
106
|
-
|
116
|
+
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
117
|
+
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
107
118
|
|
108
119
|
|
109
120
|
@dataclass
|
110
121
|
class GenerationBatchResult:
|
111
122
|
logits_output: LogitsProcessorOutput
|
112
123
|
next_token_ids: List[int]
|
124
|
+
extend_input_len_per_req: List[int]
|
125
|
+
extend_logprob_start_len_per_req: List[int]
|
113
126
|
bid: int
|
114
127
|
|
115
128
|
|
@@ -135,20 +148,16 @@ class Scheduler:
|
|
135
148
|
self.tp_rank = tp_rank
|
136
149
|
self.tp_size = server_args.tp_size
|
137
150
|
self.schedule_policy = server_args.schedule_policy
|
138
|
-
self.disable_jump_forward = server_args.disable_jump_forward
|
139
151
|
self.lora_paths = server_args.lora_paths
|
140
152
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
141
153
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
142
154
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
143
155
|
self.enable_metrics = server_args.enable_metrics
|
156
|
+
self.stream_interval = server_args.stream_interval
|
144
157
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
145
158
|
server_args.speculative_algorithm
|
146
159
|
)
|
147
|
-
self.
|
148
|
-
self.server_args.speculative_num_draft_tokens
|
149
|
-
if not self.spec_algorithm.is_none()
|
150
|
-
else 1
|
151
|
-
)
|
160
|
+
self.gpu_id = gpu_id
|
152
161
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
153
162
|
|
154
163
|
# Distributed rank info
|
@@ -188,49 +197,16 @@ class Scheduler:
|
|
188
197
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
189
198
|
|
190
199
|
# Init tokenizer
|
191
|
-
self.
|
192
|
-
server_args.model_path,
|
193
|
-
trust_remote_code=server_args.trust_remote_code,
|
194
|
-
revision=server_args.revision,
|
195
|
-
context_length=server_args.context_length,
|
196
|
-
model_override_args=server_args.json_model_override_args,
|
197
|
-
is_embedding=server_args.is_embedding,
|
198
|
-
dtype=server_args.dtype,
|
199
|
-
quantization=server_args.quantization,
|
200
|
-
)
|
201
|
-
self.is_generation = self.model_config.is_generation
|
202
|
-
|
203
|
-
if server_args.skip_tokenizer_init:
|
204
|
-
self.tokenizer = self.processor = None
|
205
|
-
else:
|
206
|
-
if self.model_config.is_multimodal:
|
207
|
-
self.processor = get_processor(
|
208
|
-
server_args.tokenizer_path,
|
209
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
210
|
-
trust_remote_code=server_args.trust_remote_code,
|
211
|
-
revision=server_args.revision,
|
212
|
-
)
|
213
|
-
self.tokenizer = self.processor.tokenizer
|
214
|
-
else:
|
215
|
-
self.tokenizer = get_tokenizer(
|
216
|
-
server_args.tokenizer_path,
|
217
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
218
|
-
trust_remote_code=server_args.trust_remote_code,
|
219
|
-
revision=server_args.revision,
|
220
|
-
)
|
200
|
+
self.init_tokenizer()
|
221
201
|
|
222
202
|
# Check whether overlap can be enabled
|
223
203
|
if not self.is_generation:
|
224
204
|
self.enable_overlap = False
|
225
205
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
226
|
-
|
227
206
|
if self.model_config.is_multimodal:
|
228
207
|
self.enable_overlap = False
|
229
208
|
logger.info("Overlap scheduler is disabled for multimodal models.")
|
230
209
|
|
231
|
-
if self.enable_overlap:
|
232
|
-
self.disable_jump_forward = True
|
233
|
-
|
234
210
|
# Launch a tensor parallel worker
|
235
211
|
if self.enable_overlap:
|
236
212
|
TpWorkerClass = TpModelWorkerClient
|
@@ -245,7 +221,7 @@ class Scheduler:
|
|
245
221
|
nccl_port=port_args.nccl_port,
|
246
222
|
)
|
247
223
|
|
248
|
-
# Launch a worker for speculative decoding
|
224
|
+
# Launch a draft worker for speculative decoding
|
249
225
|
if self.spec_algorithm.is_eagle():
|
250
226
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
251
227
|
|
@@ -279,6 +255,7 @@ class Scheduler:
|
|
279
255
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
280
256
|
global_server_args_dict.update(worker_global_server_args_dict)
|
281
257
|
set_random_seed(self.random_seed)
|
258
|
+
|
282
259
|
# Print debug info
|
283
260
|
logger.info(
|
284
261
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
@@ -289,27 +266,11 @@ class Scheduler:
|
|
289
266
|
)
|
290
267
|
|
291
268
|
# Init memory pool and cache
|
292
|
-
self.
|
293
|
-
|
294
|
-
if (
|
295
|
-
server_args.chunked_prefill_size is not None
|
296
|
-
and server_args.disable_radix_cache
|
297
|
-
):
|
298
|
-
self.tree_cache = ChunkCache(
|
299
|
-
req_to_token_pool=self.req_to_token_pool,
|
300
|
-
token_to_kv_pool=self.token_to_kv_pool,
|
301
|
-
)
|
302
|
-
else:
|
303
|
-
self.tree_cache = RadixCache(
|
304
|
-
req_to_token_pool=self.req_to_token_pool,
|
305
|
-
token_to_kv_pool=self.token_to_kv_pool,
|
306
|
-
disable=server_args.disable_radix_cache,
|
307
|
-
)
|
308
|
-
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
309
|
-
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
269
|
+
self.init_memory_pool_and_cache()
|
310
270
|
|
311
271
|
# Init running status
|
312
272
|
self.waiting_queue: List[Req] = []
|
273
|
+
self.staging_reqs = {}
|
313
274
|
# The running decoding batch for continuous batching
|
314
275
|
self.running_batch: Optional[ScheduleBatch] = None
|
315
276
|
# The current forward batch
|
@@ -319,22 +280,20 @@ class Scheduler:
|
|
319
280
|
self.forward_ct = 0
|
320
281
|
self.forward_ct_decode = 0
|
321
282
|
self.num_generated_tokens = 0
|
322
|
-
self.spec_num_total_accepted_tokens = 0
|
323
|
-
self.spec_num_total_forward_ct = 0
|
324
283
|
self.last_decode_stats_tic = time.time()
|
325
|
-
self.
|
284
|
+
self.return_health_check_ct = 0
|
326
285
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
327
286
|
if self.device == "cpu":
|
328
287
|
self.current_stream.synchronize = lambda: None # No-op for CPU
|
329
288
|
|
330
|
-
#
|
289
|
+
# Init session info
|
331
290
|
self.sessions: Dict[str, Session] = {}
|
332
291
|
|
333
292
|
# Init chunked prefill
|
334
293
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
335
294
|
if self.chunked_prefill_size <= 0: # -1 means disable
|
336
295
|
self.chunked_prefill_size = None
|
337
|
-
self.
|
296
|
+
self.chunked_req = None
|
338
297
|
self.is_mixed_chunk = (
|
339
298
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
340
299
|
)
|
@@ -348,11 +307,11 @@ class Scheduler:
|
|
348
307
|
else:
|
349
308
|
self.grammar_backend = None
|
350
309
|
|
351
|
-
# Init new token estimation
|
310
|
+
# Init schedule policy and new token estimation
|
311
|
+
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
352
312
|
assert (
|
353
313
|
server_args.schedule_conservativeness >= 0
|
354
314
|
), "Invalid schedule_conservativeness"
|
355
|
-
|
356
315
|
self.init_new_token_ratio = min(
|
357
316
|
global_config.default_init_new_token_ratio
|
358
317
|
* server_args.schedule_conservativeness,
|
@@ -368,7 +327,7 @@ class Scheduler:
|
|
368
327
|
) / global_config.default_new_token_ratio_decay_steps
|
369
328
|
self.new_token_ratio = self.init_new_token_ratio
|
370
329
|
|
371
|
-
#
|
330
|
+
# Tell whether the current running batch is full so that we can skip
|
372
331
|
# the check of whether to prefill new requests.
|
373
332
|
# This is an optimization to reduce the overhead of the prefill check.
|
374
333
|
self.batch_is_full = False
|
@@ -379,41 +338,19 @@ class Scheduler:
|
|
379
338
|
t.start()
|
380
339
|
self.parent_process = psutil.Process().parent()
|
381
340
|
|
341
|
+
# Init memory saver
|
382
342
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
383
343
|
enable=server_args.enable_memory_saver
|
384
344
|
)
|
385
345
|
|
386
346
|
# Init profiler
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
logger.info(
|
392
|
-
"Profiling enabled. Traces will be saved to: %s",
|
393
|
-
self.torch_profiler_trace_dir,
|
394
|
-
)
|
395
|
-
self.profiler = torch.profiler.profile(
|
396
|
-
activities=[
|
397
|
-
torch.profiler.ProfilerActivity.CPU,
|
398
|
-
torch.profiler.ProfilerActivity.CUDA,
|
399
|
-
],
|
400
|
-
with_stack=True,
|
401
|
-
)
|
347
|
+
self.torch_profiler = None
|
348
|
+
self.torch_profiler_output_dir: Optional[str] = None
|
349
|
+
self.torch_profiler_activities: Optional[List[str]] = None
|
350
|
+
self.profiler_target_forward_ct: Optional[int] = None
|
402
351
|
|
403
352
|
# Init metrics stats
|
404
|
-
self.
|
405
|
-
if self.enable_metrics:
|
406
|
-
self.metrics_collector = SchedulerMetricsCollector(
|
407
|
-
labels={
|
408
|
-
"model_name": self.server_args.served_model_name,
|
409
|
-
# TODO: Add lora name/path in the future,
|
410
|
-
},
|
411
|
-
)
|
412
|
-
|
413
|
-
# The largest prefill length of a single request
|
414
|
-
self._largest_prefill_len: int = 0
|
415
|
-
# The largest context length (prefill + generation) of a single request
|
416
|
-
self._largest_prefill_decode_len: int = 0
|
353
|
+
self.init_metrics()
|
417
354
|
|
418
355
|
# Init request dispatcher
|
419
356
|
self._request_dispatcher = TypeBasedDispatcher(
|
@@ -422,6 +359,8 @@ class Scheduler:
|
|
422
359
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
423
360
|
(FlushCacheReq, self.flush_cache_wrapped),
|
424
361
|
(AbortReq, self.abort_request),
|
362
|
+
(OpenSessionReqInput, self.open_session),
|
363
|
+
(CloseSessionReqInput, self.close_session),
|
425
364
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
426
365
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
427
366
|
(
|
@@ -430,39 +369,108 @@ class Scheduler:
|
|
430
369
|
),
|
431
370
|
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
432
371
|
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
372
|
+
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
373
|
+
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
433
374
|
(ProfileReq, self.profile),
|
434
|
-
(
|
435
|
-
(
|
436
|
-
(
|
437
|
-
ReleaseMemoryOccupationReqInput,
|
438
|
-
lambda _: self.release_memory_occupation(),
|
439
|
-
),
|
440
|
-
(
|
441
|
-
ResumeMemoryOccupationReqInput,
|
442
|
-
lambda _: self.resume_memory_occupation(),
|
443
|
-
),
|
375
|
+
(GetInternalStateReq, self.get_internal_state),
|
376
|
+
(SetInternalStateReq, self.set_internal_state),
|
444
377
|
]
|
445
378
|
)
|
446
379
|
|
447
|
-
def
|
448
|
-
|
449
|
-
self.watchdog_last_forward_ct = 0
|
450
|
-
self.watchdog_last_time = time.time()
|
380
|
+
def init_tokenizer(self):
|
381
|
+
server_args = self.server_args
|
451
382
|
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
383
|
+
self.model_config = ModelConfig(
|
384
|
+
server_args.model_path,
|
385
|
+
trust_remote_code=server_args.trust_remote_code,
|
386
|
+
revision=server_args.revision,
|
387
|
+
context_length=server_args.context_length,
|
388
|
+
model_override_args=server_args.json_model_override_args,
|
389
|
+
is_embedding=server_args.is_embedding,
|
390
|
+
dtype=server_args.dtype,
|
391
|
+
quantization=server_args.quantization,
|
392
|
+
)
|
393
|
+
self.is_generation = self.model_config.is_generation
|
394
|
+
|
395
|
+
if server_args.skip_tokenizer_init:
|
396
|
+
self.tokenizer = self.processor = None
|
397
|
+
else:
|
398
|
+
if self.model_config.is_multimodal:
|
399
|
+
self.processor = get_processor(
|
400
|
+
server_args.tokenizer_path,
|
401
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
402
|
+
trust_remote_code=server_args.trust_remote_code,
|
403
|
+
revision=server_args.revision,
|
404
|
+
)
|
405
|
+
self.tokenizer = self.processor.tokenizer
|
406
|
+
else:
|
407
|
+
self.tokenizer = get_tokenizer(
|
408
|
+
server_args.tokenizer_path,
|
409
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
410
|
+
trust_remote_code=server_args.trust_remote_code,
|
411
|
+
revision=server_args.revision,
|
412
|
+
)
|
413
|
+
|
414
|
+
def init_memory_pool_and_cache(self):
|
415
|
+
server_args = self.server_args
|
416
|
+
|
417
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
418
|
+
self.tp_worker.get_memory_pool()
|
419
|
+
)
|
420
|
+
|
421
|
+
if (
|
422
|
+
server_args.chunked_prefill_size is not None
|
423
|
+
and server_args.disable_radix_cache
|
424
|
+
):
|
425
|
+
self.tree_cache = ChunkCache(
|
426
|
+
req_to_token_pool=self.req_to_token_pool,
|
427
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
428
|
+
)
|
429
|
+
else:
|
430
|
+
if self.enable_hierarchical_cache:
|
431
|
+
self.tree_cache = HiRadixCache(
|
432
|
+
req_to_token_pool=self.req_to_token_pool,
|
433
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
434
|
+
)
|
435
|
+
else:
|
436
|
+
self.tree_cache = RadixCache(
|
437
|
+
req_to_token_pool=self.req_to_token_pool,
|
438
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
439
|
+
disable=server_args.disable_radix_cache,
|
440
|
+
)
|
441
|
+
|
442
|
+
self.decode_mem_cache_buf_multiplier = (
|
443
|
+
1
|
444
|
+
if self.spec_algorithm.is_none()
|
445
|
+
else (
|
446
|
+
server_args.speculative_num_draft_tokens
|
447
|
+
+ (
|
448
|
+
server_args.speculative_eagle_topk
|
449
|
+
* server_args.speculative_num_steps
|
450
|
+
)
|
451
|
+
)
|
452
|
+
)
|
453
|
+
|
454
|
+
def init_metrics(self):
|
455
|
+
# The largest prefill length of a single request
|
456
|
+
self._largest_prefill_len: int = 0
|
457
|
+
# The largest context length (prefill + generation) of a single request
|
458
|
+
self._largest_prefill_decode_len: int = 0
|
459
|
+
self.last_gen_throughput: float = 0.0
|
460
|
+
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
461
|
+
self.spec_num_total_accepted_tokens = 0
|
462
|
+
self.spec_num_total_forward_ct = 0
|
463
|
+
self.cum_spec_accept_length = 0
|
464
|
+
self.cum_spec_accept_count = 0
|
465
|
+
self.stats = SchedulerStats()
|
466
|
+
if self.enable_metrics:
|
467
|
+
engine_type = "unified"
|
468
|
+
self.metrics_collector = SchedulerMetricsCollector(
|
469
|
+
labels={
|
470
|
+
"model_name": self.server_args.served_model_name,
|
471
|
+
"engine_type": engine_type,
|
472
|
+
},
|
473
|
+
)
|
466
474
|
|
467
475
|
@torch.no_grad()
|
468
476
|
def event_loop_normal(self):
|
@@ -577,6 +585,13 @@ class Scheduler:
|
|
577
585
|
|
578
586
|
def process_input_requests(self, recv_reqs: List):
|
579
587
|
for recv_req in recv_reqs:
|
588
|
+
# If it is a health check generation request and there are running requests, ignore it.
|
589
|
+
if is_health_check_generate_req(recv_req) and (
|
590
|
+
self.chunked_req is not None or self.running_batch is not None
|
591
|
+
):
|
592
|
+
self.return_health_check_ct += 1
|
593
|
+
continue
|
594
|
+
|
580
595
|
output = self._request_dispatcher(recv_req)
|
581
596
|
if output is not None:
|
582
597
|
self.send_to_tokenizer.send_pyobj(output)
|
@@ -591,7 +606,6 @@ class Scheduler:
|
|
591
606
|
or recv_req.session_params.id is None
|
592
607
|
or recv_req.session_params.id not in self.sessions
|
593
608
|
):
|
594
|
-
|
595
609
|
if recv_req.input_embeds is not None:
|
596
610
|
# Generate fake input_ids based on the length of input_embeds
|
597
611
|
seq_length = len(recv_req.input_embeds)
|
@@ -618,10 +632,12 @@ class Scheduler:
|
|
618
632
|
recv_req.sampling_params,
|
619
633
|
return_logprob=recv_req.return_logprob,
|
620
634
|
top_logprobs_num=recv_req.top_logprobs_num,
|
635
|
+
token_ids_logprob=recv_req.token_ids_logprob,
|
621
636
|
stream=recv_req.stream,
|
622
637
|
lora_path=recv_req.lora_path,
|
623
638
|
input_embeds=recv_req.input_embeds,
|
624
639
|
custom_logit_processor=custom_logit_processor,
|
640
|
+
return_hidden_states=recv_req.return_hidden_states,
|
625
641
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
626
642
|
)
|
627
643
|
req.tokenizer = self.tokenizer
|
@@ -633,14 +649,14 @@ class Scheduler:
|
|
633
649
|
req.finished_reason = FINISH_ABORT(
|
634
650
|
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
635
651
|
)
|
636
|
-
self.
|
652
|
+
self._add_request_to_queue(req)
|
637
653
|
return
|
638
654
|
else:
|
639
655
|
# Create a new request from a previous session
|
640
656
|
session = self.sessions[recv_req.session_params.id]
|
641
657
|
req = session.create_req(recv_req, self.tokenizer)
|
642
658
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
643
|
-
self.
|
659
|
+
self._add_request_to_queue(req)
|
644
660
|
return
|
645
661
|
|
646
662
|
# Handle multimodal inputs
|
@@ -664,7 +680,7 @@ class Scheduler:
|
|
664
680
|
req.finished_reason = FINISH_ABORT(
|
665
681
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
666
682
|
)
|
667
|
-
self.
|
683
|
+
self._add_request_to_queue(req)
|
668
684
|
return
|
669
685
|
|
670
686
|
# Validate prompts length
|
@@ -674,16 +690,28 @@ class Scheduler:
|
|
674
690
|
self.server_args.allow_auto_truncate,
|
675
691
|
)
|
676
692
|
if error_msg:
|
677
|
-
|
693
|
+
req.origin_input_ids = [0]
|
694
|
+
req.sampling_params.max_new_tokens = 0
|
695
|
+
self._add_request_to_queue(req)
|
678
696
|
return
|
679
697
|
|
680
698
|
# Copy more attributes
|
681
|
-
if recv_req.logprob_start_len == -1:
|
699
|
+
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
|
682
700
|
# By default, only return the logprobs for output tokens
|
683
701
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
684
702
|
else:
|
685
703
|
req.logprob_start_len = recv_req.logprob_start_len
|
686
704
|
|
705
|
+
if req.logprob_start_len >= len(req.origin_input_ids):
|
706
|
+
req.finished_reason = FINISH_ABORT(
|
707
|
+
f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
|
708
|
+
HTTPStatus.BAD_REQUEST,
|
709
|
+
"BadRequestError",
|
710
|
+
)
|
711
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
712
|
+
self._add_request_to_queue(req)
|
713
|
+
return
|
714
|
+
|
687
715
|
req.sampling_params.max_new_tokens = min(
|
688
716
|
(
|
689
717
|
req.sampling_params.max_new_tokens
|
@@ -699,6 +727,7 @@ class Scheduler:
|
|
699
727
|
req.sampling_params.json_schema is not None
|
700
728
|
or req.sampling_params.regex is not None
|
701
729
|
or req.sampling_params.ebnf is not None
|
730
|
+
or req.sampling_params.structural_tag is not None
|
702
731
|
):
|
703
732
|
assert self.grammar_backend is not None
|
704
733
|
if req.sampling_params.json_schema is not None:
|
@@ -707,6 +736,8 @@ class Scheduler:
|
|
707
736
|
key = ("regex", req.sampling_params.regex)
|
708
737
|
elif req.sampling_params.ebnf is not None:
|
709
738
|
key = ("ebnf", req.sampling_params.ebnf)
|
739
|
+
elif req.sampling_params.structural_tag:
|
740
|
+
key = ("structural_tag", req.sampling_params.structural_tag)
|
710
741
|
|
711
742
|
req.grammar = self.grammar_backend.get_cached_value(key)
|
712
743
|
if not req.grammar:
|
@@ -716,7 +747,13 @@ class Scheduler:
|
|
716
747
|
if add_to_grammar_queue:
|
717
748
|
self.grammar_queue.append(req)
|
718
749
|
else:
|
719
|
-
self.
|
750
|
+
self._add_request_to_queue(req)
|
751
|
+
|
752
|
+
def _add_request_to_queue(self, req: Req):
|
753
|
+
self.waiting_queue.append(req)
|
754
|
+
|
755
|
+
def _extend_requests_to_queue(self, reqs: List[Req]):
|
756
|
+
self.waiting_queue.extend(reqs)
|
720
757
|
|
721
758
|
def handle_embedding_request(
|
722
759
|
self,
|
@@ -737,61 +774,64 @@ class Scheduler:
|
|
737
774
|
self.server_args.allow_auto_truncate,
|
738
775
|
)
|
739
776
|
if error_msg:
|
740
|
-
self.
|
777
|
+
self._add_request_to_queue(req)
|
741
778
|
return
|
742
779
|
|
743
780
|
# Copy more attributes
|
744
781
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
745
|
-
self.
|
782
|
+
self._add_request_to_queue(req)
|
746
783
|
|
747
784
|
def log_prefill_stats(
|
748
785
|
self,
|
749
786
|
adder: PrefillAdder,
|
750
787
|
can_run_list: List[Req],
|
751
|
-
running_bs:
|
752
|
-
has_being_chunked: bool,
|
788
|
+
running_bs: int,
|
753
789
|
):
|
754
|
-
self.tree_cache_metrics["total"] += (
|
755
|
-
adder.log_input_tokens + adder.log_hit_tokens
|
756
|
-
) / 10**9
|
757
|
-
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
758
|
-
tree_cache_hit_rate = (
|
759
|
-
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
760
|
-
)
|
761
|
-
|
762
790
|
num_used = self.max_total_num_tokens - (
|
763
|
-
self.
|
791
|
+
self.token_to_kv_pool_allocator.available_size()
|
792
|
+
+ self.tree_cache.evictable_size()
|
793
|
+
)
|
794
|
+
self._largest_prefill_len = max(
|
795
|
+
self._largest_prefill_len, adder.log_input_tokens
|
764
796
|
)
|
765
797
|
|
766
|
-
|
798
|
+
f = (
|
767
799
|
f"Prefill batch. "
|
768
800
|
f"#new-seq: {len(can_run_list)}, "
|
769
801
|
f"#new-token: {adder.log_input_tokens}, "
|
770
802
|
f"#cached-token: {adder.log_hit_tokens}, "
|
771
|
-
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
772
803
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
773
804
|
f"#running-req: {running_bs}, "
|
774
|
-
f"#queue-req: {len(self.waiting_queue)
|
805
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
775
806
|
)
|
807
|
+
logger.info(f)
|
776
808
|
|
777
809
|
if self.enable_metrics:
|
810
|
+
cache_hit_rate = adder.log_hit_tokens / (
|
811
|
+
adder.log_input_tokens + adder.log_hit_tokens
|
812
|
+
)
|
778
813
|
self.stats.num_running_reqs = running_bs
|
779
814
|
self.stats.num_used_tokens = num_used
|
780
815
|
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
781
|
-
self.stats.num_queue_reqs = len(self.waiting_queue)
|
782
|
-
self.stats.cache_hit_rate =
|
816
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
817
|
+
self.stats.cache_hit_rate = cache_hit_rate
|
783
818
|
self.metrics_collector.log_stats(self.stats)
|
784
819
|
|
785
820
|
def log_decode_stats(self):
|
786
|
-
|
787
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
788
|
-
)
|
789
|
-
gen_throughput = self.num_generated_tokens / (
|
790
|
-
time.time() - self.last_decode_stats_tic
|
791
|
-
)
|
792
|
-
self.num_generated_tokens = 0
|
821
|
+
gap_latency = time.time() - self.last_decode_stats_tic
|
793
822
|
self.last_decode_stats_tic = time.time()
|
823
|
+
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
824
|
+
self.num_generated_tokens = 0
|
794
825
|
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
826
|
+
num_used = self.max_total_num_tokens - (
|
827
|
+
self.token_to_kv_pool_allocator.available_size()
|
828
|
+
+ self.tree_cache.evictable_size()
|
829
|
+
)
|
830
|
+
|
831
|
+
if RECORD_STEP_TIME:
|
832
|
+
self.step_time_dict[num_running_reqs].append(
|
833
|
+
gap_latency / self.server_args.decode_log_interval
|
834
|
+
)
|
795
835
|
|
796
836
|
if self.spec_algorithm.is_none():
|
797
837
|
msg = (
|
@@ -799,14 +839,17 @@ class Scheduler:
|
|
799
839
|
f"#running-req: {num_running_reqs}, "
|
800
840
|
f"#token: {num_used}, "
|
801
841
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
802
|
-
f"gen throughput (token/s): {
|
803
|
-
f"
|
842
|
+
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
843
|
+
f"largest-len: {self._largest_prefill_decode_len}, "
|
844
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
804
845
|
)
|
805
846
|
spec_accept_length = 0
|
806
847
|
else:
|
807
848
|
spec_accept_length = (
|
808
849
|
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
809
850
|
)
|
851
|
+
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
852
|
+
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
810
853
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
811
854
|
msg = (
|
812
855
|
f"Decode batch. "
|
@@ -814,8 +857,9 @@ class Scheduler:
|
|
814
857
|
f"#token: {num_used}, "
|
815
858
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
816
859
|
f"accept len: {spec_accept_length:.2f}, "
|
817
|
-
f"gen throughput (token/s): {
|
818
|
-
f"
|
860
|
+
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
861
|
+
f"largest-len: {self._largest_prefill_decode_len}, "
|
862
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
819
863
|
)
|
820
864
|
|
821
865
|
logger.info(msg)
|
@@ -823,14 +867,16 @@ class Scheduler:
|
|
823
867
|
self.stats.num_running_reqs = num_running_reqs
|
824
868
|
self.stats.num_used_tokens = num_used
|
825
869
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
826
|
-
self.stats.
|
870
|
+
self.stats.cache_hit_rate = 0.0
|
871
|
+
self.stats.gen_throughput = self.last_gen_throughput
|
827
872
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
828
873
|
self.stats.spec_accept_length = spec_accept_length
|
829
874
|
self.metrics_collector.log_stats(self.stats)
|
830
875
|
|
831
876
|
def check_memory(self):
|
832
877
|
available_size = (
|
833
|
-
self.
|
878
|
+
self.token_to_kv_pool_allocator.available_size()
|
879
|
+
+ self.tree_cache.evictable_size()
|
834
880
|
)
|
835
881
|
protected_size = self.tree_cache.protected_size()
|
836
882
|
memory_leak = available_size != (
|
@@ -857,21 +903,42 @@ class Scheduler:
|
|
857
903
|
if crash_on_warnings():
|
858
904
|
raise ValueError(msg)
|
859
905
|
|
906
|
+
if (
|
907
|
+
self.enable_metrics
|
908
|
+
and self.attn_tp_rank == 0
|
909
|
+
and time.time() > self.metrics_collector.last_log_time + 30
|
910
|
+
):
|
911
|
+
# During idle time, also collect metrics every 30 seconds.
|
912
|
+
num_used = self.max_total_num_tokens - (
|
913
|
+
self.token_to_kv_pool_allocator.available_size()
|
914
|
+
+ self.tree_cache.evictable_size()
|
915
|
+
)
|
916
|
+
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
917
|
+
self.stats.num_running_reqs = num_running_reqs
|
918
|
+
self.stats.num_used_tokens = num_used
|
919
|
+
self.stats.token_usage = num_used / self.max_total_num_tokens
|
920
|
+
self.stats.gen_throughput = 0
|
921
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
922
|
+
self.metrics_collector.log_stats(self.stats)
|
923
|
+
|
860
924
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
861
925
|
# Merge the prefill batch into the running batch
|
862
926
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
863
|
-
if self.
|
864
|
-
# Move the chunked request out of the batch
|
865
|
-
|
866
|
-
self.
|
867
|
-
|
868
|
-
|
927
|
+
if self.chunked_req:
|
928
|
+
# Move the chunked request out of the batch so that we can merge
|
929
|
+
# only finished requests to running_batch.
|
930
|
+
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
931
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
932
|
+
# chunked request keeps its rid but will get a new req_pool_idx
|
933
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
869
934
|
self.batch_is_full = False
|
870
935
|
|
936
|
+
self.last_batch.filter_batch()
|
871
937
|
if not self.last_batch.is_empty():
|
872
938
|
if self.running_batch is None:
|
873
939
|
self.running_batch = self.last_batch
|
874
940
|
else:
|
941
|
+
# merge running_batch with prefill batch
|
875
942
|
self.running_batch.merge_batch(self.last_batch)
|
876
943
|
|
877
944
|
new_batch = self.get_new_batch_prefill()
|
@@ -900,7 +967,7 @@ class Scheduler:
|
|
900
967
|
# Handle the cases where prefill is not allowed
|
901
968
|
if (
|
902
969
|
self.batch_is_full or len(self.waiting_queue) == 0
|
903
|
-
) and self.
|
970
|
+
) and self.chunked_req is None:
|
904
971
|
return None
|
905
972
|
|
906
973
|
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
@@ -914,7 +981,7 @@ class Scheduler:
|
|
914
981
|
# Prefill policy
|
915
982
|
adder = PrefillAdder(
|
916
983
|
self.tree_cache,
|
917
|
-
self.
|
984
|
+
self.token_to_kv_pool_allocator,
|
918
985
|
self.running_batch,
|
919
986
|
self.new_token_ratio,
|
920
987
|
self.max_prefill_tokens,
|
@@ -922,10 +989,10 @@ class Scheduler:
|
|
922
989
|
running_bs if self.is_mixed_chunk else 0,
|
923
990
|
)
|
924
991
|
|
925
|
-
|
926
|
-
if
|
927
|
-
self.
|
928
|
-
self.
|
992
|
+
is_chunked = self.chunked_req is not None
|
993
|
+
if is_chunked:
|
994
|
+
self.chunked_req.init_next_round_input()
|
995
|
+
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
929
996
|
|
930
997
|
if self.lora_paths:
|
931
998
|
lora_set = (
|
@@ -933,7 +1000,6 @@ class Scheduler:
|
|
933
1000
|
if self.running_batch is not None
|
934
1001
|
else set([])
|
935
1002
|
)
|
936
|
-
|
937
1003
|
# Get requests from the waiting queue to a new prefill batch
|
938
1004
|
for req in self.waiting_queue:
|
939
1005
|
if (
|
@@ -953,7 +1019,31 @@ class Scheduler:
|
|
953
1019
|
break
|
954
1020
|
|
955
1021
|
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
956
|
-
|
1022
|
+
|
1023
|
+
if self.enable_hierarchical_cache and req.last_node is not None:
|
1024
|
+
if req.last_node.evicted:
|
1025
|
+
# loading KV cache for the request
|
1026
|
+
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
1027
|
+
req.last_node,
|
1028
|
+
req.prefix_indices,
|
1029
|
+
adder.rem_total_tokens,
|
1030
|
+
)
|
1031
|
+
if req.last_node.loading:
|
1032
|
+
# to prevent frequent cache invalidation
|
1033
|
+
if req.rid in self.staging_reqs:
|
1034
|
+
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
1035
|
+
self.tree_cache.inc_lock_ref(req.last_node)
|
1036
|
+
self.staging_reqs[req.rid] = req.last_node
|
1037
|
+
continue
|
1038
|
+
elif req.last_node.loading:
|
1039
|
+
if not self.tree_cache.loading_complete(req.last_node):
|
1040
|
+
continue
|
1041
|
+
|
1042
|
+
if req.rid in self.staging_reqs:
|
1043
|
+
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
1044
|
+
del self.staging_reqs[req.rid]
|
1045
|
+
|
1046
|
+
res = adder.add_one_req(req, self.chunked_req)
|
957
1047
|
if res != AddReqResult.CONTINUE:
|
958
1048
|
if res == AddReqResult.NO_TOKEN:
|
959
1049
|
if self.enable_hierarchical_cache:
|
@@ -965,39 +1055,36 @@ class Scheduler:
|
|
965
1055
|
else:
|
966
1056
|
self.batch_is_full = True
|
967
1057
|
break
|
968
|
-
if self.server_args.prefill_only_one_req:
|
969
|
-
break
|
970
1058
|
|
971
1059
|
# Update waiting queue
|
972
|
-
can_run_list = adder.can_run_list
|
1060
|
+
can_run_list: List[Req] = adder.can_run_list
|
973
1061
|
if len(can_run_list) == 0:
|
974
1062
|
return None
|
975
1063
|
self.waiting_queue = [
|
976
1064
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
977
1065
|
]
|
978
1066
|
|
979
|
-
if adder.
|
980
|
-
assert self.
|
981
|
-
self.
|
1067
|
+
if adder.new_chunked_req is not None:
|
1068
|
+
assert self.chunked_req is None
|
1069
|
+
self.chunked_req = adder.new_chunked_req
|
982
1070
|
|
983
|
-
if self.
|
984
|
-
self.
|
1071
|
+
if self.chunked_req:
|
1072
|
+
self.chunked_req.is_chunked += 1
|
985
1073
|
|
986
1074
|
# Print stats
|
987
1075
|
if self.attn_tp_rank == 0:
|
988
|
-
self.log_prefill_stats(adder, can_run_list, running_bs
|
1076
|
+
self.log_prefill_stats(adder, can_run_list, running_bs)
|
989
1077
|
|
990
1078
|
# Create a new batch
|
991
1079
|
new_batch = ScheduleBatch.init_new(
|
992
1080
|
can_run_list,
|
993
1081
|
self.req_to_token_pool,
|
994
|
-
self.
|
1082
|
+
self.token_to_kv_pool_allocator,
|
995
1083
|
self.tree_cache,
|
996
1084
|
self.model_config,
|
997
1085
|
self.enable_overlap,
|
998
1086
|
self.spec_algorithm,
|
999
1087
|
self.server_args.enable_custom_logit_processor,
|
1000
|
-
self.server_args.return_hidden_states,
|
1001
1088
|
)
|
1002
1089
|
new_batch.prepare_for_extend()
|
1003
1090
|
|
@@ -1021,8 +1108,6 @@ class Scheduler:
|
|
1021
1108
|
|
1022
1109
|
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
|
1023
1110
|
"""Update the current running decoding batch."""
|
1024
|
-
global test_retract
|
1025
|
-
|
1026
1111
|
initial_bs = batch.batch_size()
|
1027
1112
|
|
1028
1113
|
batch.filter_batch()
|
@@ -1032,35 +1117,25 @@ class Scheduler:
|
|
1032
1117
|
|
1033
1118
|
# Check if decode out of memory
|
1034
1119
|
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
1035
|
-
|
1120
|
+
TEST_RETRACT and batch.batch_size() > 10
|
1036
1121
|
):
|
1037
1122
|
old_ratio = self.new_token_ratio
|
1038
1123
|
|
1039
|
-
retracted_reqs, new_token_ratio = batch.retract_decode()
|
1124
|
+
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
|
1040
1125
|
self.new_token_ratio = new_token_ratio
|
1041
|
-
if self.draft_worker:
|
1042
|
-
self.draft_worker.finish_request(retracted_reqs)
|
1043
1126
|
|
1044
1127
|
logger.info(
|
1045
1128
|
"Decode out of memory happened. "
|
1046
1129
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
1047
1130
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
1048
1131
|
)
|
1049
|
-
self.
|
1132
|
+
self._extend_requests_to_queue(retracted_reqs)
|
1050
1133
|
else:
|
1051
1134
|
self.new_token_ratio = max(
|
1052
1135
|
self.new_token_ratio - self.new_token_ratio_decay,
|
1053
1136
|
self.min_new_token_ratio,
|
1054
1137
|
)
|
1055
1138
|
|
1056
|
-
# Check for jump-forward
|
1057
|
-
if not self.disable_jump_forward and batch.has_grammar:
|
1058
|
-
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
1059
|
-
self.waiting_queue.extend(jump_forward_reqs)
|
1060
|
-
if batch.is_empty():
|
1061
|
-
self.batch_is_full = False
|
1062
|
-
return None
|
1063
|
-
|
1064
1139
|
if batch.batch_size() < initial_bs:
|
1065
1140
|
self.batch_is_full = False
|
1066
1141
|
|
@@ -1074,17 +1149,26 @@ class Scheduler:
|
|
1074
1149
|
"""Run a batch."""
|
1075
1150
|
self.forward_ct += 1
|
1076
1151
|
|
1152
|
+
# Check profiler
|
1153
|
+
if (
|
1154
|
+
self.profiler_target_forward_ct
|
1155
|
+
and self.profiler_target_forward_ct <= self.forward_ct
|
1156
|
+
):
|
1157
|
+
self.stop_profile()
|
1158
|
+
|
1159
|
+
# Run forward
|
1077
1160
|
if self.is_generation:
|
1078
1161
|
if self.spec_algorithm.is_none():
|
1079
1162
|
model_worker_batch = batch.get_model_worker_batch()
|
1080
1163
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
1081
1164
|
model_worker_batch
|
1082
1165
|
)
|
1166
|
+
bid = model_worker_batch.bid
|
1083
1167
|
else:
|
1084
1168
|
(
|
1085
1169
|
logits_output,
|
1086
1170
|
next_token_ids,
|
1087
|
-
|
1171
|
+
bid,
|
1088
1172
|
num_accepted_tokens,
|
1089
1173
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1090
1174
|
self.spec_num_total_accepted_tokens += (
|
@@ -1094,10 +1178,24 @@ class Scheduler:
|
|
1094
1178
|
self.num_generated_tokens += num_accepted_tokens
|
1095
1179
|
batch.output_ids = next_token_ids
|
1096
1180
|
|
1181
|
+
# These 2 values are needed for processing the output, but the values can be
|
1182
|
+
# modified by overlap schedule. So we have to copy them here so that
|
1183
|
+
# we can use the correct values in output processing.
|
1184
|
+
if batch.return_logprob:
|
1185
|
+
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
1186
|
+
extend_logprob_start_len_per_req = [
|
1187
|
+
req.extend_logprob_start_len for req in batch.reqs
|
1188
|
+
]
|
1189
|
+
else:
|
1190
|
+
extend_input_len_per_req = None
|
1191
|
+
extend_logprob_start_len_per_req = None
|
1192
|
+
|
1097
1193
|
ret = GenerationBatchResult(
|
1098
1194
|
logits_output=logits_output,
|
1099
1195
|
next_token_ids=next_token_ids,
|
1100
|
-
|
1196
|
+
extend_input_len_per_req=extend_input_len_per_req,
|
1197
|
+
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1198
|
+
bid=bid,
|
1101
1199
|
)
|
1102
1200
|
else: # embedding or reward model
|
1103
1201
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -1121,11 +1219,22 @@ class Scheduler:
|
|
1121
1219
|
elif batch.forward_mode.is_idle():
|
1122
1220
|
if self.enable_overlap:
|
1123
1221
|
self.tp_worker.resolve_batch_result(result.bid)
|
1222
|
+
if batch.next_batch_sampling_info:
|
1223
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1224
|
+
self.current_stream.synchronize()
|
1225
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1124
1226
|
elif batch.forward_mode.is_dummy_first():
|
1125
1227
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1126
1228
|
self.current_stream.synchronize()
|
1127
1229
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1128
1230
|
|
1231
|
+
if self.return_health_check_ct:
|
1232
|
+
# Return some signal for the health check.
|
1233
|
+
# This is used to prevent the health check signal being blocked by long context prefill.
|
1234
|
+
# However, one minor issue is that this code path does not check the status of detokenizer manager.
|
1235
|
+
self.return_health_check_ct -= 1
|
1236
|
+
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
1237
|
+
|
1129
1238
|
def process_batch_result_prefill(
|
1130
1239
|
self,
|
1131
1240
|
batch: ScheduleBatch,
|
@@ -1137,10 +1246,14 @@ class Scheduler:
|
|
1137
1246
|
(
|
1138
1247
|
logits_output,
|
1139
1248
|
next_token_ids,
|
1249
|
+
extend_input_len_per_req,
|
1250
|
+
extend_logprob_start_len_per_req,
|
1140
1251
|
bid,
|
1141
1252
|
) = (
|
1142
1253
|
result.logits_output,
|
1143
1254
|
result.next_token_ids,
|
1255
|
+
result.extend_input_len_per_req,
|
1256
|
+
result.extend_logprob_start_len_per_req,
|
1144
1257
|
result.bid,
|
1145
1258
|
)
|
1146
1259
|
|
@@ -1150,12 +1263,14 @@ class Scheduler:
|
|
1150
1263
|
# Move next_token_ids and logprobs to cpu
|
1151
1264
|
next_token_ids = next_token_ids.tolist()
|
1152
1265
|
if batch.return_logprob:
|
1153
|
-
logits_output.next_token_logprobs
|
1154
|
-
logits_output.next_token_logprobs
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1266
|
+
if logits_output.next_token_logprobs is not None:
|
1267
|
+
logits_output.next_token_logprobs = (
|
1268
|
+
logits_output.next_token_logprobs.tolist()
|
1269
|
+
)
|
1270
|
+
if logits_output.input_token_logprobs is not None:
|
1271
|
+
logits_output.input_token_logprobs = tuple(
|
1272
|
+
logits_output.input_token_logprobs.tolist()
|
1273
|
+
)
|
1159
1274
|
|
1160
1275
|
hidden_state_offset = 0
|
1161
1276
|
|
@@ -1168,25 +1283,38 @@ class Scheduler:
|
|
1168
1283
|
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
1169
1284
|
# Free the one delayed token for the mixed decode batch
|
1170
1285
|
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
1171
|
-
self.
|
1286
|
+
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
1172
1287
|
continue
|
1173
1288
|
|
1174
|
-
if req.
|
1289
|
+
if req.is_chunked <= 0:
|
1290
|
+
# req output_ids are set here
|
1175
1291
|
req.output_ids.append(next_token_id)
|
1176
1292
|
req.check_finished()
|
1177
1293
|
|
1178
1294
|
if req.finished():
|
1179
1295
|
self.tree_cache.cache_finished_req(req)
|
1180
1296
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
1297
|
+
# This updates radix so others can match
|
1181
1298
|
self.tree_cache.cache_unfinished_req(req)
|
1182
1299
|
|
1183
1300
|
if req.return_logprob:
|
1184
|
-
|
1185
|
-
|
1301
|
+
assert extend_logprob_start_len_per_req is not None
|
1302
|
+
assert extend_input_len_per_req is not None
|
1303
|
+
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
1304
|
+
extend_input_len = extend_input_len_per_req[i]
|
1305
|
+
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
1306
|
+
self.add_logprob_return_values(
|
1307
|
+
i,
|
1308
|
+
req,
|
1309
|
+
logprob_pt,
|
1310
|
+
next_token_ids,
|
1311
|
+
num_input_logprobs,
|
1312
|
+
logits_output,
|
1186
1313
|
)
|
1314
|
+
logprob_pt += num_input_logprobs
|
1187
1315
|
|
1188
1316
|
if (
|
1189
|
-
|
1317
|
+
req.return_hidden_states
|
1190
1318
|
and logits_output.hidden_states is not None
|
1191
1319
|
):
|
1192
1320
|
req.hidden_states.append(
|
@@ -1205,12 +1333,31 @@ class Scheduler:
|
|
1205
1333
|
req.grammar.finished = req.finished()
|
1206
1334
|
else:
|
1207
1335
|
# being chunked reqs' prefill is not finished
|
1208
|
-
req.
|
1336
|
+
req.is_chunked -= 1
|
1209
1337
|
# There is only at most one request being currently chunked.
|
1210
1338
|
# Because this request does not finish prefill,
|
1211
1339
|
# we don't want to stream the request currently being chunked.
|
1212
1340
|
skip_stream_req = req
|
1213
1341
|
|
1342
|
+
# Incrementally update input logprobs.
|
1343
|
+
if req.return_logprob:
|
1344
|
+
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
1345
|
+
extend_input_len = extend_input_len_per_req[i]
|
1346
|
+
if extend_logprob_start_len < extend_input_len:
|
1347
|
+
# Update input logprobs.
|
1348
|
+
num_input_logprobs = (
|
1349
|
+
extend_input_len - extend_logprob_start_len
|
1350
|
+
)
|
1351
|
+
self.add_input_logprob_return_values(
|
1352
|
+
i,
|
1353
|
+
req,
|
1354
|
+
logits_output,
|
1355
|
+
logprob_pt,
|
1356
|
+
num_input_logprobs,
|
1357
|
+
last_prefill_chunk=False,
|
1358
|
+
)
|
1359
|
+
logprob_pt += num_input_logprobs
|
1360
|
+
|
1214
1361
|
if batch.next_batch_sampling_info:
|
1215
1362
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1216
1363
|
self.current_stream.synchronize()
|
@@ -1226,7 +1373,7 @@ class Scheduler:
|
|
1226
1373
|
continue
|
1227
1374
|
|
1228
1375
|
req.embedding = embeddings[i]
|
1229
|
-
if req.
|
1376
|
+
if req.is_chunked <= 0:
|
1230
1377
|
# Dummy output token for embedding models
|
1231
1378
|
req.output_ids.append(0)
|
1232
1379
|
req.check_finished()
|
@@ -1237,7 +1384,7 @@ class Scheduler:
|
|
1237
1384
|
self.tree_cache.cache_unfinished_req(req)
|
1238
1385
|
else:
|
1239
1386
|
# being chunked reqs' prefill is not finished
|
1240
|
-
req.
|
1387
|
+
req.is_chunked -= 1
|
1241
1388
|
|
1242
1389
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
1243
1390
|
|
@@ -1254,23 +1401,27 @@ class Scheduler:
|
|
1254
1401
|
self.num_generated_tokens += len(batch.reqs)
|
1255
1402
|
|
1256
1403
|
if self.enable_overlap:
|
1404
|
+
assert batch.spec_algorithm.is_none()
|
1257
1405
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
1258
1406
|
next_token_logprobs = logits_output.next_token_logprobs
|
1259
|
-
|
1407
|
+
elif batch.spec_algorithm.is_none():
|
1408
|
+
# spec decoding handles output logprobs inside verify process.
|
1260
1409
|
next_token_ids = next_token_ids.tolist()
|
1261
1410
|
if batch.return_logprob:
|
1262
1411
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
1263
1412
|
|
1264
|
-
self.
|
1413
|
+
self.token_to_kv_pool_allocator.free_group_begin()
|
1265
1414
|
|
1266
1415
|
# Check finish condition
|
1416
|
+
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
1417
|
+
# We should ignore using next_token_ids for spec decoding cases.
|
1267
1418
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
1268
1419
|
if req.is_retracted:
|
1269
1420
|
continue
|
1270
1421
|
|
1271
1422
|
if self.enable_overlap and req.finished():
|
1272
1423
|
# Free the one delayed token
|
1273
|
-
self.
|
1424
|
+
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
1274
1425
|
continue
|
1275
1426
|
|
1276
1427
|
if batch.spec_algorithm.is_none():
|
@@ -1278,11 +1429,11 @@ class Scheduler:
|
|
1278
1429
|
req.output_ids.append(next_token_id)
|
1279
1430
|
|
1280
1431
|
req.check_finished()
|
1281
|
-
|
1282
1432
|
if req.finished():
|
1283
1433
|
self.tree_cache.cache_finished_req(req)
|
1284
1434
|
|
1285
|
-
if req.return_logprob:
|
1435
|
+
if req.return_logprob and batch.spec_algorithm.is_none():
|
1436
|
+
# speculative worker handles logprob in speculative decoding
|
1286
1437
|
req.output_token_logprobs_val.append(next_token_logprobs[i])
|
1287
1438
|
req.output_token_logprobs_idx.append(next_token_id)
|
1288
1439
|
if req.top_logprobs_num > 0:
|
@@ -1292,14 +1443,18 @@ class Scheduler:
|
|
1292
1443
|
req.output_top_logprobs_idx.append(
|
1293
1444
|
logits_output.next_token_top_logprobs_idx[i]
|
1294
1445
|
)
|
1446
|
+
if req.token_ids_logprob is not None:
|
1447
|
+
req.output_token_ids_logprobs_val.append(
|
1448
|
+
logits_output.next_token_token_ids_logprobs_val[i]
|
1449
|
+
)
|
1450
|
+
req.output_token_ids_logprobs_idx.append(
|
1451
|
+
logits_output.next_token_token_ids_logprobs_idx[i]
|
1452
|
+
)
|
1295
1453
|
|
1296
|
-
if
|
1297
|
-
self.server_args.return_hidden_states
|
1298
|
-
and logits_output.hidden_states is not None
|
1299
|
-
):
|
1454
|
+
if req.return_hidden_states and logits_output.hidden_states is not None:
|
1300
1455
|
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
1301
1456
|
|
1302
|
-
if req.grammar is not None:
|
1457
|
+
if req.grammar is not None and batch.spec_algorithm.is_none():
|
1303
1458
|
req.grammar.accept_token(next_token_id)
|
1304
1459
|
req.grammar.finished = req.finished()
|
1305
1460
|
|
@@ -1310,7 +1465,7 @@ class Scheduler:
|
|
1310
1465
|
|
1311
1466
|
self.stream_output(batch.reqs, batch.return_logprob)
|
1312
1467
|
|
1313
|
-
self.
|
1468
|
+
self.token_to_kv_pool_allocator.free_group_end()
|
1314
1469
|
|
1315
1470
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
1316
1471
|
if (
|
@@ -1319,86 +1474,169 @@ class Scheduler:
|
|
1319
1474
|
):
|
1320
1475
|
self.log_decode_stats()
|
1321
1476
|
|
1322
|
-
def
|
1477
|
+
def add_input_logprob_return_values(
|
1323
1478
|
self,
|
1324
1479
|
i: int,
|
1325
1480
|
req: Req,
|
1326
|
-
pt: int,
|
1327
|
-
next_token_ids: List[int],
|
1328
1481
|
output: LogitsProcessorOutput,
|
1482
|
+
logprob_pt: int,
|
1483
|
+
num_input_logprobs: int,
|
1484
|
+
last_prefill_chunk: bool, # If True, it means prefill is finished.
|
1329
1485
|
):
|
1330
|
-
"""
|
1331
|
-
|
1332
|
-
|
1486
|
+
"""Incrementally add input logprobs to `req`.
|
1487
|
+
|
1488
|
+
Args:
|
1489
|
+
i: The request index in a batch.
|
1490
|
+
req: The request. Input logprobs inside req are modified as a
|
1491
|
+
consequence of the API
|
1492
|
+
fill_ids: The prefill ids processed.
|
1493
|
+
output: Logit processor output that's used to compute input logprobs
|
1494
|
+
last_prefill_chunk: True if it is the last prefill (when chunked).
|
1495
|
+
Some of input logprob operation should only happen at the last
|
1496
|
+
prefill (e.g., computing input token logprobs).
|
1497
|
+
"""
|
1498
|
+
assert output.input_token_logprobs is not None
|
1499
|
+
if req.input_token_logprobs is None:
|
1500
|
+
req.input_token_logprobs = []
|
1501
|
+
if req.temp_input_top_logprobs_val is None:
|
1502
|
+
req.temp_input_top_logprobs_val = []
|
1503
|
+
if req.temp_input_top_logprobs_idx is None:
|
1504
|
+
req.temp_input_top_logprobs_idx = []
|
1505
|
+
if req.temp_input_token_ids_logprobs_val is None:
|
1506
|
+
req.temp_input_token_ids_logprobs_val = []
|
1507
|
+
if req.temp_input_token_ids_logprobs_idx is None:
|
1508
|
+
req.temp_input_token_ids_logprobs_idx = []
|
1509
|
+
|
1510
|
+
if req.input_token_logprobs_val is not None:
|
1511
|
+
# The input logprob has been already computed. It only happens
|
1512
|
+
# upon retract.
|
1513
|
+
if req.top_logprobs_num > 0:
|
1514
|
+
assert req.input_token_logprobs_val is not None
|
1515
|
+
return
|
1333
1516
|
|
1334
|
-
#
|
1335
|
-
|
1517
|
+
# Important for the performance.
|
1518
|
+
assert isinstance(output.input_token_logprobs, tuple)
|
1519
|
+
input_token_logprobs: Tuple[int] = output.input_token_logprobs
|
1520
|
+
input_token_logprobs = input_token_logprobs[
|
1521
|
+
logprob_pt : logprob_pt + num_input_logprobs
|
1522
|
+
]
|
1523
|
+
req.input_token_logprobs.extend(input_token_logprobs)
|
1336
1524
|
|
1337
|
-
if req.
|
1338
|
-
|
1339
|
-
|
1340
|
-
]
|
1525
|
+
if req.top_logprobs_num > 0:
|
1526
|
+
req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
|
1527
|
+
req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
|
1341
1528
|
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1529
|
+
if req.token_ids_logprob is not None:
|
1530
|
+
req.temp_input_token_ids_logprobs_val.append(
|
1531
|
+
output.input_token_ids_logprobs_val[i]
|
1532
|
+
)
|
1533
|
+
req.temp_input_token_ids_logprobs_idx.append(
|
1534
|
+
output.input_token_ids_logprobs_idx[i]
|
1535
|
+
)
|
1536
|
+
|
1537
|
+
if last_prefill_chunk:
|
1538
|
+
input_token_logprobs = req.input_token_logprobs
|
1539
|
+
req.input_token_logprobs = None
|
1540
|
+
assert req.input_token_logprobs_val is None
|
1541
|
+
assert req.input_token_logprobs_idx is None
|
1542
|
+
assert req.input_top_logprobs_val is None
|
1543
|
+
assert req.input_top_logprobs_idx is None
|
1544
|
+
|
1545
|
+
# Compute input_token_logprobs_val
|
1546
|
+
# Always pad the first one with None.
|
1547
|
+
req.input_token_logprobs_val = [None]
|
1548
|
+
req.input_token_logprobs_val.extend(input_token_logprobs)
|
1549
|
+
# The last input logprob is for sampling, so just pop it out.
|
1550
|
+
req.input_token_logprobs_val.pop()
|
1551
|
+
|
1552
|
+
# Compute input_token_logprobs_idx
|
1553
|
+
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
1348
1554
|
# Clip the padded hash values from image tokens.
|
1349
1555
|
# Otherwise, it will lead to detokenization errors.
|
1350
1556
|
input_token_logprobs_idx = [
|
1351
1557
|
x if x < self.model_config.vocab_size - 1 else 0
|
1352
1558
|
for x in input_token_logprobs_idx
|
1353
1559
|
]
|
1560
|
+
req.input_token_logprobs_idx = input_token_logprobs_idx
|
1354
1561
|
|
1355
|
-
if
|
1356
|
-
req.
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1562
|
+
if req.top_logprobs_num > 0:
|
1563
|
+
req.input_top_logprobs_val = [None]
|
1564
|
+
req.input_top_logprobs_idx = [None]
|
1565
|
+
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
1566
|
+
req.temp_input_token_ids_logprobs_idx
|
1567
|
+
)
|
1568
|
+
for val, idx in zip(
|
1569
|
+
req.temp_input_top_logprobs_val,
|
1570
|
+
req.temp_input_top_logprobs_idx,
|
1571
|
+
strict=True,
|
1572
|
+
):
|
1573
|
+
req.input_top_logprobs_val.extend(val)
|
1574
|
+
req.input_top_logprobs_idx.extend(idx)
|
1575
|
+
|
1576
|
+
# Last token is a sample token.
|
1577
|
+
req.input_top_logprobs_val.pop()
|
1578
|
+
req.input_top_logprobs_idx.pop()
|
1579
|
+
req.temp_input_top_logprobs_idx = None
|
1580
|
+
req.temp_input_top_logprobs_val = None
|
1581
|
+
|
1582
|
+
if req.token_ids_logprob is not None:
|
1583
|
+
req.input_token_ids_logprobs_val = [None]
|
1584
|
+
req.input_token_ids_logprobs_idx = [None]
|
1585
|
+
|
1586
|
+
for val, idx in zip(
|
1587
|
+
req.temp_input_token_ids_logprobs_val,
|
1588
|
+
req.temp_input_token_ids_logprobs_idx,
|
1589
|
+
strict=True,
|
1590
|
+
):
|
1591
|
+
req.input_token_ids_logprobs_val.extend(val)
|
1592
|
+
req.input_token_ids_logprobs_idx.extend(idx)
|
1360
1593
|
|
1361
|
-
|
1362
|
-
|
1594
|
+
# Last token is a sample token.
|
1595
|
+
req.input_token_ids_logprobs_val.pop()
|
1596
|
+
req.input_token_ids_logprobs_idx.pop()
|
1597
|
+
req.temp_input_token_ids_logprobs_idx = None
|
1598
|
+
req.temp_input_token_ids_logprobs_val = None
|
1363
1599
|
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
],
|
1375
|
-
)
|
1376
|
-
req.output_token_logprobs_idx.extend(
|
1377
|
-
req.fill_ids[
|
1378
|
-
len(req.fill_ids)
|
1379
|
-
- req.last_update_decode_tokens : len(req.fill_ids)
|
1380
|
-
]
|
1381
|
-
)
|
1600
|
+
if req.return_logprob:
|
1601
|
+
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
|
1602
|
+
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
1603
|
+
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
1604
|
+
if req.top_logprobs_num > 0:
|
1605
|
+
assert len(req.input_top_logprobs_val) == relevant_tokens_len
|
1606
|
+
assert len(req.input_top_logprobs_idx) == relevant_tokens_len
|
1607
|
+
if req.token_ids_logprob is not None:
|
1608
|
+
assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
|
1609
|
+
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
1382
1610
|
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
|
1611
|
+
def add_logprob_return_values(
|
1612
|
+
self,
|
1613
|
+
i: int,
|
1614
|
+
req: Req,
|
1615
|
+
pt: int,
|
1616
|
+
next_token_ids: List[int],
|
1617
|
+
num_input_logprobs: int,
|
1618
|
+
output: LogitsProcessorOutput,
|
1619
|
+
):
|
1620
|
+
"""Attach logprobs to the return values."""
|
1621
|
+
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
1622
|
+
req.output_token_logprobs_idx.append(next_token_ids[i])
|
1623
|
+
|
1624
|
+
self.add_input_logprob_return_values(
|
1625
|
+
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
1626
|
+
)
|
1398
1627
|
|
1628
|
+
if req.top_logprobs_num > 0:
|
1399
1629
|
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
1400
1630
|
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
1401
1631
|
|
1632
|
+
if req.token_ids_logprob is not None:
|
1633
|
+
req.output_token_ids_logprobs_val.append(
|
1634
|
+
output.next_token_token_ids_logprobs_val[i]
|
1635
|
+
)
|
1636
|
+
req.output_token_ids_logprobs_idx.append(
|
1637
|
+
output.next_token_token_ids_logprobs_idx[i]
|
1638
|
+
)
|
1639
|
+
|
1402
1640
|
return num_input_logprobs
|
1403
1641
|
|
1404
1642
|
def stream_output(
|
@@ -1409,7 +1647,6 @@ class Scheduler:
|
|
1409
1647
|
finished_reasons: List[BaseFinishReason] = []
|
1410
1648
|
|
1411
1649
|
if self.is_generation:
|
1412
|
-
vids = []
|
1413
1650
|
decoded_texts = []
|
1414
1651
|
decode_ids_list = []
|
1415
1652
|
read_offsets = []
|
@@ -1422,7 +1659,7 @@ class Scheduler:
|
|
1422
1659
|
completion_tokens = []
|
1423
1660
|
cached_tokens = []
|
1424
1661
|
spec_verify_ct = []
|
1425
|
-
|
1662
|
+
output_hidden_states = None
|
1426
1663
|
|
1427
1664
|
if return_logprob:
|
1428
1665
|
input_token_logprobs_val = []
|
@@ -1433,33 +1670,46 @@ class Scheduler:
|
|
1433
1670
|
input_top_logprobs_idx = []
|
1434
1671
|
output_top_logprobs_val = []
|
1435
1672
|
output_top_logprobs_idx = []
|
1673
|
+
input_token_ids_logprobs_val = []
|
1674
|
+
input_token_ids_logprobs_idx = []
|
1675
|
+
output_token_ids_logprobs_val = []
|
1676
|
+
output_token_ids_logprobs_idx = []
|
1436
1677
|
else:
|
1437
1678
|
input_token_logprobs_val = input_token_logprobs_idx = (
|
1438
1679
|
output_token_logprobs_val
|
1439
1680
|
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
1440
1681
|
input_top_logprobs_idx
|
1441
|
-
) = output_top_logprobs_val = output_top_logprobs_idx =
|
1682
|
+
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
1683
|
+
input_token_ids_logprobs_val
|
1684
|
+
) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
|
1685
|
+
output_token_ids_logprobs_idx
|
1686
|
+
) = None
|
1442
1687
|
|
1443
1688
|
for req in reqs:
|
1444
1689
|
if req is skip_req:
|
1445
1690
|
continue
|
1446
1691
|
|
1447
|
-
#
|
1692
|
+
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
|
1693
|
+
if self.model_config.is_multimodal_gen and req.to_abort:
|
1694
|
+
continue
|
1695
|
+
|
1448
1696
|
if (
|
1449
1697
|
req.finished()
|
1450
1698
|
# If stream, follow the given stream_interval
|
1451
1699
|
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
1452
1700
|
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
1453
|
-
|
1701
|
+
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
|
1702
|
+
# always increase one-by-one.
|
1703
|
+
or (
|
1704
|
+
not req.stream
|
1705
|
+
and len(req.output_ids) % 50 == 0
|
1706
|
+
and not self.model_config.is_multimodal_gen
|
1707
|
+
)
|
1454
1708
|
):
|
1455
|
-
if self.draft_worker and req.finished():
|
1456
|
-
self.draft_worker.finish_request(req)
|
1457
|
-
|
1458
1709
|
rids.append(req.rid)
|
1459
1710
|
finished_reasons.append(
|
1460
1711
|
req.finished_reason.to_json() if req.finished_reason else None
|
1461
1712
|
)
|
1462
|
-
vids.append(req.vid)
|
1463
1713
|
decoded_texts.append(req.decoded_text)
|
1464
1714
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
1465
1715
|
decode_ids_list.append(decode_ids)
|
@@ -1488,16 +1738,32 @@ class Scheduler:
|
|
1488
1738
|
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
1489
1739
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1490
1740
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1741
|
+
input_token_ids_logprobs_val.append(
|
1742
|
+
req.input_token_ids_logprobs_val
|
1743
|
+
)
|
1744
|
+
input_token_ids_logprobs_idx.append(
|
1745
|
+
req.input_token_ids_logprobs_idx
|
1746
|
+
)
|
1747
|
+
output_token_ids_logprobs_val.append(
|
1748
|
+
req.output_token_ids_logprobs_val
|
1749
|
+
)
|
1750
|
+
output_token_ids_logprobs_idx.append(
|
1751
|
+
req.output_token_ids_logprobs_idx
|
1752
|
+
)
|
1491
1753
|
|
1492
|
-
|
1754
|
+
if req.return_hidden_states:
|
1755
|
+
if output_hidden_states is None:
|
1756
|
+
output_hidden_states = []
|
1757
|
+
output_hidden_states.append(req.hidden_states)
|
1493
1758
|
|
1494
1759
|
# Send to detokenizer
|
1495
1760
|
if rids:
|
1761
|
+
if self.model_config.is_multimodal_gen:
|
1762
|
+
raise NotImplementedError()
|
1496
1763
|
self.send_to_detokenizer.send_pyobj(
|
1497
1764
|
BatchTokenIDOut(
|
1498
1765
|
rids,
|
1499
1766
|
finished_reasons,
|
1500
|
-
vids,
|
1501
1767
|
decoded_texts,
|
1502
1768
|
decode_ids_list,
|
1503
1769
|
read_offsets,
|
@@ -1517,20 +1783,28 @@ class Scheduler:
|
|
1517
1783
|
input_top_logprobs_idx,
|
1518
1784
|
output_top_logprobs_val,
|
1519
1785
|
output_top_logprobs_idx,
|
1520
|
-
|
1786
|
+
input_token_ids_logprobs_val,
|
1787
|
+
input_token_ids_logprobs_idx,
|
1788
|
+
output_token_ids_logprobs_val,
|
1789
|
+
output_token_ids_logprobs_idx,
|
1790
|
+
output_hidden_states,
|
1521
1791
|
)
|
1522
1792
|
)
|
1523
1793
|
else: # embedding or reward model
|
1524
1794
|
embeddings = []
|
1525
1795
|
prompt_tokens = []
|
1796
|
+
cached_tokens = []
|
1526
1797
|
for req in reqs:
|
1527
1798
|
if req.finished():
|
1528
1799
|
rids.append(req.rid)
|
1529
1800
|
finished_reasons.append(req.finished_reason.to_json())
|
1530
1801
|
embeddings.append(req.embedding)
|
1531
1802
|
prompt_tokens.append(len(req.origin_input_ids))
|
1803
|
+
cached_tokens.append(req.cached_tokens)
|
1532
1804
|
self.send_to_detokenizer.send_pyobj(
|
1533
|
-
BatchEmbeddingOut(
|
1805
|
+
BatchEmbeddingOut(
|
1806
|
+
rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
|
1807
|
+
)
|
1534
1808
|
)
|
1535
1809
|
|
1536
1810
|
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
@@ -1575,13 +1849,12 @@ class Scheduler:
|
|
1575
1849
|
idle_batch = ScheduleBatch.init_new(
|
1576
1850
|
[],
|
1577
1851
|
self.req_to_token_pool,
|
1578
|
-
self.
|
1852
|
+
self.token_to_kv_pool_allocator,
|
1579
1853
|
self.tree_cache,
|
1580
1854
|
self.model_config,
|
1581
1855
|
self.enable_overlap,
|
1582
1856
|
self.spec_algorithm,
|
1583
1857
|
self.server_args.enable_custom_logit_processor,
|
1584
|
-
self.server_args.return_hidden_states,
|
1585
1858
|
)
|
1586
1859
|
idle_batch.prepare_for_idle()
|
1587
1860
|
return idle_batch
|
@@ -1596,20 +1869,58 @@ class Scheduler:
|
|
1596
1869
|
except futures._base.TimeoutError:
|
1597
1870
|
break
|
1598
1871
|
|
1599
|
-
if self.
|
1872
|
+
if self.server_args.enable_dp_attention:
|
1873
|
+
tp_size = self.attn_tp_size
|
1874
|
+
tp_group = self.attn_tp_cpu_group
|
1875
|
+
else:
|
1876
|
+
tp_size = self.tp_size
|
1877
|
+
tp_group = self.tp_cpu_group
|
1878
|
+
|
1879
|
+
if tp_size > 1:
|
1600
1880
|
# Sync across TP ranks to make sure they have the same number of ready requests
|
1601
1881
|
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
1602
1882
|
torch.distributed.all_reduce(
|
1603
|
-
tensor, op=torch.distributed.ReduceOp.MAX, group=
|
1883
|
+
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
1604
1884
|
)
|
1605
1885
|
num_ready_reqs_max = tensor.item()
|
1606
1886
|
for i in range(num_ready_reqs, num_ready_reqs_max):
|
1607
1887
|
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
1608
1888
|
num_ready_reqs = num_ready_reqs_max
|
1609
1889
|
|
1610
|
-
self.
|
1890
|
+
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
1611
1891
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1612
1892
|
|
1893
|
+
def watchdog_thread(self):
|
1894
|
+
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
1895
|
+
self.watchdog_last_forward_ct = 0
|
1896
|
+
self.watchdog_last_time = time.time()
|
1897
|
+
|
1898
|
+
while True:
|
1899
|
+
current = time.time()
|
1900
|
+
if self.cur_batch is not None:
|
1901
|
+
if self.watchdog_last_forward_ct == self.forward_ct:
|
1902
|
+
if current > self.watchdog_last_time + self.watchdog_timeout:
|
1903
|
+
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
1904
|
+
break
|
1905
|
+
else:
|
1906
|
+
self.watchdog_last_forward_ct = self.forward_ct
|
1907
|
+
self.watchdog_last_time = current
|
1908
|
+
time.sleep(self.watchdog_timeout // 2)
|
1909
|
+
|
1910
|
+
# Print batch size and memory pool info to check whether there are de-sync issues.
|
1911
|
+
logger.error(
|
1912
|
+
f"{self.cur_batch.batch_size()=}, "
|
1913
|
+
f"{self.cur_batch.reqs=}, "
|
1914
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
1915
|
+
f"{self.tree_cache.evictable_size()=}, "
|
1916
|
+
)
|
1917
|
+
# Wait for some time so that the parent process can print the error.
|
1918
|
+
pyspy_dump_schedulers()
|
1919
|
+
print(file=sys.stderr, flush=True)
|
1920
|
+
print(file=sys.stdout, flush=True)
|
1921
|
+
time.sleep(5)
|
1922
|
+
self.parent_process.send_signal(signal.SIGQUIT)
|
1923
|
+
|
1613
1924
|
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
1614
1925
|
self.flush_cache()
|
1615
1926
|
|
@@ -1618,21 +1929,24 @@ class Scheduler:
|
|
1618
1929
|
if len(self.waiting_queue) == 0 and (
|
1619
1930
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
1620
1931
|
):
|
1932
|
+
self.cur_batch = None
|
1933
|
+
self.last_batch = None
|
1621
1934
|
self.tree_cache.reset()
|
1622
|
-
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
1623
1935
|
if self.grammar_backend:
|
1624
1936
|
self.grammar_backend.reset()
|
1625
1937
|
self.req_to_token_pool.clear()
|
1626
|
-
self.
|
1938
|
+
self.token_to_kv_pool_allocator.clear()
|
1627
1939
|
|
1628
1940
|
if not self.spec_algorithm.is_none():
|
1629
1941
|
self.draft_worker.model_runner.req_to_token_pool.clear()
|
1630
|
-
self.draft_worker.model_runner.
|
1942
|
+
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
1631
1943
|
|
1632
1944
|
self.num_generated_tokens = 0
|
1633
1945
|
self.forward_ct_decode = 0
|
1634
1946
|
self.spec_num_total_accepted_tokens = 0
|
1635
1947
|
self.spec_num_total_forward_ct = 0
|
1948
|
+
self.cum_spec_accept_length = 0
|
1949
|
+
self.cum_spec_accept_count = 0
|
1636
1950
|
torch.cuda.empty_cache()
|
1637
1951
|
logger.info("Cache flushed successfully!")
|
1638
1952
|
if_success = True
|
@@ -1645,6 +1959,49 @@ class Scheduler:
|
|
1645
1959
|
if_success = False
|
1646
1960
|
return if_success
|
1647
1961
|
|
1962
|
+
def get_internal_state(self, recv_req: GetInternalStateReq):
|
1963
|
+
ret = dict(global_server_args_dict)
|
1964
|
+
ret["last_gen_throughput"] = self.last_gen_throughput
|
1965
|
+
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
1966
|
+
ret["avg_spec_accept_length"] = (
|
1967
|
+
self.cum_spec_accept_length / self.cum_spec_accept_count
|
1968
|
+
)
|
1969
|
+
|
1970
|
+
if RECORD_STEP_TIME:
|
1971
|
+
ret["step_time_dict"] = self.step_time_dict
|
1972
|
+
return GetInternalStateReqOutput(
|
1973
|
+
internal_state=ret,
|
1974
|
+
)
|
1975
|
+
|
1976
|
+
def set_internal_state(self, recv_req: SetInternalStateReq):
|
1977
|
+
server_args_dict = recv_req.server_args
|
1978
|
+
args_allow_update = set(
|
1979
|
+
[
|
1980
|
+
"speculative_accept_threshold_single",
|
1981
|
+
"speculative_accept_threshold_acc",
|
1982
|
+
]
|
1983
|
+
)
|
1984
|
+
if_success = True
|
1985
|
+
for k, v in server_args_dict.items():
|
1986
|
+
if k not in args_allow_update:
|
1987
|
+
logging.warning(f"Updating {k} is not supported.")
|
1988
|
+
if_success = False
|
1989
|
+
break
|
1990
|
+
if if_success:
|
1991
|
+
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
1992
|
+
avg_spec_accept_length = (
|
1993
|
+
self.cum_spec_accept_length / self.cum_spec_accept_count
|
1994
|
+
)
|
1995
|
+
logger.info(f"{avg_spec_accept_length=}")
|
1996
|
+
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
|
1997
|
+
for k, v in server_args_dict.items():
|
1998
|
+
global_server_args_dict[k] = v
|
1999
|
+
logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
|
2000
|
+
return SetInternalStateReqOutput(
|
2001
|
+
updated=True,
|
2002
|
+
server_args=global_server_args_dict,
|
2003
|
+
)
|
2004
|
+
|
1648
2005
|
def abort_request(self, recv_req: AbortReq):
|
1649
2006
|
# Delete requests in the waiting queue
|
1650
2007
|
to_del = None
|
@@ -1666,6 +2023,9 @@ class Scheduler:
|
|
1666
2023
|
req.to_abort = True
|
1667
2024
|
break
|
1668
2025
|
|
2026
|
+
def _pause_engine(self) -> Tuple[List[Req], int]:
|
2027
|
+
raise NotImplementedError()
|
2028
|
+
|
1669
2029
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
1670
2030
|
"""In-place update of the weights from disk."""
|
1671
2031
|
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
@@ -1674,7 +2034,7 @@ class Scheduler:
|
|
1674
2034
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1675
2035
|
else:
|
1676
2036
|
logger.error(message)
|
1677
|
-
return UpdateWeightFromDiskReqOutput(success, message)
|
2037
|
+
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
1678
2038
|
|
1679
2039
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
1680
2040
|
"""Initialize the online model parameter update group."""
|
@@ -1699,8 +2059,9 @@ class Scheduler:
|
|
1699
2059
|
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
1700
2060
|
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
1701
2061
|
if success:
|
1702
|
-
|
1703
|
-
|
2062
|
+
if recv_req.flush_cache:
|
2063
|
+
flash_cache_success = self.flush_cache()
|
2064
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
1704
2065
|
else:
|
1705
2066
|
logger.error(message)
|
1706
2067
|
return UpdateWeightsFromTensorReqOutput(success, message)
|
@@ -1709,7 +2070,7 @@ class Scheduler:
|
|
1709
2070
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1710
2071
|
return GetWeightsByNameReqOutput(parameter)
|
1711
2072
|
|
1712
|
-
def release_memory_occupation(self):
|
2073
|
+
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
1713
2074
|
self.stashed_model_static_state = _export_static_state(
|
1714
2075
|
self.tp_worker.worker.model_runner.model
|
1715
2076
|
)
|
@@ -1717,7 +2078,7 @@ class Scheduler:
|
|
1717
2078
|
self.flush_cache()
|
1718
2079
|
return ReleaseMemoryOccupationReqOutput()
|
1719
2080
|
|
1720
|
-
def resume_memory_occupation(self):
|
2081
|
+
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
1721
2082
|
self.memory_saver_adapter.resume()
|
1722
2083
|
_import_static_state(
|
1723
2084
|
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
@@ -1726,24 +2087,96 @@ class Scheduler:
|
|
1726
2087
|
return ResumeMemoryOccupationReqOutput()
|
1727
2088
|
|
1728
2089
|
def profile(self, recv_req: ProfileReq):
|
1729
|
-
if recv_req ==
|
1730
|
-
self.start_profile(
|
2090
|
+
if recv_req.type == ProfileReqType.START_PROFILE:
|
2091
|
+
return self.start_profile(
|
2092
|
+
recv_req.output_dir, recv_req.num_steps, recv_req.activities
|
2093
|
+
)
|
1731
2094
|
else:
|
1732
|
-
self.stop_profile()
|
2095
|
+
return self.stop_profile()
|
2096
|
+
|
2097
|
+
def start_profile(
|
2098
|
+
self,
|
2099
|
+
output_dir: Optional[str],
|
2100
|
+
num_steps: Optional[int],
|
2101
|
+
activities: Optional[List[str]],
|
2102
|
+
) -> None:
|
2103
|
+
if self.torch_profiler_activities:
|
2104
|
+
return ProfileReqOutput(
|
2105
|
+
success=False,
|
2106
|
+
message="Profiling is already in progress. Call /stop_profile first.",
|
2107
|
+
)
|
2108
|
+
|
2109
|
+
if output_dir is None:
|
2110
|
+
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
2111
|
+
if activities is None:
|
2112
|
+
activities = ["CPU", "GPU"]
|
2113
|
+
|
2114
|
+
self.torch_profiler_output_dir = output_dir
|
2115
|
+
self.torch_profiler_activities = activities
|
2116
|
+
logger.info(
|
2117
|
+
"Profiling starts. Traces will be saved to: %s",
|
2118
|
+
self.torch_profiler_output_dir,
|
2119
|
+
)
|
2120
|
+
|
2121
|
+
activity_map = {
|
2122
|
+
"CPU": torch.profiler.ProfilerActivity.CPU,
|
2123
|
+
"GPU": torch.profiler.ProfilerActivity.CUDA,
|
2124
|
+
}
|
2125
|
+
torchprof_activities = [
|
2126
|
+
activity_map[a] for a in activities if a in activity_map
|
2127
|
+
]
|
2128
|
+
|
2129
|
+
if torchprof_activities:
|
2130
|
+
self.torch_profiler = torch.profiler.profile(
|
2131
|
+
activities=torchprof_activities,
|
2132
|
+
with_stack=True,
|
2133
|
+
)
|
2134
|
+
self.torch_profiler.start()
|
2135
|
+
|
2136
|
+
if "MEM" in activities:
|
2137
|
+
torch.cuda.memory._record_memory_history(max_entries=100000)
|
1733
2138
|
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
2139
|
+
if num_steps:
|
2140
|
+
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
2141
|
+
# The caller will be notified when reaching profiler_target_forward_ct
|
2142
|
+
else:
|
2143
|
+
self.profiler_target_forward_ct = None
|
2144
|
+
return ProfileReqOutput(success=True, message="Succeeded")
|
1738
2145
|
|
1739
2146
|
def stop_profile(self) -> None:
|
1740
|
-
if self.
|
1741
|
-
|
1742
|
-
|
1743
|
-
|
1744
|
-
|
2147
|
+
if self.torch_profiler_activities is None:
|
2148
|
+
return
|
2149
|
+
|
2150
|
+
logger.info("Stop profiling...")
|
2151
|
+
if self.torch_profiler is not None:
|
2152
|
+
self.torch_profiler.stop()
|
2153
|
+
self.torch_profiler.export_chrome_trace(
|
2154
|
+
os.path.join(
|
2155
|
+
self.torch_profiler_output_dir,
|
2156
|
+
str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
2157
|
+
)
|
2158
|
+
)
|
2159
|
+
|
2160
|
+
if "MEM" in self.torch_profiler_activities:
|
2161
|
+
memory_profile_path = os.path.join(
|
2162
|
+
self.torch_profiler_trace_dir,
|
2163
|
+
str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
|
2164
|
+
)
|
2165
|
+
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
2166
|
+
torch.cuda.memory._record_memory_history(enabled=None)
|
2167
|
+
|
2168
|
+
logger.info(
|
2169
|
+
"Profiling done. Traces are saved to: %s",
|
2170
|
+
self.torch_profiler_output_dir,
|
1745
2171
|
)
|
1746
|
-
|
2172
|
+
self.torch_profiler = None
|
2173
|
+
self.torch_profiler_output_dir = None
|
2174
|
+
self.torch_profiler_activities = None
|
2175
|
+
|
2176
|
+
if self.profiler_target_forward_ct:
|
2177
|
+
self.send_to_tokenizer.send_pyobj(
|
2178
|
+
ProfileReqOutput(success=True, message="Succeeded.")
|
2179
|
+
)
|
1747
2180
|
|
1748
2181
|
def open_session(self, recv_req: OpenSessionReqInput):
|
1749
2182
|
# handle error
|
@@ -1752,7 +2185,7 @@ class Scheduler:
|
|
1752
2185
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1753
2186
|
return OpenSessionReqOutput(session_id, False)
|
1754
2187
|
elif session_id is None:
|
1755
|
-
logger.warning(
|
2188
|
+
logger.warning("session id is None, cannot open.")
|
1756
2189
|
return OpenSessionReqOutput(session_id, False)
|
1757
2190
|
else:
|
1758
2191
|
self.sessions[session_id] = Session(
|
@@ -1769,6 +2202,10 @@ class Scheduler:
|
|
1769
2202
|
del self.sessions[session_id]
|
1770
2203
|
|
1771
2204
|
|
2205
|
+
def is_health_check_generate_req(recv_req):
|
2206
|
+
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
2207
|
+
|
2208
|
+
|
1772
2209
|
def _export_static_state(model):
|
1773
2210
|
return dict(
|
1774
2211
|
buffers=[
|
@@ -1791,26 +2228,28 @@ def run_scheduler_process(
|
|
1791
2228
|
dp_rank: Optional[int],
|
1792
2229
|
pipe_writer,
|
1793
2230
|
):
|
1794
|
-
|
2231
|
+
# Config the process
|
2232
|
+
# kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
|
2233
|
+
setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
|
1795
2234
|
faulthandler.enable()
|
2235
|
+
parent_process = psutil.Process().parent()
|
1796
2236
|
|
1797
2237
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
1798
2238
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
1799
2239
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
1800
2240
|
|
1801
|
-
#
|
2241
|
+
# Configure the logger
|
1802
2242
|
if dp_rank is None:
|
1803
|
-
|
2243
|
+
prefix = f" TP{tp_rank}"
|
1804
2244
|
else:
|
1805
|
-
|
2245
|
+
prefix = f" DP{dp_rank} TP{tp_rank}"
|
2246
|
+
configure_logger(server_args, prefix=prefix)
|
1806
2247
|
suppress_other_loggers()
|
1807
2248
|
|
1808
2249
|
# Set cpu affinity to this gpu process
|
1809
2250
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1810
2251
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1811
2252
|
|
1812
|
-
parent_process = psutil.Process().parent()
|
1813
|
-
|
1814
2253
|
# Create a scheduler and run the event loop
|
1815
2254
|
try:
|
1816
2255
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|