sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -17,10 +17,11 @@ import faulthandler
|
|
17
17
|
import logging
|
18
18
|
import os
|
19
19
|
import signal
|
20
|
+
import sys
|
20
21
|
import threading
|
21
22
|
import time
|
22
23
|
import warnings
|
23
|
-
from collections import deque
|
24
|
+
from collections import defaultdict, deque
|
24
25
|
from concurrent import futures
|
25
26
|
from dataclasses import dataclass
|
26
27
|
from http import HTTPStatus
|
@@ -44,17 +45,24 @@ from sglang.srt.managers.io_struct import (
|
|
44
45
|
BatchTokenIDOut,
|
45
46
|
CloseSessionReqInput,
|
46
47
|
FlushCacheReq,
|
48
|
+
GetInternalStateReq,
|
49
|
+
GetInternalStateReqOutput,
|
47
50
|
GetWeightsByNameReqInput,
|
48
51
|
GetWeightsByNameReqOutput,
|
52
|
+
HealthCheckOutput,
|
49
53
|
InitWeightsUpdateGroupReqInput,
|
50
54
|
InitWeightsUpdateGroupReqOutput,
|
51
55
|
OpenSessionReqInput,
|
52
56
|
OpenSessionReqOutput,
|
53
57
|
ProfileReq,
|
58
|
+
ProfileReqOutput,
|
59
|
+
ProfileReqType,
|
54
60
|
ReleaseMemoryOccupationReqInput,
|
55
61
|
ReleaseMemoryOccupationReqOutput,
|
56
62
|
ResumeMemoryOccupationReqInput,
|
57
63
|
ResumeMemoryOccupationReqOutput,
|
64
|
+
SetInternalStateReq,
|
65
|
+
SetInternalStateReqOutput,
|
58
66
|
TokenizedEmbeddingReqInput,
|
59
67
|
TokenizedGenerateReqInput,
|
60
68
|
UpdateWeightFromDiskReqInput,
|
@@ -82,6 +90,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
|
82
90
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
83
91
|
from sglang.srt.managers.utils import validate_input_length
|
84
92
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
93
|
+
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
85
94
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
86
95
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
87
96
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
@@ -94,6 +103,7 @@ from sglang.srt.utils import (
|
|
94
103
|
crash_on_warnings,
|
95
104
|
get_bool_env_var,
|
96
105
|
get_zmq_socket,
|
106
|
+
pyspy_dump_schedulers,
|
97
107
|
set_gpu_proc_affinity,
|
98
108
|
set_random_seed,
|
99
109
|
suppress_other_loggers,
|
@@ -103,13 +113,16 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
|
103
113
|
logger = logging.getLogger(__name__)
|
104
114
|
|
105
115
|
# Test retract decode for debugging purposes
|
106
|
-
|
116
|
+
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
117
|
+
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
107
118
|
|
108
119
|
|
109
120
|
@dataclass
|
110
121
|
class GenerationBatchResult:
|
111
122
|
logits_output: LogitsProcessorOutput
|
112
123
|
next_token_ids: List[int]
|
124
|
+
extend_input_len_per_req: List[int]
|
125
|
+
extend_logprob_start_len_per_req: List[int]
|
113
126
|
bid: int
|
114
127
|
|
115
128
|
|
@@ -135,21 +148,28 @@ class Scheduler:
|
|
135
148
|
self.tp_rank = tp_rank
|
136
149
|
self.tp_size = server_args.tp_size
|
137
150
|
self.schedule_policy = server_args.schedule_policy
|
138
|
-
self.disable_jump_forward = server_args.disable_jump_forward
|
139
151
|
self.lora_paths = server_args.lora_paths
|
140
152
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
141
153
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
142
154
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
143
155
|
self.enable_metrics = server_args.enable_metrics
|
156
|
+
self.stream_interval = server_args.stream_interval
|
144
157
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
145
158
|
server_args.speculative_algorithm
|
146
159
|
)
|
160
|
+
self.gpu_id = gpu_id
|
161
|
+
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
147
162
|
self.decode_mem_cache_buf_multiplier = (
|
148
|
-
|
163
|
+
(
|
164
|
+
self.server_args.speculative_num_draft_tokens
|
165
|
+
+ (
|
166
|
+
self.server_args.speculative_eagle_topk
|
167
|
+
* self.server_args.speculative_num_draft_tokens
|
168
|
+
)
|
169
|
+
)
|
149
170
|
if not self.spec_algorithm.is_none()
|
150
171
|
else 1
|
151
172
|
)
|
152
|
-
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
153
173
|
|
154
174
|
# Distributed rank info
|
155
175
|
self.dp_size = server_args.dp_size
|
@@ -228,9 +248,6 @@ class Scheduler:
|
|
228
248
|
self.enable_overlap = False
|
229
249
|
logger.info("Overlap scheduler is disabled for multimodal models.")
|
230
250
|
|
231
|
-
if self.enable_overlap:
|
232
|
-
self.disable_jump_forward = True
|
233
|
-
|
234
251
|
# Launch a tensor parallel worker
|
235
252
|
if self.enable_overlap:
|
236
253
|
TpWorkerClass = TpModelWorkerClient
|
@@ -245,7 +262,7 @@ class Scheduler:
|
|
245
262
|
nccl_port=port_args.nccl_port,
|
246
263
|
)
|
247
264
|
|
248
|
-
# Launch a worker for speculative decoding
|
265
|
+
# Launch a draft worker for speculative decoding
|
249
266
|
if self.spec_algorithm.is_eagle():
|
250
267
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
251
268
|
|
@@ -257,8 +274,10 @@ class Scheduler:
|
|
257
274
|
target_worker=self.tp_worker,
|
258
275
|
dp_rank=dp_rank,
|
259
276
|
)
|
277
|
+
self.prefill_only_one_req = True
|
260
278
|
else:
|
261
279
|
self.draft_worker = None
|
280
|
+
self.prefill_only_one_req = False
|
262
281
|
|
263
282
|
# Get token and memory info from the model worker
|
264
283
|
(
|
@@ -279,6 +298,7 @@ class Scheduler:
|
|
279
298
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
280
299
|
global_server_args_dict.update(worker_global_server_args_dict)
|
281
300
|
set_random_seed(self.random_seed)
|
301
|
+
|
282
302
|
# Print debug info
|
283
303
|
logger.info(
|
284
304
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
@@ -289,7 +309,9 @@ class Scheduler:
|
|
289
309
|
)
|
290
310
|
|
291
311
|
# Init memory pool and cache
|
292
|
-
self.req_to_token_pool, self.
|
312
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
313
|
+
self.tp_worker.get_memory_pool()
|
314
|
+
)
|
293
315
|
|
294
316
|
if (
|
295
317
|
server_args.chunked_prefill_size is not None
|
@@ -297,19 +319,26 @@ class Scheduler:
|
|
297
319
|
):
|
298
320
|
self.tree_cache = ChunkCache(
|
299
321
|
req_to_token_pool=self.req_to_token_pool,
|
300
|
-
|
322
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
301
323
|
)
|
302
324
|
else:
|
303
|
-
self.
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
325
|
+
if self.enable_hierarchical_cache:
|
326
|
+
self.tree_cache = HiRadixCache(
|
327
|
+
req_to_token_pool=self.req_to_token_pool,
|
328
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
329
|
+
)
|
330
|
+
else:
|
331
|
+
self.tree_cache = RadixCache(
|
332
|
+
req_to_token_pool=self.req_to_token_pool,
|
333
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
334
|
+
disable=server_args.disable_radix_cache,
|
335
|
+
)
|
336
|
+
|
309
337
|
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
310
338
|
|
311
339
|
# Init running status
|
312
340
|
self.waiting_queue: List[Req] = []
|
341
|
+
self.staging_reqs = {}
|
313
342
|
# The running decoding batch for continuous batching
|
314
343
|
self.running_batch: Optional[ScheduleBatch] = None
|
315
344
|
# The current forward batch
|
@@ -321,12 +350,22 @@ class Scheduler:
|
|
321
350
|
self.num_generated_tokens = 0
|
322
351
|
self.spec_num_total_accepted_tokens = 0
|
323
352
|
self.spec_num_total_forward_ct = 0
|
353
|
+
self.cum_spec_accept_length = 0
|
354
|
+
self.cum_spec_accept_count = 0
|
324
355
|
self.last_decode_stats_tic = time.time()
|
325
|
-
self.
|
356
|
+
self.return_health_check_ct = 0
|
326
357
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
327
358
|
if self.device == "cpu":
|
328
359
|
self.current_stream.synchronize = lambda: None # No-op for CPU
|
329
360
|
|
361
|
+
# For metrics only.
|
362
|
+
# The largest prefill length of a single request
|
363
|
+
self._largest_prefill_len: int = 0
|
364
|
+
# The largest context length (prefill + generation) of a single request
|
365
|
+
self._largest_prefill_decode_len: int = 0
|
366
|
+
self.last_gen_throughput: float = 0.0
|
367
|
+
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
368
|
+
|
330
369
|
# Session info
|
331
370
|
self.sessions: Dict[str, Session] = {}
|
332
371
|
|
@@ -334,7 +373,7 @@ class Scheduler:
|
|
334
373
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
335
374
|
if self.chunked_prefill_size <= 0: # -1 means disable
|
336
375
|
self.chunked_prefill_size = None
|
337
|
-
self.
|
376
|
+
self.chunked_req = None
|
338
377
|
self.is_mixed_chunk = (
|
339
378
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
340
379
|
)
|
@@ -368,7 +407,7 @@ class Scheduler:
|
|
368
407
|
) / global_config.default_new_token_ratio_decay_steps
|
369
408
|
self.new_token_ratio = self.init_new_token_ratio
|
370
409
|
|
371
|
-
#
|
410
|
+
# Tell whether the current running batch is full so that we can skip
|
372
411
|
# the check of whether to prefill new requests.
|
373
412
|
# This is an optimization to reduce the overhead of the prefill check.
|
374
413
|
self.batch_is_full = False
|
@@ -379,26 +418,16 @@ class Scheduler:
|
|
379
418
|
t.start()
|
380
419
|
self.parent_process = psutil.Process().parent()
|
381
420
|
|
421
|
+
# Init memory saver
|
382
422
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
383
423
|
enable=server_args.enable_memory_saver
|
384
424
|
)
|
385
425
|
|
386
426
|
# Init profiler
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
logger.info(
|
392
|
-
"Profiling enabled. Traces will be saved to: %s",
|
393
|
-
self.torch_profiler_trace_dir,
|
394
|
-
)
|
395
|
-
self.profiler = torch.profiler.profile(
|
396
|
-
activities=[
|
397
|
-
torch.profiler.ProfilerActivity.CPU,
|
398
|
-
torch.profiler.ProfilerActivity.CUDA,
|
399
|
-
],
|
400
|
-
with_stack=True,
|
401
|
-
)
|
427
|
+
self.torch_profiler = None
|
428
|
+
self.torch_profiler_output_dir: Optional[str] = None
|
429
|
+
self.torch_profiler_activities: Optional[List[str]] = None
|
430
|
+
self.profiler_target_forward_ct: Optional[int] = None
|
402
431
|
|
403
432
|
# Init metrics stats
|
404
433
|
self.stats = SchedulerStats()
|
@@ -410,11 +439,6 @@ class Scheduler:
|
|
410
439
|
},
|
411
440
|
)
|
412
441
|
|
413
|
-
# The largest prefill length of a single request
|
414
|
-
self._largest_prefill_len: int = 0
|
415
|
-
# The largest context length (prefill + generation) of a single request
|
416
|
-
self._largest_prefill_decode_len: int = 0
|
417
|
-
|
418
442
|
# Init request dispatcher
|
419
443
|
self._request_dispatcher = TypeBasedDispatcher(
|
420
444
|
[
|
@@ -422,6 +446,8 @@ class Scheduler:
|
|
422
446
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
423
447
|
(FlushCacheReq, self.flush_cache_wrapped),
|
424
448
|
(AbortReq, self.abort_request),
|
449
|
+
(OpenSessionReqInput, self.open_session),
|
450
|
+
(CloseSessionReqInput, self.close_session),
|
425
451
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
426
452
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
427
453
|
(
|
@@ -430,22 +456,15 @@ class Scheduler:
|
|
430
456
|
),
|
431
457
|
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
432
458
|
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
459
|
+
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
460
|
+
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
433
461
|
(ProfileReq, self.profile),
|
434
|
-
(
|
435
|
-
(CloseSessionReqInput, self.close_session),
|
436
|
-
(
|
437
|
-
ReleaseMemoryOccupationReqInput,
|
438
|
-
lambda _: self.release_memory_occupation(),
|
439
|
-
),
|
440
|
-
(
|
441
|
-
ResumeMemoryOccupationReqInput,
|
442
|
-
lambda _: self.resume_memory_occupation(),
|
443
|
-
),
|
462
|
+
(GetInternalStateReq, self.get_internal_state),
|
444
463
|
]
|
445
464
|
)
|
446
465
|
|
447
466
|
def watchdog_thread(self):
|
448
|
-
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
467
|
+
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
449
468
|
self.watchdog_last_forward_ct = 0
|
450
469
|
self.watchdog_last_time = time.time()
|
451
470
|
|
@@ -460,7 +479,18 @@ class Scheduler:
|
|
460
479
|
self.watchdog_last_forward_ct = self.forward_ct
|
461
480
|
self.watchdog_last_time = current
|
462
481
|
time.sleep(self.watchdog_timeout // 2)
|
463
|
-
|
482
|
+
|
483
|
+
# Print batch size and memory pool info to check whether there are de-sync issues.
|
484
|
+
logger.error(
|
485
|
+
f"{self.cur_batch.batch_size()=}, "
|
486
|
+
f"{self.cur_batch.reqs=}, "
|
487
|
+
f"{self.token_to_kv_pool.available_size()=}, "
|
488
|
+
f"{self.tree_cache.evictable_size()=}, "
|
489
|
+
)
|
490
|
+
# Wait for some time so that the parent process can print the error.
|
491
|
+
pyspy_dump_schedulers()
|
492
|
+
print(file=sys.stderr, flush=True)
|
493
|
+
print(file=sys.stdout, flush=True)
|
464
494
|
time.sleep(5)
|
465
495
|
self.parent_process.send_signal(signal.SIGQUIT)
|
466
496
|
|
@@ -577,6 +607,13 @@ class Scheduler:
|
|
577
607
|
|
578
608
|
def process_input_requests(self, recv_reqs: List):
|
579
609
|
for recv_req in recv_reqs:
|
610
|
+
# If it is a health check generation request and there are running requests, ignore it.
|
611
|
+
if is_health_check_generate_req(recv_req) and (
|
612
|
+
self.chunked_req is not None or self.running_batch is not None
|
613
|
+
):
|
614
|
+
self.return_health_check_ct += 1
|
615
|
+
continue
|
616
|
+
|
580
617
|
output = self._request_dispatcher(recv_req)
|
581
618
|
if output is not None:
|
582
619
|
self.send_to_tokenizer.send_pyobj(output)
|
@@ -591,7 +628,6 @@ class Scheduler:
|
|
591
628
|
or recv_req.session_params.id is None
|
592
629
|
or recv_req.session_params.id not in self.sessions
|
593
630
|
):
|
594
|
-
|
595
631
|
if recv_req.input_embeds is not None:
|
596
632
|
# Generate fake input_ids based on the length of input_embeds
|
597
633
|
seq_length = len(recv_req.input_embeds)
|
@@ -618,10 +654,12 @@ class Scheduler:
|
|
618
654
|
recv_req.sampling_params,
|
619
655
|
return_logprob=recv_req.return_logprob,
|
620
656
|
top_logprobs_num=recv_req.top_logprobs_num,
|
657
|
+
token_ids_logprob=recv_req.token_ids_logprob,
|
621
658
|
stream=recv_req.stream,
|
622
659
|
lora_path=recv_req.lora_path,
|
623
660
|
input_embeds=recv_req.input_embeds,
|
624
661
|
custom_logit_processor=custom_logit_processor,
|
662
|
+
return_hidden_states=recv_req.return_hidden_states,
|
625
663
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
626
664
|
)
|
627
665
|
req.tokenizer = self.tokenizer
|
@@ -633,14 +671,14 @@ class Scheduler:
|
|
633
671
|
req.finished_reason = FINISH_ABORT(
|
634
672
|
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
635
673
|
)
|
636
|
-
self.
|
674
|
+
self._add_request_to_queue(req)
|
637
675
|
return
|
638
676
|
else:
|
639
677
|
# Create a new request from a previous session
|
640
678
|
session = self.sessions[recv_req.session_params.id]
|
641
679
|
req = session.create_req(recv_req, self.tokenizer)
|
642
680
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
643
|
-
self.
|
681
|
+
self._add_request_to_queue(req)
|
644
682
|
return
|
645
683
|
|
646
684
|
# Handle multimodal inputs
|
@@ -664,7 +702,7 @@ class Scheduler:
|
|
664
702
|
req.finished_reason = FINISH_ABORT(
|
665
703
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
666
704
|
)
|
667
|
-
self.
|
705
|
+
self._add_request_to_queue(req)
|
668
706
|
return
|
669
707
|
|
670
708
|
# Validate prompts length
|
@@ -674,16 +712,28 @@ class Scheduler:
|
|
674
712
|
self.server_args.allow_auto_truncate,
|
675
713
|
)
|
676
714
|
if error_msg:
|
677
|
-
|
715
|
+
req.origin_input_ids = [0]
|
716
|
+
req.sampling_params.max_new_tokens = 0
|
717
|
+
self._add_request_to_queue(req)
|
678
718
|
return
|
679
719
|
|
680
720
|
# Copy more attributes
|
681
|
-
if recv_req.logprob_start_len == -1:
|
721
|
+
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
|
682
722
|
# By default, only return the logprobs for output tokens
|
683
723
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
684
724
|
else:
|
685
725
|
req.logprob_start_len = recv_req.logprob_start_len
|
686
726
|
|
727
|
+
if req.logprob_start_len >= len(req.origin_input_ids):
|
728
|
+
req.finished_reason = FINISH_ABORT(
|
729
|
+
f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
|
730
|
+
HTTPStatus.BAD_REQUEST,
|
731
|
+
"BadRequestError",
|
732
|
+
)
|
733
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
734
|
+
self._add_request_to_queue(req)
|
735
|
+
return
|
736
|
+
|
687
737
|
req.sampling_params.max_new_tokens = min(
|
688
738
|
(
|
689
739
|
req.sampling_params.max_new_tokens
|
@@ -699,6 +749,7 @@ class Scheduler:
|
|
699
749
|
req.sampling_params.json_schema is not None
|
700
750
|
or req.sampling_params.regex is not None
|
701
751
|
or req.sampling_params.ebnf is not None
|
752
|
+
or req.sampling_params.structural_tag is not None
|
702
753
|
):
|
703
754
|
assert self.grammar_backend is not None
|
704
755
|
if req.sampling_params.json_schema is not None:
|
@@ -707,6 +758,8 @@ class Scheduler:
|
|
707
758
|
key = ("regex", req.sampling_params.regex)
|
708
759
|
elif req.sampling_params.ebnf is not None:
|
709
760
|
key = ("ebnf", req.sampling_params.ebnf)
|
761
|
+
elif req.sampling_params.structural_tag:
|
762
|
+
key = ("structural_tag", req.sampling_params.structural_tag)
|
710
763
|
|
711
764
|
req.grammar = self.grammar_backend.get_cached_value(key)
|
712
765
|
if not req.grammar:
|
@@ -716,7 +769,13 @@ class Scheduler:
|
|
716
769
|
if add_to_grammar_queue:
|
717
770
|
self.grammar_queue.append(req)
|
718
771
|
else:
|
719
|
-
self.
|
772
|
+
self._add_request_to_queue(req)
|
773
|
+
|
774
|
+
def _add_request_to_queue(self, req: Req):
|
775
|
+
self.waiting_queue.append(req)
|
776
|
+
|
777
|
+
def _extend_requests_to_queue(self, reqs: List[Req]):
|
778
|
+
self.waiting_queue.extend(reqs)
|
720
779
|
|
721
780
|
def handle_embedding_request(
|
722
781
|
self,
|
@@ -737,61 +796,64 @@ class Scheduler:
|
|
737
796
|
self.server_args.allow_auto_truncate,
|
738
797
|
)
|
739
798
|
if error_msg:
|
740
|
-
self.
|
799
|
+
self._add_request_to_queue(req)
|
741
800
|
return
|
742
801
|
|
743
802
|
# Copy more attributes
|
744
803
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
745
|
-
self.
|
804
|
+
self._add_request_to_queue(req)
|
746
805
|
|
747
806
|
def log_prefill_stats(
|
748
807
|
self,
|
749
808
|
adder: PrefillAdder,
|
750
809
|
can_run_list: List[Req],
|
751
|
-
running_bs:
|
752
|
-
has_being_chunked: bool,
|
810
|
+
running_bs: int,
|
753
811
|
):
|
754
|
-
self.tree_cache_metrics["total"] += (
|
755
|
-
adder.log_input_tokens + adder.log_hit_tokens
|
756
|
-
) / 10**9
|
757
|
-
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
758
|
-
tree_cache_hit_rate = (
|
759
|
-
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
760
|
-
)
|
761
|
-
|
762
812
|
num_used = self.max_total_num_tokens - (
|
763
|
-
self.
|
813
|
+
self.token_to_kv_pool_allocator.available_size()
|
814
|
+
+ self.tree_cache.evictable_size()
|
815
|
+
)
|
816
|
+
self._largest_prefill_len = max(
|
817
|
+
self._largest_prefill_len, adder.log_input_tokens
|
764
818
|
)
|
765
819
|
|
766
|
-
|
820
|
+
f = (
|
767
821
|
f"Prefill batch. "
|
768
822
|
f"#new-seq: {len(can_run_list)}, "
|
769
823
|
f"#new-token: {adder.log_input_tokens}, "
|
770
824
|
f"#cached-token: {adder.log_hit_tokens}, "
|
771
|
-
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
772
825
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
773
826
|
f"#running-req: {running_bs}, "
|
774
|
-
f"#queue-req: {len(self.waiting_queue)
|
827
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
775
828
|
)
|
829
|
+
logger.info(f)
|
776
830
|
|
777
831
|
if self.enable_metrics:
|
832
|
+
cache_hit_rate = adder.log_hit_tokens / (
|
833
|
+
adder.log_input_tokens + adder.log_hit_tokens
|
834
|
+
)
|
778
835
|
self.stats.num_running_reqs = running_bs
|
779
836
|
self.stats.num_used_tokens = num_used
|
780
837
|
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
781
|
-
self.stats.num_queue_reqs = len(self.waiting_queue)
|
782
|
-
self.stats.cache_hit_rate =
|
838
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
839
|
+
self.stats.cache_hit_rate = cache_hit_rate
|
783
840
|
self.metrics_collector.log_stats(self.stats)
|
784
841
|
|
785
842
|
def log_decode_stats(self):
|
786
|
-
|
787
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
788
|
-
)
|
789
|
-
gen_throughput = self.num_generated_tokens / (
|
790
|
-
time.time() - self.last_decode_stats_tic
|
791
|
-
)
|
792
|
-
self.num_generated_tokens = 0
|
843
|
+
gap_latency = time.time() - self.last_decode_stats_tic
|
793
844
|
self.last_decode_stats_tic = time.time()
|
845
|
+
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
846
|
+
self.num_generated_tokens = 0
|
794
847
|
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
848
|
+
num_used = self.max_total_num_tokens - (
|
849
|
+
self.token_to_kv_pool_allocator.available_size()
|
850
|
+
+ self.tree_cache.evictable_size()
|
851
|
+
)
|
852
|
+
|
853
|
+
if RECORD_STEP_TIME:
|
854
|
+
self.step_time_dict[num_running_reqs].append(
|
855
|
+
gap_latency / self.server_args.decode_log_interval
|
856
|
+
)
|
795
857
|
|
796
858
|
if self.spec_algorithm.is_none():
|
797
859
|
msg = (
|
@@ -799,14 +861,17 @@ class Scheduler:
|
|
799
861
|
f"#running-req: {num_running_reqs}, "
|
800
862
|
f"#token: {num_used}, "
|
801
863
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
802
|
-
f"gen throughput (token/s): {
|
803
|
-
f"
|
864
|
+
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
865
|
+
f"largest-len: {self._largest_prefill_decode_len}, "
|
866
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
804
867
|
)
|
805
868
|
spec_accept_length = 0
|
806
869
|
else:
|
807
870
|
spec_accept_length = (
|
808
871
|
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
809
872
|
)
|
873
|
+
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
874
|
+
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
810
875
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
811
876
|
msg = (
|
812
877
|
f"Decode batch. "
|
@@ -814,8 +879,9 @@ class Scheduler:
|
|
814
879
|
f"#token: {num_used}, "
|
815
880
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
816
881
|
f"accept len: {spec_accept_length:.2f}, "
|
817
|
-
f"gen throughput (token/s): {
|
818
|
-
f"
|
882
|
+
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
883
|
+
f"largest-len: {self._largest_prefill_decode_len}, "
|
884
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
819
885
|
)
|
820
886
|
|
821
887
|
logger.info(msg)
|
@@ -823,14 +889,16 @@ class Scheduler:
|
|
823
889
|
self.stats.num_running_reqs = num_running_reqs
|
824
890
|
self.stats.num_used_tokens = num_used
|
825
891
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
826
|
-
self.stats.
|
892
|
+
self.stats.cache_hit_rate = 0.0
|
893
|
+
self.stats.gen_throughput = self.last_gen_throughput
|
827
894
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
828
895
|
self.stats.spec_accept_length = spec_accept_length
|
829
896
|
self.metrics_collector.log_stats(self.stats)
|
830
897
|
|
831
898
|
def check_memory(self):
|
832
899
|
available_size = (
|
833
|
-
self.
|
900
|
+
self.token_to_kv_pool_allocator.available_size()
|
901
|
+
+ self.tree_cache.evictable_size()
|
834
902
|
)
|
835
903
|
protected_size = self.tree_cache.protected_size()
|
836
904
|
memory_leak = available_size != (
|
@@ -857,21 +925,42 @@ class Scheduler:
|
|
857
925
|
if crash_on_warnings():
|
858
926
|
raise ValueError(msg)
|
859
927
|
|
928
|
+
if (
|
929
|
+
self.enable_metrics
|
930
|
+
and self.attn_tp_rank == 0
|
931
|
+
and time.time() > self.metrics_collector.last_log_time + 30
|
932
|
+
):
|
933
|
+
# During idle time, also collect metrics every 30 seconds.
|
934
|
+
num_used = self.max_total_num_tokens - (
|
935
|
+
self.token_to_kv_pool.available_size()
|
936
|
+
+ self.tree_cache.evictable_size()
|
937
|
+
)
|
938
|
+
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
939
|
+
self.stats.num_running_reqs = num_running_reqs
|
940
|
+
self.stats.num_used_tokens = num_used
|
941
|
+
self.stats.token_usage = num_used / self.max_total_num_tokens
|
942
|
+
self.stats.gen_throughput = 0
|
943
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
944
|
+
self.metrics_collector.log_stats(self.stats)
|
945
|
+
|
860
946
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
861
947
|
# Merge the prefill batch into the running batch
|
862
948
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
863
|
-
if self.
|
864
|
-
# Move the chunked request out of the batch
|
865
|
-
|
866
|
-
self.
|
867
|
-
|
868
|
-
|
949
|
+
if self.chunked_req:
|
950
|
+
# Move the chunked request out of the batch so that we can merge
|
951
|
+
# only finished requests to running_batch.
|
952
|
+
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
953
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
954
|
+
# chunked request keeps its rid but will get a new req_pool_idx
|
955
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
869
956
|
self.batch_is_full = False
|
870
957
|
|
958
|
+
self.last_batch.filter_batch()
|
871
959
|
if not self.last_batch.is_empty():
|
872
960
|
if self.running_batch is None:
|
873
961
|
self.running_batch = self.last_batch
|
874
962
|
else:
|
963
|
+
# merge running_batch with prefill batch
|
875
964
|
self.running_batch.merge_batch(self.last_batch)
|
876
965
|
|
877
966
|
new_batch = self.get_new_batch_prefill()
|
@@ -900,7 +989,7 @@ class Scheduler:
|
|
900
989
|
# Handle the cases where prefill is not allowed
|
901
990
|
if (
|
902
991
|
self.batch_is_full or len(self.waiting_queue) == 0
|
903
|
-
) and self.
|
992
|
+
) and self.chunked_req is None:
|
904
993
|
return None
|
905
994
|
|
906
995
|
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
@@ -914,7 +1003,7 @@ class Scheduler:
|
|
914
1003
|
# Prefill policy
|
915
1004
|
adder = PrefillAdder(
|
916
1005
|
self.tree_cache,
|
917
|
-
self.
|
1006
|
+
self.token_to_kv_pool_allocator,
|
918
1007
|
self.running_batch,
|
919
1008
|
self.new_token_ratio,
|
920
1009
|
self.max_prefill_tokens,
|
@@ -922,10 +1011,10 @@ class Scheduler:
|
|
922
1011
|
running_bs if self.is_mixed_chunk else 0,
|
923
1012
|
)
|
924
1013
|
|
925
|
-
|
926
|
-
if
|
927
|
-
self.
|
928
|
-
self.
|
1014
|
+
is_chunked = self.chunked_req is not None
|
1015
|
+
if is_chunked:
|
1016
|
+
self.chunked_req.init_next_round_input()
|
1017
|
+
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
929
1018
|
|
930
1019
|
if self.lora_paths:
|
931
1020
|
lora_set = (
|
@@ -933,7 +1022,6 @@ class Scheduler:
|
|
933
1022
|
if self.running_batch is not None
|
934
1023
|
else set([])
|
935
1024
|
)
|
936
|
-
|
937
1025
|
# Get requests from the waiting queue to a new prefill batch
|
938
1026
|
for req in self.waiting_queue:
|
939
1027
|
if (
|
@@ -953,7 +1041,31 @@ class Scheduler:
|
|
953
1041
|
break
|
954
1042
|
|
955
1043
|
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
956
|
-
|
1044
|
+
|
1045
|
+
if self.enable_hierarchical_cache and req.last_node is not None:
|
1046
|
+
if req.last_node.evicted:
|
1047
|
+
# loading KV cache for the request
|
1048
|
+
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
1049
|
+
req.last_node,
|
1050
|
+
req.prefix_indices,
|
1051
|
+
adder.rem_total_tokens,
|
1052
|
+
)
|
1053
|
+
if req.last_node.loading:
|
1054
|
+
# to prevent frequent cache invalidation
|
1055
|
+
if req.rid in self.staging_reqs:
|
1056
|
+
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
1057
|
+
self.tree_cache.inc_lock_ref(req.last_node)
|
1058
|
+
self.staging_reqs[req.rid] = req.last_node
|
1059
|
+
continue
|
1060
|
+
elif req.last_node.loading:
|
1061
|
+
if not self.tree_cache.loading_complete(req.last_node):
|
1062
|
+
continue
|
1063
|
+
|
1064
|
+
if req.rid in self.staging_reqs:
|
1065
|
+
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
1066
|
+
del self.staging_reqs[req.rid]
|
1067
|
+
|
1068
|
+
res = adder.add_one_req(req, self.chunked_req)
|
957
1069
|
if res != AddReqResult.CONTINUE:
|
958
1070
|
if res == AddReqResult.NO_TOKEN:
|
959
1071
|
if self.enable_hierarchical_cache:
|
@@ -965,39 +1077,38 @@ class Scheduler:
|
|
965
1077
|
else:
|
966
1078
|
self.batch_is_full = True
|
967
1079
|
break
|
968
|
-
if self.
|
1080
|
+
if self.prefill_only_one_req:
|
969
1081
|
break
|
970
1082
|
|
971
1083
|
# Update waiting queue
|
972
|
-
can_run_list = adder.can_run_list
|
1084
|
+
can_run_list: List[Req] = adder.can_run_list
|
973
1085
|
if len(can_run_list) == 0:
|
974
1086
|
return None
|
975
1087
|
self.waiting_queue = [
|
976
1088
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
977
1089
|
]
|
978
1090
|
|
979
|
-
if adder.
|
980
|
-
assert self.
|
981
|
-
self.
|
1091
|
+
if adder.new_chunked_req is not None:
|
1092
|
+
assert self.chunked_req is None
|
1093
|
+
self.chunked_req = adder.new_chunked_req
|
982
1094
|
|
983
|
-
if self.
|
984
|
-
self.
|
1095
|
+
if self.chunked_req:
|
1096
|
+
self.chunked_req.is_chunked += 1
|
985
1097
|
|
986
1098
|
# Print stats
|
987
1099
|
if self.attn_tp_rank == 0:
|
988
|
-
self.log_prefill_stats(adder, can_run_list, running_bs
|
1100
|
+
self.log_prefill_stats(adder, can_run_list, running_bs)
|
989
1101
|
|
990
1102
|
# Create a new batch
|
991
1103
|
new_batch = ScheduleBatch.init_new(
|
992
1104
|
can_run_list,
|
993
1105
|
self.req_to_token_pool,
|
994
|
-
self.
|
1106
|
+
self.token_to_kv_pool_allocator,
|
995
1107
|
self.tree_cache,
|
996
1108
|
self.model_config,
|
997
1109
|
self.enable_overlap,
|
998
1110
|
self.spec_algorithm,
|
999
1111
|
self.server_args.enable_custom_logit_processor,
|
1000
|
-
self.server_args.return_hidden_states,
|
1001
1112
|
)
|
1002
1113
|
new_batch.prepare_for_extend()
|
1003
1114
|
|
@@ -1021,8 +1132,6 @@ class Scheduler:
|
|
1021
1132
|
|
1022
1133
|
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
|
1023
1134
|
"""Update the current running decoding batch."""
|
1024
|
-
global test_retract
|
1025
|
-
|
1026
1135
|
initial_bs = batch.batch_size()
|
1027
1136
|
|
1028
1137
|
batch.filter_batch()
|
@@ -1032,35 +1141,25 @@ class Scheduler:
|
|
1032
1141
|
|
1033
1142
|
# Check if decode out of memory
|
1034
1143
|
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
1035
|
-
|
1144
|
+
TEST_RETRACT and batch.batch_size() > 10
|
1036
1145
|
):
|
1037
1146
|
old_ratio = self.new_token_ratio
|
1038
1147
|
|
1039
|
-
retracted_reqs, new_token_ratio = batch.retract_decode()
|
1148
|
+
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
|
1040
1149
|
self.new_token_ratio = new_token_ratio
|
1041
|
-
if self.draft_worker:
|
1042
|
-
self.draft_worker.finish_request(retracted_reqs)
|
1043
1150
|
|
1044
1151
|
logger.info(
|
1045
1152
|
"Decode out of memory happened. "
|
1046
1153
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
1047
1154
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
1048
1155
|
)
|
1049
|
-
self.
|
1156
|
+
self._extend_requests_to_queue(retracted_reqs)
|
1050
1157
|
else:
|
1051
1158
|
self.new_token_ratio = max(
|
1052
1159
|
self.new_token_ratio - self.new_token_ratio_decay,
|
1053
1160
|
self.min_new_token_ratio,
|
1054
1161
|
)
|
1055
1162
|
|
1056
|
-
# Check for jump-forward
|
1057
|
-
if not self.disable_jump_forward and batch.has_grammar:
|
1058
|
-
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
1059
|
-
self.waiting_queue.extend(jump_forward_reqs)
|
1060
|
-
if batch.is_empty():
|
1061
|
-
self.batch_is_full = False
|
1062
|
-
return None
|
1063
|
-
|
1064
1163
|
if batch.batch_size() < initial_bs:
|
1065
1164
|
self.batch_is_full = False
|
1066
1165
|
|
@@ -1074,17 +1173,25 @@ class Scheduler:
|
|
1074
1173
|
"""Run a batch."""
|
1075
1174
|
self.forward_ct += 1
|
1076
1175
|
|
1176
|
+
# Check profiler
|
1177
|
+
if (
|
1178
|
+
self.profiler_target_forward_ct
|
1179
|
+
and self.profiler_target_forward_ct <= self.forward_ct
|
1180
|
+
):
|
1181
|
+
self.stop_profile()
|
1182
|
+
|
1077
1183
|
if self.is_generation:
|
1078
1184
|
if self.spec_algorithm.is_none():
|
1079
1185
|
model_worker_batch = batch.get_model_worker_batch()
|
1080
1186
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
1081
1187
|
model_worker_batch
|
1082
1188
|
)
|
1189
|
+
bid = model_worker_batch.bid
|
1083
1190
|
else:
|
1084
1191
|
(
|
1085
1192
|
logits_output,
|
1086
1193
|
next_token_ids,
|
1087
|
-
|
1194
|
+
bid,
|
1088
1195
|
num_accepted_tokens,
|
1089
1196
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1090
1197
|
self.spec_num_total_accepted_tokens += (
|
@@ -1093,11 +1200,24 @@ class Scheduler:
|
|
1093
1200
|
self.spec_num_total_forward_ct += batch.batch_size()
|
1094
1201
|
self.num_generated_tokens += num_accepted_tokens
|
1095
1202
|
batch.output_ids = next_token_ids
|
1203
|
+
# These 2 values are needed for processing the output, but the values can be
|
1204
|
+
# modified by overlap schedule. So we have to copy them here so that
|
1205
|
+
# we can use the correct values in output processing.
|
1206
|
+
if batch.return_logprob:
|
1207
|
+
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
1208
|
+
extend_logprob_start_len_per_req = [
|
1209
|
+
req.extend_logprob_start_len for req in batch.reqs
|
1210
|
+
]
|
1211
|
+
else:
|
1212
|
+
extend_input_len_per_req = None
|
1213
|
+
extend_logprob_start_len_per_req = None
|
1096
1214
|
|
1097
1215
|
ret = GenerationBatchResult(
|
1098
1216
|
logits_output=logits_output,
|
1099
1217
|
next_token_ids=next_token_ids,
|
1100
|
-
|
1218
|
+
extend_input_len_per_req=extend_input_len_per_req,
|
1219
|
+
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1220
|
+
bid=bid,
|
1101
1221
|
)
|
1102
1222
|
else: # embedding or reward model
|
1103
1223
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -1113,6 +1233,7 @@ class Scheduler:
|
|
1113
1233
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1114
1234
|
):
|
1115
1235
|
if batch.forward_mode.is_decode():
|
1236
|
+
assert isinstance(result, GenerationBatchResult)
|
1116
1237
|
self.process_batch_result_decode(batch, result)
|
1117
1238
|
if batch.is_empty():
|
1118
1239
|
self.running_batch = None
|
@@ -1121,11 +1242,22 @@ class Scheduler:
|
|
1121
1242
|
elif batch.forward_mode.is_idle():
|
1122
1243
|
if self.enable_overlap:
|
1123
1244
|
self.tp_worker.resolve_batch_result(result.bid)
|
1245
|
+
if batch.next_batch_sampling_info:
|
1246
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1247
|
+
self.current_stream.synchronize()
|
1248
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1124
1249
|
elif batch.forward_mode.is_dummy_first():
|
1125
1250
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1126
1251
|
self.current_stream.synchronize()
|
1127
1252
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1128
1253
|
|
1254
|
+
if self.return_health_check_ct:
|
1255
|
+
# Return some signal for the health check.
|
1256
|
+
# This is used to prevent the health check signal being blocked by long context prefill.
|
1257
|
+
# However, one minor issue is that this code path does not check the status of detokenizer manager.
|
1258
|
+
self.return_health_check_ct -= 1
|
1259
|
+
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
1260
|
+
|
1129
1261
|
def process_batch_result_prefill(
|
1130
1262
|
self,
|
1131
1263
|
batch: ScheduleBatch,
|
@@ -1137,10 +1269,14 @@ class Scheduler:
|
|
1137
1269
|
(
|
1138
1270
|
logits_output,
|
1139
1271
|
next_token_ids,
|
1272
|
+
extend_input_len_per_req,
|
1273
|
+
extend_logprob_start_len_per_req,
|
1140
1274
|
bid,
|
1141
1275
|
) = (
|
1142
1276
|
result.logits_output,
|
1143
1277
|
result.next_token_ids,
|
1278
|
+
result.extend_input_len_per_req,
|
1279
|
+
result.extend_logprob_start_len_per_req,
|
1144
1280
|
result.bid,
|
1145
1281
|
)
|
1146
1282
|
|
@@ -1150,12 +1286,14 @@ class Scheduler:
|
|
1150
1286
|
# Move next_token_ids and logprobs to cpu
|
1151
1287
|
next_token_ids = next_token_ids.tolist()
|
1152
1288
|
if batch.return_logprob:
|
1153
|
-
logits_output.next_token_logprobs
|
1154
|
-
logits_output.next_token_logprobs
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1289
|
+
if logits_output.next_token_logprobs is not None:
|
1290
|
+
logits_output.next_token_logprobs = (
|
1291
|
+
logits_output.next_token_logprobs.tolist()
|
1292
|
+
)
|
1293
|
+
if logits_output.input_token_logprobs is not None:
|
1294
|
+
logits_output.input_token_logprobs = tuple(
|
1295
|
+
logits_output.input_token_logprobs.tolist()
|
1296
|
+
)
|
1159
1297
|
|
1160
1298
|
hidden_state_offset = 0
|
1161
1299
|
|
@@ -1168,25 +1306,38 @@ class Scheduler:
|
|
1168
1306
|
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
1169
1307
|
# Free the one delayed token for the mixed decode batch
|
1170
1308
|
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
1171
|
-
self.
|
1309
|
+
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
1172
1310
|
continue
|
1173
1311
|
|
1174
|
-
if req.
|
1312
|
+
if req.is_chunked <= 0:
|
1313
|
+
# req output_ids are set here
|
1175
1314
|
req.output_ids.append(next_token_id)
|
1176
1315
|
req.check_finished()
|
1177
1316
|
|
1178
1317
|
if req.finished():
|
1179
1318
|
self.tree_cache.cache_finished_req(req)
|
1180
1319
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
1320
|
+
# This updates radix so others can match
|
1181
1321
|
self.tree_cache.cache_unfinished_req(req)
|
1182
1322
|
|
1183
1323
|
if req.return_logprob:
|
1184
|
-
|
1185
|
-
|
1324
|
+
assert extend_logprob_start_len_per_req is not None
|
1325
|
+
assert extend_input_len_per_req is not None
|
1326
|
+
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
1327
|
+
extend_input_len = extend_input_len_per_req[i]
|
1328
|
+
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
1329
|
+
self.add_logprob_return_values(
|
1330
|
+
i,
|
1331
|
+
req,
|
1332
|
+
logprob_pt,
|
1333
|
+
next_token_ids,
|
1334
|
+
num_input_logprobs,
|
1335
|
+
logits_output,
|
1186
1336
|
)
|
1337
|
+
logprob_pt += num_input_logprobs
|
1187
1338
|
|
1188
1339
|
if (
|
1189
|
-
|
1340
|
+
req.return_hidden_states
|
1190
1341
|
and logits_output.hidden_states is not None
|
1191
1342
|
):
|
1192
1343
|
req.hidden_states.append(
|
@@ -1205,12 +1356,31 @@ class Scheduler:
|
|
1205
1356
|
req.grammar.finished = req.finished()
|
1206
1357
|
else:
|
1207
1358
|
# being chunked reqs' prefill is not finished
|
1208
|
-
req.
|
1359
|
+
req.is_chunked -= 1
|
1209
1360
|
# There is only at most one request being currently chunked.
|
1210
1361
|
# Because this request does not finish prefill,
|
1211
1362
|
# we don't want to stream the request currently being chunked.
|
1212
1363
|
skip_stream_req = req
|
1213
1364
|
|
1365
|
+
# Incrementally update input logprobs.
|
1366
|
+
if req.return_logprob:
|
1367
|
+
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
1368
|
+
extend_input_len = extend_input_len_per_req[i]
|
1369
|
+
if extend_logprob_start_len < extend_input_len:
|
1370
|
+
# Update input logprobs.
|
1371
|
+
num_input_logprobs = (
|
1372
|
+
extend_input_len - extend_logprob_start_len
|
1373
|
+
)
|
1374
|
+
self.add_input_logprob_return_values(
|
1375
|
+
i,
|
1376
|
+
req,
|
1377
|
+
logits_output,
|
1378
|
+
logprob_pt,
|
1379
|
+
num_input_logprobs,
|
1380
|
+
last_prefill_chunk=False,
|
1381
|
+
)
|
1382
|
+
logprob_pt += num_input_logprobs
|
1383
|
+
|
1214
1384
|
if batch.next_batch_sampling_info:
|
1215
1385
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1216
1386
|
self.current_stream.synchronize()
|
@@ -1226,7 +1396,7 @@ class Scheduler:
|
|
1226
1396
|
continue
|
1227
1397
|
|
1228
1398
|
req.embedding = embeddings[i]
|
1229
|
-
if req.
|
1399
|
+
if req.is_chunked <= 0:
|
1230
1400
|
# Dummy output token for embedding models
|
1231
1401
|
req.output_ids.append(0)
|
1232
1402
|
req.check_finished()
|
@@ -1237,7 +1407,7 @@ class Scheduler:
|
|
1237
1407
|
self.tree_cache.cache_unfinished_req(req)
|
1238
1408
|
else:
|
1239
1409
|
# being chunked reqs' prefill is not finished
|
1240
|
-
req.
|
1410
|
+
req.is_chunked -= 1
|
1241
1411
|
|
1242
1412
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
1243
1413
|
|
@@ -1254,23 +1424,27 @@ class Scheduler:
|
|
1254
1424
|
self.num_generated_tokens += len(batch.reqs)
|
1255
1425
|
|
1256
1426
|
if self.enable_overlap:
|
1427
|
+
assert batch.spec_algorithm.is_none()
|
1257
1428
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
1258
1429
|
next_token_logprobs = logits_output.next_token_logprobs
|
1259
|
-
|
1430
|
+
elif batch.spec_algorithm.is_none():
|
1431
|
+
# spec decoding handles output logprobs inside verify process.
|
1260
1432
|
next_token_ids = next_token_ids.tolist()
|
1261
1433
|
if batch.return_logprob:
|
1262
1434
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
1263
1435
|
|
1264
|
-
self.
|
1436
|
+
self.token_to_kv_pool_allocator.free_group_begin()
|
1265
1437
|
|
1266
1438
|
# Check finish condition
|
1439
|
+
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
1440
|
+
# We should ignore using next_token_ids for spec decoding cases.
|
1267
1441
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
1268
1442
|
if req.is_retracted:
|
1269
1443
|
continue
|
1270
1444
|
|
1271
1445
|
if self.enable_overlap and req.finished():
|
1272
1446
|
# Free the one delayed token
|
1273
|
-
self.
|
1447
|
+
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
1274
1448
|
continue
|
1275
1449
|
|
1276
1450
|
if batch.spec_algorithm.is_none():
|
@@ -1278,11 +1452,11 @@ class Scheduler:
|
|
1278
1452
|
req.output_ids.append(next_token_id)
|
1279
1453
|
|
1280
1454
|
req.check_finished()
|
1281
|
-
|
1282
1455
|
if req.finished():
|
1283
1456
|
self.tree_cache.cache_finished_req(req)
|
1284
1457
|
|
1285
|
-
if req.return_logprob:
|
1458
|
+
if req.return_logprob and batch.spec_algorithm.is_none():
|
1459
|
+
# speculative worker handles logprob in speculative decoding
|
1286
1460
|
req.output_token_logprobs_val.append(next_token_logprobs[i])
|
1287
1461
|
req.output_token_logprobs_idx.append(next_token_id)
|
1288
1462
|
if req.top_logprobs_num > 0:
|
@@ -1292,14 +1466,18 @@ class Scheduler:
|
|
1292
1466
|
req.output_top_logprobs_idx.append(
|
1293
1467
|
logits_output.next_token_top_logprobs_idx[i]
|
1294
1468
|
)
|
1469
|
+
if req.token_ids_logprob is not None:
|
1470
|
+
req.output_token_ids_logprobs_val.append(
|
1471
|
+
logits_output.next_token_token_ids_logprobs_val[i]
|
1472
|
+
)
|
1473
|
+
req.output_token_ids_logprobs_idx.append(
|
1474
|
+
logits_output.next_token_token_ids_logprobs_idx[i]
|
1475
|
+
)
|
1295
1476
|
|
1296
|
-
if
|
1297
|
-
self.server_args.return_hidden_states
|
1298
|
-
and logits_output.hidden_states is not None
|
1299
|
-
):
|
1477
|
+
if req.return_hidden_states and logits_output.hidden_states is not None:
|
1300
1478
|
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
1301
1479
|
|
1302
|
-
if req.grammar is not None:
|
1480
|
+
if req.grammar is not None and batch.spec_algorithm.is_none():
|
1303
1481
|
req.grammar.accept_token(next_token_id)
|
1304
1482
|
req.grammar.finished = req.finished()
|
1305
1483
|
|
@@ -1307,10 +1485,9 @@ class Scheduler:
|
|
1307
1485
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1308
1486
|
self.current_stream.synchronize()
|
1309
1487
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1310
|
-
|
1311
1488
|
self.stream_output(batch.reqs, batch.return_logprob)
|
1312
1489
|
|
1313
|
-
self.
|
1490
|
+
self.token_to_kv_pool_allocator.free_group_end()
|
1314
1491
|
|
1315
1492
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
1316
1493
|
if (
|
@@ -1319,86 +1496,167 @@ class Scheduler:
|
|
1319
1496
|
):
|
1320
1497
|
self.log_decode_stats()
|
1321
1498
|
|
1322
|
-
def
|
1499
|
+
def add_input_logprob_return_values(
|
1323
1500
|
self,
|
1324
1501
|
i: int,
|
1325
1502
|
req: Req,
|
1326
|
-
pt: int,
|
1327
|
-
next_token_ids: List[int],
|
1328
1503
|
output: LogitsProcessorOutput,
|
1504
|
+
logprob_pt: int,
|
1505
|
+
num_input_logprobs: int,
|
1506
|
+
last_prefill_chunk: bool, # If True, it means prefill is finished.
|
1329
1507
|
):
|
1330
|
-
"""
|
1331
|
-
|
1332
|
-
|
1508
|
+
"""Incrementally add input logprobs to `req`.
|
1509
|
+
|
1510
|
+
Args:
|
1511
|
+
i: The request index in a batch.
|
1512
|
+
req: The request. Input logprobs inside req are modified as a
|
1513
|
+
consequence of the API
|
1514
|
+
fill_ids: The prefill ids processed.
|
1515
|
+
output: Logit processor output that's used to compute input logprobs
|
1516
|
+
last_prefill_chunk: True if it is the last prefill (when chunked).
|
1517
|
+
Some of input logprob operation should only happen at the last
|
1518
|
+
prefill (e.g., computing input token logprobs).
|
1519
|
+
"""
|
1520
|
+
assert output.input_token_logprobs is not None
|
1521
|
+
if req.input_token_logprobs is None:
|
1522
|
+
req.input_token_logprobs = []
|
1523
|
+
if req.temp_input_top_logprobs_val is None:
|
1524
|
+
req.temp_input_top_logprobs_val = []
|
1525
|
+
if req.temp_input_top_logprobs_idx is None:
|
1526
|
+
req.temp_input_top_logprobs_idx = []
|
1527
|
+
if req.temp_input_token_ids_logprobs_val is None:
|
1528
|
+
req.temp_input_token_ids_logprobs_val = []
|
1529
|
+
if req.temp_input_token_ids_logprobs_idx is None:
|
1530
|
+
req.temp_input_token_ids_logprobs_idx = []
|
1531
|
+
|
1532
|
+
if req.input_token_logprobs_val is not None:
|
1533
|
+
# The input logprob has been already computed. It only happens
|
1534
|
+
# upon retract.
|
1535
|
+
if req.top_logprobs_num > 0:
|
1536
|
+
assert req.input_token_logprobs_val is not None
|
1537
|
+
return
|
1333
1538
|
|
1334
|
-
#
|
1335
|
-
|
1539
|
+
# Important for the performance.
|
1540
|
+
assert isinstance(output.input_token_logprobs, tuple)
|
1541
|
+
input_token_logprobs: Tuple[int] = output.input_token_logprobs
|
1542
|
+
input_token_logprobs = input_token_logprobs[
|
1543
|
+
logprob_pt : logprob_pt + num_input_logprobs
|
1544
|
+
]
|
1545
|
+
req.input_token_logprobs.extend(input_token_logprobs)
|
1336
1546
|
|
1337
|
-
if req.
|
1338
|
-
|
1339
|
-
|
1340
|
-
]
|
1547
|
+
if req.top_logprobs_num > 0:
|
1548
|
+
req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
|
1549
|
+
req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
|
1341
1550
|
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1551
|
+
if req.token_ids_logprob is not None:
|
1552
|
+
req.temp_input_token_ids_logprobs_val.append(
|
1553
|
+
output.input_token_ids_logprobs_val[i]
|
1554
|
+
)
|
1555
|
+
req.temp_input_token_ids_logprobs_idx.append(
|
1556
|
+
output.input_token_ids_logprobs_idx[i]
|
1557
|
+
)
|
1558
|
+
|
1559
|
+
if last_prefill_chunk:
|
1560
|
+
input_token_logprobs = req.input_token_logprobs
|
1561
|
+
req.input_token_logprobs = None
|
1562
|
+
assert req.input_token_logprobs_val is None
|
1563
|
+
assert req.input_token_logprobs_idx is None
|
1564
|
+
assert req.input_top_logprobs_val is None
|
1565
|
+
assert req.input_top_logprobs_idx is None
|
1566
|
+
|
1567
|
+
# Compute input_token_logprobs_val
|
1568
|
+
# Always pad the first one with None.
|
1569
|
+
req.input_token_logprobs_val = [None]
|
1570
|
+
req.input_token_logprobs_val.extend(input_token_logprobs)
|
1571
|
+
# The last input logprob is for sampling, so just pop it out.
|
1572
|
+
req.input_token_logprobs_val.pop()
|
1573
|
+
|
1574
|
+
# Compute input_token_logprobs_idx
|
1575
|
+
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
1348
1576
|
# Clip the padded hash values from image tokens.
|
1349
1577
|
# Otherwise, it will lead to detokenization errors.
|
1350
1578
|
input_token_logprobs_idx = [
|
1351
1579
|
x if x < self.model_config.vocab_size - 1 else 0
|
1352
1580
|
for x in input_token_logprobs_idx
|
1353
1581
|
]
|
1582
|
+
req.input_token_logprobs_idx = input_token_logprobs_idx
|
1354
1583
|
|
1355
|
-
if
|
1356
|
-
req.
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1584
|
+
if req.top_logprobs_num > 0:
|
1585
|
+
req.input_top_logprobs_val = [None]
|
1586
|
+
req.input_top_logprobs_idx = [None]
|
1587
|
+
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
1588
|
+
req.temp_input_token_ids_logprobs_idx
|
1589
|
+
)
|
1590
|
+
for val, idx in zip(
|
1591
|
+
req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
|
1592
|
+
):
|
1593
|
+
req.input_top_logprobs_val.extend(val)
|
1594
|
+
req.input_top_logprobs_idx.extend(idx)
|
1595
|
+
|
1596
|
+
# Last token is a sample token.
|
1597
|
+
req.input_top_logprobs_val.pop()
|
1598
|
+
req.input_top_logprobs_idx.pop()
|
1599
|
+
req.temp_input_top_logprobs_idx = None
|
1600
|
+
req.temp_input_top_logprobs_val = None
|
1601
|
+
|
1602
|
+
if req.token_ids_logprob is not None:
|
1603
|
+
req.input_token_ids_logprobs_val = [None]
|
1604
|
+
req.input_token_ids_logprobs_idx = [None]
|
1605
|
+
|
1606
|
+
for val, idx in zip(
|
1607
|
+
req.temp_input_token_ids_logprobs_val,
|
1608
|
+
req.temp_input_token_ids_logprobs_idx,
|
1609
|
+
strict=True,
|
1610
|
+
):
|
1611
|
+
req.input_token_ids_logprobs_val.extend(val)
|
1612
|
+
req.input_token_ids_logprobs_idx.extend(idx)
|
1360
1613
|
|
1361
|
-
|
1362
|
-
|
1614
|
+
# Last token is a sample token.
|
1615
|
+
req.input_token_ids_logprobs_val.pop()
|
1616
|
+
req.input_token_ids_logprobs_idx.pop()
|
1617
|
+
req.temp_input_token_ids_logprobs_idx = None
|
1618
|
+
req.temp_input_token_ids_logprobs_val = None
|
1363
1619
|
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
],
|
1375
|
-
)
|
1376
|
-
req.output_token_logprobs_idx.extend(
|
1377
|
-
req.fill_ids[
|
1378
|
-
len(req.fill_ids)
|
1379
|
-
- req.last_update_decode_tokens : len(req.fill_ids)
|
1380
|
-
]
|
1381
|
-
)
|
1620
|
+
if req.return_logprob:
|
1621
|
+
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
|
1622
|
+
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
1623
|
+
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
1624
|
+
if req.top_logprobs_num > 0:
|
1625
|
+
assert len(req.input_top_logprobs_val) == relevant_tokens_len
|
1626
|
+
assert len(req.input_top_logprobs_idx) == relevant_tokens_len
|
1627
|
+
if req.token_ids_logprob is not None:
|
1628
|
+
assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
|
1629
|
+
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
1382
1630
|
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
req.output_top_logprobs_idx.extend(
|
1396
|
-
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
1397
|
-
)
|
1631
|
+
def add_logprob_return_values(
|
1632
|
+
self,
|
1633
|
+
i: int,
|
1634
|
+
req: Req,
|
1635
|
+
pt: int,
|
1636
|
+
next_token_ids: List[int],
|
1637
|
+
num_input_logprobs: int,
|
1638
|
+
output: LogitsProcessorOutput,
|
1639
|
+
):
|
1640
|
+
"""Attach logprobs to the return values."""
|
1641
|
+
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
1642
|
+
req.output_token_logprobs_idx.append(next_token_ids[i])
|
1398
1643
|
|
1644
|
+
self.add_input_logprob_return_values(
|
1645
|
+
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
1646
|
+
)
|
1647
|
+
|
1648
|
+
if req.top_logprobs_num > 0:
|
1399
1649
|
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
1400
1650
|
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
1401
1651
|
|
1652
|
+
if req.token_ids_logprob is not None:
|
1653
|
+
req.output_token_ids_logprobs_val.append(
|
1654
|
+
output.next_token_token_ids_logprobs_val[i]
|
1655
|
+
)
|
1656
|
+
req.output_token_ids_logprobs_idx.append(
|
1657
|
+
output.next_token_token_ids_logprobs_idx[i]
|
1658
|
+
)
|
1659
|
+
|
1402
1660
|
return num_input_logprobs
|
1403
1661
|
|
1404
1662
|
def stream_output(
|
@@ -1409,7 +1667,6 @@ class Scheduler:
|
|
1409
1667
|
finished_reasons: List[BaseFinishReason] = []
|
1410
1668
|
|
1411
1669
|
if self.is_generation:
|
1412
|
-
vids = []
|
1413
1670
|
decoded_texts = []
|
1414
1671
|
decode_ids_list = []
|
1415
1672
|
read_offsets = []
|
@@ -1422,7 +1679,7 @@ class Scheduler:
|
|
1422
1679
|
completion_tokens = []
|
1423
1680
|
cached_tokens = []
|
1424
1681
|
spec_verify_ct = []
|
1425
|
-
|
1682
|
+
output_hidden_states = None
|
1426
1683
|
|
1427
1684
|
if return_logprob:
|
1428
1685
|
input_token_logprobs_val = []
|
@@ -1433,33 +1690,46 @@ class Scheduler:
|
|
1433
1690
|
input_top_logprobs_idx = []
|
1434
1691
|
output_top_logprobs_val = []
|
1435
1692
|
output_top_logprobs_idx = []
|
1693
|
+
input_token_ids_logprobs_val = []
|
1694
|
+
input_token_ids_logprobs_idx = []
|
1695
|
+
output_token_ids_logprobs_val = []
|
1696
|
+
output_token_ids_logprobs_idx = []
|
1436
1697
|
else:
|
1437
1698
|
input_token_logprobs_val = input_token_logprobs_idx = (
|
1438
1699
|
output_token_logprobs_val
|
1439
1700
|
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
1440
1701
|
input_top_logprobs_idx
|
1441
|
-
) = output_top_logprobs_val = output_top_logprobs_idx =
|
1702
|
+
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
1703
|
+
input_token_ids_logprobs_val
|
1704
|
+
) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
|
1705
|
+
output_token_ids_logprobs_idx
|
1706
|
+
) = None
|
1442
1707
|
|
1443
1708
|
for req in reqs:
|
1444
1709
|
if req is skip_req:
|
1445
1710
|
continue
|
1446
1711
|
|
1447
|
-
#
|
1712
|
+
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
|
1713
|
+
if self.model_config.is_multimodal_gen and req.to_abort:
|
1714
|
+
continue
|
1715
|
+
|
1448
1716
|
if (
|
1449
1717
|
req.finished()
|
1450
1718
|
# If stream, follow the given stream_interval
|
1451
1719
|
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
1452
1720
|
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
1453
|
-
|
1721
|
+
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
|
1722
|
+
# always increase one-by-one.
|
1723
|
+
or (
|
1724
|
+
not req.stream
|
1725
|
+
and len(req.output_ids) % 50 == 0
|
1726
|
+
and not self.model_config.is_multimodal_gen
|
1727
|
+
)
|
1454
1728
|
):
|
1455
|
-
if self.draft_worker and req.finished():
|
1456
|
-
self.draft_worker.finish_request(req)
|
1457
|
-
|
1458
1729
|
rids.append(req.rid)
|
1459
1730
|
finished_reasons.append(
|
1460
1731
|
req.finished_reason.to_json() if req.finished_reason else None
|
1461
1732
|
)
|
1462
|
-
vids.append(req.vid)
|
1463
1733
|
decoded_texts.append(req.decoded_text)
|
1464
1734
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
1465
1735
|
decode_ids_list.append(decode_ids)
|
@@ -1488,16 +1758,32 @@ class Scheduler:
|
|
1488
1758
|
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
1489
1759
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1490
1760
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1761
|
+
input_token_ids_logprobs_val.append(
|
1762
|
+
req.input_token_ids_logprobs_val
|
1763
|
+
)
|
1764
|
+
input_token_ids_logprobs_idx.append(
|
1765
|
+
req.input_token_ids_logprobs_idx
|
1766
|
+
)
|
1767
|
+
output_token_ids_logprobs_val.append(
|
1768
|
+
req.output_token_ids_logprobs_val
|
1769
|
+
)
|
1770
|
+
output_token_ids_logprobs_idx.append(
|
1771
|
+
req.output_token_ids_logprobs_idx
|
1772
|
+
)
|
1491
1773
|
|
1492
|
-
|
1774
|
+
if req.return_hidden_states:
|
1775
|
+
if output_hidden_states is None:
|
1776
|
+
output_hidden_states = []
|
1777
|
+
output_hidden_states.append(req.hidden_states)
|
1493
1778
|
|
1494
1779
|
# Send to detokenizer
|
1495
1780
|
if rids:
|
1781
|
+
if self.model_config.is_multimodal_gen:
|
1782
|
+
raise NotImplementedError()
|
1496
1783
|
self.send_to_detokenizer.send_pyobj(
|
1497
1784
|
BatchTokenIDOut(
|
1498
1785
|
rids,
|
1499
1786
|
finished_reasons,
|
1500
|
-
vids,
|
1501
1787
|
decoded_texts,
|
1502
1788
|
decode_ids_list,
|
1503
1789
|
read_offsets,
|
@@ -1517,7 +1803,11 @@ class Scheduler:
|
|
1517
1803
|
input_top_logprobs_idx,
|
1518
1804
|
output_top_logprobs_val,
|
1519
1805
|
output_top_logprobs_idx,
|
1520
|
-
|
1806
|
+
input_token_ids_logprobs_val,
|
1807
|
+
input_token_ids_logprobs_idx,
|
1808
|
+
output_token_ids_logprobs_val,
|
1809
|
+
output_token_ids_logprobs_idx,
|
1810
|
+
output_hidden_states,
|
1521
1811
|
)
|
1522
1812
|
)
|
1523
1813
|
else: # embedding or reward model
|
@@ -1575,13 +1865,12 @@ class Scheduler:
|
|
1575
1865
|
idle_batch = ScheduleBatch.init_new(
|
1576
1866
|
[],
|
1577
1867
|
self.req_to_token_pool,
|
1578
|
-
self.
|
1868
|
+
self.token_to_kv_pool_allocator,
|
1579
1869
|
self.tree_cache,
|
1580
1870
|
self.model_config,
|
1581
1871
|
self.enable_overlap,
|
1582
1872
|
self.spec_algorithm,
|
1583
1873
|
self.server_args.enable_custom_logit_processor,
|
1584
|
-
self.server_args.return_hidden_states,
|
1585
1874
|
)
|
1586
1875
|
idle_batch.prepare_for_idle()
|
1587
1876
|
return idle_batch
|
@@ -1596,18 +1885,25 @@ class Scheduler:
|
|
1596
1885
|
except futures._base.TimeoutError:
|
1597
1886
|
break
|
1598
1887
|
|
1599
|
-
if self.
|
1888
|
+
if self.server_args.enable_dp_attention:
|
1889
|
+
tp_size = self.attn_tp_size
|
1890
|
+
tp_group = self.attn_tp_cpu_group
|
1891
|
+
else:
|
1892
|
+
tp_size = self.tp_size
|
1893
|
+
tp_group = self.tp_cpu_group
|
1894
|
+
|
1895
|
+
if tp_size > 1:
|
1600
1896
|
# Sync across TP ranks to make sure they have the same number of ready requests
|
1601
1897
|
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
1602
1898
|
torch.distributed.all_reduce(
|
1603
|
-
tensor, op=torch.distributed.ReduceOp.MAX, group=
|
1899
|
+
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
1604
1900
|
)
|
1605
1901
|
num_ready_reqs_max = tensor.item()
|
1606
1902
|
for i in range(num_ready_reqs, num_ready_reqs_max):
|
1607
1903
|
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
1608
1904
|
num_ready_reqs = num_ready_reqs_max
|
1609
1905
|
|
1610
|
-
self.
|
1906
|
+
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
1611
1907
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1612
1908
|
|
1613
1909
|
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
@@ -1618,21 +1914,25 @@ class Scheduler:
|
|
1618
1914
|
if len(self.waiting_queue) == 0 and (
|
1619
1915
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
1620
1916
|
):
|
1917
|
+
self.cur_batch = None
|
1918
|
+
self.last_batch = None
|
1621
1919
|
self.tree_cache.reset()
|
1622
1920
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
1623
1921
|
if self.grammar_backend:
|
1624
1922
|
self.grammar_backend.reset()
|
1625
1923
|
self.req_to_token_pool.clear()
|
1626
|
-
self.
|
1924
|
+
self.token_to_kv_pool_allocator.clear()
|
1627
1925
|
|
1628
1926
|
if not self.spec_algorithm.is_none():
|
1629
1927
|
self.draft_worker.model_runner.req_to_token_pool.clear()
|
1630
|
-
self.draft_worker.model_runner.
|
1928
|
+
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
1631
1929
|
|
1632
1930
|
self.num_generated_tokens = 0
|
1633
1931
|
self.forward_ct_decode = 0
|
1634
1932
|
self.spec_num_total_accepted_tokens = 0
|
1635
1933
|
self.spec_num_total_forward_ct = 0
|
1934
|
+
self.cum_spec_accept_length = 0
|
1935
|
+
self.cum_spec_accept_count = 0
|
1636
1936
|
torch.cuda.empty_cache()
|
1637
1937
|
logger.info("Cache flushed successfully!")
|
1638
1938
|
if_success = True
|
@@ -1645,6 +1945,49 @@ class Scheduler:
|
|
1645
1945
|
if_success = False
|
1646
1946
|
return if_success
|
1647
1947
|
|
1948
|
+
def get_internal_state(self, recv_req: GetInternalStateReq):
|
1949
|
+
ret = dict(global_server_args_dict)
|
1950
|
+
ret["last_gen_throughput"] = self.last_gen_throughput
|
1951
|
+
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
1952
|
+
ret["avg_spec_accept_length"] = (
|
1953
|
+
self.cum_spec_accept_length / self.cum_spec_accept_count
|
1954
|
+
)
|
1955
|
+
|
1956
|
+
if RECORD_STEP_TIME:
|
1957
|
+
ret["step_time_dict"] = self.step_time_dict
|
1958
|
+
return GetInternalStateReqOutput(
|
1959
|
+
internal_state=ret,
|
1960
|
+
)
|
1961
|
+
|
1962
|
+
def set_internal_state(self, recv_req: SetInternalStateReq):
|
1963
|
+
server_args_dict = recv_req.server_args
|
1964
|
+
args_allow_update = set(
|
1965
|
+
[
|
1966
|
+
"speculative_accept_threshold_single",
|
1967
|
+
"speculative_accept_threshold_acc",
|
1968
|
+
]
|
1969
|
+
)
|
1970
|
+
if_success = True
|
1971
|
+
for k, v in server_args_dict.items():
|
1972
|
+
if k not in args_allow_update:
|
1973
|
+
logging.warning(f"Updating {k} is not supported.")
|
1974
|
+
if_success = False
|
1975
|
+
break
|
1976
|
+
if if_success:
|
1977
|
+
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
1978
|
+
avg_spec_accept_length = (
|
1979
|
+
self.cum_spec_accept_length / self.cum_spec_accept_count
|
1980
|
+
)
|
1981
|
+
logger.info(f"{avg_spec_accept_length=}")
|
1982
|
+
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
|
1983
|
+
for k, v in server_args_dict.items():
|
1984
|
+
global_server_args_dict[k] = v
|
1985
|
+
logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
|
1986
|
+
return SetInternalStateReqOutput(
|
1987
|
+
updated=True,
|
1988
|
+
server_args=global_server_args_dict,
|
1989
|
+
)
|
1990
|
+
|
1648
1991
|
def abort_request(self, recv_req: AbortReq):
|
1649
1992
|
# Delete requests in the waiting queue
|
1650
1993
|
to_del = None
|
@@ -1674,7 +2017,7 @@ class Scheduler:
|
|
1674
2017
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1675
2018
|
else:
|
1676
2019
|
logger.error(message)
|
1677
|
-
return UpdateWeightFromDiskReqOutput(success, message)
|
2020
|
+
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
1678
2021
|
|
1679
2022
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
1680
2023
|
"""Initialize the online model parameter update group."""
|
@@ -1699,8 +2042,9 @@ class Scheduler:
|
|
1699
2042
|
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
1700
2043
|
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
1701
2044
|
if success:
|
1702
|
-
|
1703
|
-
|
2045
|
+
if recv_req.flush_cache:
|
2046
|
+
flash_cache_success = self.flush_cache()
|
2047
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
1704
2048
|
else:
|
1705
2049
|
logger.error(message)
|
1706
2050
|
return UpdateWeightsFromTensorReqOutput(success, message)
|
@@ -1709,7 +2053,7 @@ class Scheduler:
|
|
1709
2053
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1710
2054
|
return GetWeightsByNameReqOutput(parameter)
|
1711
2055
|
|
1712
|
-
def release_memory_occupation(self):
|
2056
|
+
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
1713
2057
|
self.stashed_model_static_state = _export_static_state(
|
1714
2058
|
self.tp_worker.worker.model_runner.model
|
1715
2059
|
)
|
@@ -1717,7 +2061,7 @@ class Scheduler:
|
|
1717
2061
|
self.flush_cache()
|
1718
2062
|
return ReleaseMemoryOccupationReqOutput()
|
1719
2063
|
|
1720
|
-
def resume_memory_occupation(self):
|
2064
|
+
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
1721
2065
|
self.memory_saver_adapter.resume()
|
1722
2066
|
_import_static_state(
|
1723
2067
|
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
@@ -1726,24 +2070,96 @@ class Scheduler:
|
|
1726
2070
|
return ResumeMemoryOccupationReqOutput()
|
1727
2071
|
|
1728
2072
|
def profile(self, recv_req: ProfileReq):
|
1729
|
-
if recv_req ==
|
1730
|
-
self.start_profile(
|
2073
|
+
if recv_req.type == ProfileReqType.START_PROFILE:
|
2074
|
+
return self.start_profile(
|
2075
|
+
recv_req.output_dir, recv_req.num_steps, recv_req.activities
|
2076
|
+
)
|
1731
2077
|
else:
|
1732
|
-
self.stop_profile()
|
2078
|
+
return self.stop_profile()
|
2079
|
+
|
2080
|
+
def start_profile(
|
2081
|
+
self,
|
2082
|
+
output_dir: Optional[str],
|
2083
|
+
num_steps: Optional[int],
|
2084
|
+
activities: Optional[List[str]],
|
2085
|
+
) -> None:
|
2086
|
+
if self.torch_profiler_activities:
|
2087
|
+
return ProfileReqOutput(
|
2088
|
+
success=False,
|
2089
|
+
message="Profiling is already in progress. Call /stop_profile first.",
|
2090
|
+
)
|
2091
|
+
|
2092
|
+
if output_dir is None:
|
2093
|
+
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
2094
|
+
if activities is None:
|
2095
|
+
activities = ["CPU", "GPU"]
|
1733
2096
|
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
2097
|
+
self.torch_profiler_output_dir = output_dir
|
2098
|
+
self.torch_profiler_activities = activities
|
2099
|
+
logger.info(
|
2100
|
+
"Profiling starts. Traces will be saved to: %s",
|
2101
|
+
self.torch_profiler_output_dir,
|
2102
|
+
)
|
2103
|
+
|
2104
|
+
activity_map = {
|
2105
|
+
"CPU": torch.profiler.ProfilerActivity.CPU,
|
2106
|
+
"GPU": torch.profiler.ProfilerActivity.CUDA,
|
2107
|
+
}
|
2108
|
+
torchprof_activities = [
|
2109
|
+
activity_map[a] for a in activities if a in activity_map
|
2110
|
+
]
|
2111
|
+
|
2112
|
+
if torchprof_activities:
|
2113
|
+
self.torch_profiler = torch.profiler.profile(
|
2114
|
+
activities=torchprof_activities,
|
2115
|
+
with_stack=True,
|
2116
|
+
)
|
2117
|
+
self.torch_profiler.start()
|
2118
|
+
|
2119
|
+
if "MEM" in activities:
|
2120
|
+
torch.cuda.memory._record_memory_history(max_entries=100000)
|
2121
|
+
|
2122
|
+
if num_steps:
|
2123
|
+
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
2124
|
+
# The caller will be notified when reaching profiler_target_forward_ct
|
2125
|
+
else:
|
2126
|
+
self.profiler_target_forward_ct = None
|
2127
|
+
return ProfileReqOutput(success=True, message="Succeeded")
|
1738
2128
|
|
1739
2129
|
def stop_profile(self) -> None:
|
1740
|
-
if self.
|
1741
|
-
|
1742
|
-
|
1743
|
-
|
1744
|
-
|
2130
|
+
if self.torch_profiler_activities is None:
|
2131
|
+
return
|
2132
|
+
|
2133
|
+
logger.info("Stop profiling...")
|
2134
|
+
if self.torch_profiler is not None:
|
2135
|
+
self.torch_profiler.stop()
|
2136
|
+
self.torch_profiler.export_chrome_trace(
|
2137
|
+
os.path.join(
|
2138
|
+
self.torch_profiler_output_dir,
|
2139
|
+
str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
2140
|
+
)
|
2141
|
+
)
|
2142
|
+
|
2143
|
+
if "MEM" in self.torch_profiler_activities:
|
2144
|
+
memory_profile_path = os.path.join(
|
2145
|
+
self.torch_profiler_trace_dir,
|
2146
|
+
str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
|
2147
|
+
)
|
2148
|
+
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
2149
|
+
torch.cuda.memory._record_memory_history(enabled=None)
|
2150
|
+
|
2151
|
+
logger.info(
|
2152
|
+
"Profiling done. Traces are saved to: %s",
|
2153
|
+
self.torch_profiler_output_dir,
|
1745
2154
|
)
|
1746
|
-
|
2155
|
+
self.torch_profiler = None
|
2156
|
+
self.torch_profiler_output_dir = None
|
2157
|
+
self.torch_profiler_activities = None
|
2158
|
+
|
2159
|
+
if self.profiler_target_forward_ct:
|
2160
|
+
self.send_to_tokenizer.send_pyobj(
|
2161
|
+
ProfileReqOutput(success=True, message="Succeeded.")
|
2162
|
+
)
|
1747
2163
|
|
1748
2164
|
def open_session(self, recv_req: OpenSessionReqInput):
|
1749
2165
|
# handle error
|
@@ -1752,7 +2168,7 @@ class Scheduler:
|
|
1752
2168
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1753
2169
|
return OpenSessionReqOutput(session_id, False)
|
1754
2170
|
elif session_id is None:
|
1755
|
-
logger.warning(
|
2171
|
+
logger.warning("session id is None, cannot open.")
|
1756
2172
|
return OpenSessionReqOutput(session_id, False)
|
1757
2173
|
else:
|
1758
2174
|
self.sessions[session_id] = Session(
|
@@ -1769,6 +2185,10 @@ class Scheduler:
|
|
1769
2185
|
del self.sessions[session_id]
|
1770
2186
|
|
1771
2187
|
|
2188
|
+
def is_health_check_generate_req(recv_req):
|
2189
|
+
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
2190
|
+
|
2191
|
+
|
1772
2192
|
def _export_static_state(model):
|
1773
2193
|
return dict(
|
1774
2194
|
buffers=[
|
@@ -1791,26 +2211,28 @@ def run_scheduler_process(
|
|
1791
2211
|
dp_rank: Optional[int],
|
1792
2212
|
pipe_writer,
|
1793
2213
|
):
|
1794
|
-
|
2214
|
+
# Config the process
|
2215
|
+
# kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
|
2216
|
+
setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
|
1795
2217
|
faulthandler.enable()
|
2218
|
+
parent_process = psutil.Process().parent()
|
1796
2219
|
|
1797
2220
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
1798
2221
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
1799
2222
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
1800
2223
|
|
1801
|
-
#
|
2224
|
+
# Configure the logger
|
1802
2225
|
if dp_rank is None:
|
1803
|
-
|
2226
|
+
prefix = f" TP{tp_rank}"
|
1804
2227
|
else:
|
1805
|
-
|
2228
|
+
prefix = f" DP{dp_rank} TP{tp_rank}"
|
2229
|
+
configure_logger(server_args, prefix=prefix)
|
1806
2230
|
suppress_other_loggers()
|
1807
2231
|
|
1808
2232
|
# Set cpu affinity to this gpu process
|
1809
2233
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1810
2234
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1811
2235
|
|
1812
|
-
parent_process = psutil.Process().parent()
|
1813
|
-
|
1814
2236
|
# Create a scheduler and run the event loop
|
1815
2237
|
try:
|
1816
2238
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|