sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +95 -49
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +72 -8
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +33 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +212 -117
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +258 -782
- sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
- sglang/srt/managers/tokenizer_manager.py +7 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +63 -34
- sglang/srt/mem_cache/memory_pool.py +112 -46
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/metrics/collector.py +8 -0
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +12 -8
- sglang/srt/model_executor/model_runner.py +153 -134
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +25 -19
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +37 -15
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +19 -20
- sglang/srt/speculative/build_eagle_tree.py +6 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
- sglang/srt/speculative/eagle_utils.py +2 -1
- sglang/srt/speculative/eagle_worker.py +109 -38
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -41,8 +41,6 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
42
42
|
from sglang.srt.managers.io_struct import (
|
43
43
|
AbortReq,
|
44
|
-
BatchEmbeddingOut,
|
45
|
-
BatchTokenIDOut,
|
46
44
|
CloseSessionReqInput,
|
47
45
|
FlushCacheReq,
|
48
46
|
GetInternalStateReq,
|
@@ -74,7 +72,6 @@ from sglang.srt.managers.io_struct import (
|
|
74
72
|
)
|
75
73
|
from sglang.srt.managers.schedule_batch import (
|
76
74
|
FINISH_ABORT,
|
77
|
-
BaseFinishReason,
|
78
75
|
ImageInputs,
|
79
76
|
Req,
|
80
77
|
ScheduleBatch,
|
@@ -85,6 +82,9 @@ from sglang.srt.managers.schedule_policy import (
|
|
85
82
|
PrefillAdder,
|
86
83
|
SchedulePolicy,
|
87
84
|
)
|
85
|
+
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
86
|
+
SchedulerOutputProcessorMixin,
|
87
|
+
)
|
88
88
|
from sglang.srt.managers.session_controller import Session
|
89
89
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
90
90
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
93
93
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
94
94
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
95
95
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
96
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
96
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
97
97
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
98
98
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
99
99
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
|
|
103
103
|
crash_on_warnings,
|
104
104
|
get_bool_env_var,
|
105
105
|
get_zmq_socket,
|
106
|
+
kill_itself_when_parent_died,
|
106
107
|
pyspy_dump_schedulers,
|
107
108
|
set_gpu_proc_affinity,
|
108
109
|
set_random_seed,
|
@@ -132,7 +133,7 @@ class EmbeddingBatchResult:
|
|
132
133
|
bid: int
|
133
134
|
|
134
135
|
|
135
|
-
class Scheduler:
|
136
|
+
class Scheduler(SchedulerOutputProcessorMixin):
|
136
137
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
137
138
|
|
138
139
|
def __init__(
|
@@ -159,17 +160,7 @@ class Scheduler:
|
|
159
160
|
)
|
160
161
|
self.gpu_id = gpu_id
|
161
162
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
162
|
-
self.
|
163
|
-
(
|
164
|
-
self.server_args.speculative_num_draft_tokens
|
165
|
-
+ (
|
166
|
-
self.server_args.speculative_eagle_topk
|
167
|
-
* self.server_args.speculative_num_draft_tokens
|
168
|
-
)
|
169
|
-
)
|
170
|
-
if not self.spec_algorithm.is_none()
|
171
|
-
else 1
|
172
|
-
)
|
163
|
+
self.page_size = server_args.page_size
|
173
164
|
|
174
165
|
# Distributed rank info
|
175
166
|
self.dp_size = server_args.dp_size
|
@@ -208,42 +199,12 @@ class Scheduler:
|
|
208
199
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
209
200
|
|
210
201
|
# Init tokenizer
|
211
|
-
self.
|
212
|
-
server_args.model_path,
|
213
|
-
trust_remote_code=server_args.trust_remote_code,
|
214
|
-
revision=server_args.revision,
|
215
|
-
context_length=server_args.context_length,
|
216
|
-
model_override_args=server_args.json_model_override_args,
|
217
|
-
is_embedding=server_args.is_embedding,
|
218
|
-
dtype=server_args.dtype,
|
219
|
-
quantization=server_args.quantization,
|
220
|
-
)
|
221
|
-
self.is_generation = self.model_config.is_generation
|
222
|
-
|
223
|
-
if server_args.skip_tokenizer_init:
|
224
|
-
self.tokenizer = self.processor = None
|
225
|
-
else:
|
226
|
-
if self.model_config.is_multimodal:
|
227
|
-
self.processor = get_processor(
|
228
|
-
server_args.tokenizer_path,
|
229
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
230
|
-
trust_remote_code=server_args.trust_remote_code,
|
231
|
-
revision=server_args.revision,
|
232
|
-
)
|
233
|
-
self.tokenizer = self.processor.tokenizer
|
234
|
-
else:
|
235
|
-
self.tokenizer = get_tokenizer(
|
236
|
-
server_args.tokenizer_path,
|
237
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
238
|
-
trust_remote_code=server_args.trust_remote_code,
|
239
|
-
revision=server_args.revision,
|
240
|
-
)
|
202
|
+
self.init_tokenizer()
|
241
203
|
|
242
204
|
# Check whether overlap can be enabled
|
243
205
|
if not self.is_generation:
|
244
206
|
self.enable_overlap = False
|
245
207
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
246
|
-
|
247
208
|
if self.model_config.is_multimodal:
|
248
209
|
self.enable_overlap = False
|
249
210
|
logger.info("Overlap scheduler is disabled for multimodal models.")
|
@@ -274,10 +235,8 @@ class Scheduler:
|
|
274
235
|
target_worker=self.tp_worker,
|
275
236
|
dp_rank=dp_rank,
|
276
237
|
)
|
277
|
-
self.prefill_only_one_req = True
|
278
238
|
else:
|
279
239
|
self.draft_worker = None
|
280
|
-
self.prefill_only_one_req = False
|
281
240
|
|
282
241
|
# Get token and memory info from the model worker
|
283
242
|
(
|
@@ -309,64 +268,28 @@ class Scheduler:
|
|
309
268
|
)
|
310
269
|
|
311
270
|
# Init memory pool and cache
|
312
|
-
self.
|
313
|
-
self.tp_worker.get_memory_pool()
|
314
|
-
)
|
315
|
-
|
316
|
-
if (
|
317
|
-
server_args.chunked_prefill_size is not None
|
318
|
-
and server_args.disable_radix_cache
|
319
|
-
):
|
320
|
-
self.tree_cache = ChunkCache(
|
321
|
-
req_to_token_pool=self.req_to_token_pool,
|
322
|
-
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
323
|
-
)
|
324
|
-
else:
|
325
|
-
if self.enable_hierarchical_cache:
|
326
|
-
self.tree_cache = HiRadixCache(
|
327
|
-
req_to_token_pool=self.req_to_token_pool,
|
328
|
-
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
329
|
-
)
|
330
|
-
else:
|
331
|
-
self.tree_cache = RadixCache(
|
332
|
-
req_to_token_pool=self.req_to_token_pool,
|
333
|
-
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
334
|
-
disable=server_args.disable_radix_cache,
|
335
|
-
)
|
336
|
-
|
337
|
-
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
271
|
+
self.init_memory_pool_and_cache()
|
338
272
|
|
339
273
|
# Init running status
|
340
274
|
self.waiting_queue: List[Req] = []
|
341
|
-
self.staging_reqs = {}
|
342
275
|
# The running decoding batch for continuous batching
|
343
|
-
self.running_batch:
|
276
|
+
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
|
344
277
|
# The current forward batch
|
345
278
|
self.cur_batch: Optional[ScheduleBatch] = None
|
346
|
-
# The
|
279
|
+
# The last forward batch
|
347
280
|
self.last_batch: Optional[ScheduleBatch] = None
|
348
281
|
self.forward_ct = 0
|
349
282
|
self.forward_ct_decode = 0
|
350
283
|
self.num_generated_tokens = 0
|
351
|
-
self.
|
352
|
-
self.spec_num_total_forward_ct = 0
|
353
|
-
self.cum_spec_accept_length = 0
|
354
|
-
self.cum_spec_accept_count = 0
|
284
|
+
self.num_prefill_tokens = 0
|
355
285
|
self.last_decode_stats_tic = time.time()
|
286
|
+
self.last_prefill_stats_tic = time.time()
|
356
287
|
self.return_health_check_ct = 0
|
357
288
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
358
289
|
if self.device == "cpu":
|
359
290
|
self.current_stream.synchronize = lambda: None # No-op for CPU
|
360
291
|
|
361
|
-
#
|
362
|
-
# The largest prefill length of a single request
|
363
|
-
self._largest_prefill_len: int = 0
|
364
|
-
# The largest context length (prefill + generation) of a single request
|
365
|
-
self._largest_prefill_decode_len: int = 0
|
366
|
-
self.last_gen_throughput: float = 0.0
|
367
|
-
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
368
|
-
|
369
|
-
# Session info
|
292
|
+
# Init session info
|
370
293
|
self.sessions: Dict[str, Session] = {}
|
371
294
|
|
372
295
|
# Init chunked prefill
|
@@ -387,11 +310,15 @@ class Scheduler:
|
|
387
310
|
else:
|
388
311
|
self.grammar_backend = None
|
389
312
|
|
390
|
-
# Init new token estimation
|
313
|
+
# Init schedule policy and new token estimation
|
314
|
+
self.policy = SchedulePolicy(
|
315
|
+
self.schedule_policy,
|
316
|
+
self.tree_cache,
|
317
|
+
self.enable_hierarchical_cache,
|
318
|
+
)
|
391
319
|
assert (
|
392
320
|
server_args.schedule_conservativeness >= 0
|
393
321
|
), "Invalid schedule_conservativeness"
|
394
|
-
|
395
322
|
self.init_new_token_ratio = min(
|
396
323
|
global_config.default_init_new_token_ratio
|
397
324
|
* server_args.schedule_conservativeness,
|
@@ -407,11 +334,6 @@ class Scheduler:
|
|
407
334
|
) / global_config.default_new_token_ratio_decay_steps
|
408
335
|
self.new_token_ratio = self.init_new_token_ratio
|
409
336
|
|
410
|
-
# Tell whether the current running batch is full so that we can skip
|
411
|
-
# the check of whether to prefill new requests.
|
412
|
-
# This is an optimization to reduce the overhead of the prefill check.
|
413
|
-
self.batch_is_full = False
|
414
|
-
|
415
337
|
# Init watchdog thread
|
416
338
|
self.watchdog_timeout = server_args.watchdog_timeout
|
417
339
|
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
@@ -430,14 +352,7 @@ class Scheduler:
|
|
430
352
|
self.profiler_target_forward_ct: Optional[int] = None
|
431
353
|
|
432
354
|
# Init metrics stats
|
433
|
-
self.
|
434
|
-
if self.enable_metrics:
|
435
|
-
self.metrics_collector = SchedulerMetricsCollector(
|
436
|
-
labels={
|
437
|
-
"model_name": self.server_args.served_model_name,
|
438
|
-
# TODO: Add lora name/path in the future,
|
439
|
-
},
|
440
|
-
)
|
355
|
+
self.init_metrics()
|
441
356
|
|
442
357
|
# Init request dispatcher
|
443
358
|
self._request_dispatcher = TypeBasedDispatcher(
|
@@ -460,39 +375,107 @@ class Scheduler:
|
|
460
375
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
461
376
|
(ProfileReq, self.profile),
|
462
377
|
(GetInternalStateReq, self.get_internal_state),
|
378
|
+
(SetInternalStateReq, self.set_internal_state),
|
463
379
|
]
|
464
380
|
)
|
465
381
|
|
466
|
-
def
|
467
|
-
|
468
|
-
self.watchdog_last_forward_ct = 0
|
469
|
-
self.watchdog_last_time = time.time()
|
382
|
+
def init_tokenizer(self):
|
383
|
+
server_args = self.server_args
|
470
384
|
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
385
|
+
self.model_config = ModelConfig(
|
386
|
+
server_args.model_path,
|
387
|
+
trust_remote_code=server_args.trust_remote_code,
|
388
|
+
revision=server_args.revision,
|
389
|
+
context_length=server_args.context_length,
|
390
|
+
model_override_args=server_args.json_model_override_args,
|
391
|
+
is_embedding=server_args.is_embedding,
|
392
|
+
dtype=server_args.dtype,
|
393
|
+
quantization=server_args.quantization,
|
394
|
+
)
|
395
|
+
self.is_generation = self.model_config.is_generation
|
482
396
|
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
397
|
+
if server_args.skip_tokenizer_init:
|
398
|
+
self.tokenizer = self.processor = None
|
399
|
+
else:
|
400
|
+
if self.model_config.is_multimodal:
|
401
|
+
self.processor = get_processor(
|
402
|
+
server_args.tokenizer_path,
|
403
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
404
|
+
trust_remote_code=server_args.trust_remote_code,
|
405
|
+
revision=server_args.revision,
|
406
|
+
)
|
407
|
+
self.tokenizer = self.processor.tokenizer
|
408
|
+
else:
|
409
|
+
self.tokenizer = get_tokenizer(
|
410
|
+
server_args.tokenizer_path,
|
411
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
412
|
+
trust_remote_code=server_args.trust_remote_code,
|
413
|
+
revision=server_args.revision,
|
414
|
+
)
|
415
|
+
|
416
|
+
def init_memory_pool_and_cache(self):
|
417
|
+
server_args = self.server_args
|
418
|
+
|
419
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
420
|
+
self.tp_worker.get_memory_pool()
|
421
|
+
)
|
422
|
+
|
423
|
+
if (
|
424
|
+
server_args.chunked_prefill_size is not None
|
425
|
+
and server_args.disable_radix_cache
|
426
|
+
):
|
427
|
+
self.tree_cache = ChunkCache(
|
428
|
+
req_to_token_pool=self.req_to_token_pool,
|
429
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
430
|
+
)
|
431
|
+
else:
|
432
|
+
if self.enable_hierarchical_cache:
|
433
|
+
self.tree_cache = HiRadixCache(
|
434
|
+
req_to_token_pool=self.req_to_token_pool,
|
435
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
436
|
+
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
|
437
|
+
)
|
438
|
+
else:
|
439
|
+
self.tree_cache = RadixCache(
|
440
|
+
req_to_token_pool=self.req_to_token_pool,
|
441
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
442
|
+
page_size=self.page_size,
|
443
|
+
disable=server_args.disable_radix_cache,
|
444
|
+
)
|
445
|
+
|
446
|
+
self.decode_mem_cache_buf_multiplier = (
|
447
|
+
1
|
448
|
+
if self.spec_algorithm.is_none()
|
449
|
+
else (
|
450
|
+
server_args.speculative_num_draft_tokens
|
451
|
+
+ (
|
452
|
+
server_args.speculative_eagle_topk
|
453
|
+
* server_args.speculative_num_steps
|
454
|
+
)
|
455
|
+
)
|
489
456
|
)
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
self.
|
457
|
+
|
458
|
+
def init_metrics(self):
|
459
|
+
# The largest prefill length of a single request
|
460
|
+
self._largest_prefill_len: int = 0
|
461
|
+
# The largest context length (prefill + generation) of a single request
|
462
|
+
self._largest_prefill_decode_len: int = 0
|
463
|
+
self.last_gen_throughput: float = 0.0
|
464
|
+
self.last_input_throughput: float = 0.0
|
465
|
+
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
466
|
+
self.spec_num_total_accepted_tokens = 0
|
467
|
+
self.spec_num_total_forward_ct = 0
|
468
|
+
self.cum_spec_accept_length = 0
|
469
|
+
self.cum_spec_accept_count = 0
|
470
|
+
self.stats = SchedulerStats()
|
471
|
+
if self.enable_metrics:
|
472
|
+
engine_type = "unified"
|
473
|
+
self.metrics_collector = SchedulerMetricsCollector(
|
474
|
+
labels={
|
475
|
+
"model_name": self.server_args.served_model_name,
|
476
|
+
"engine_type": engine_type,
|
477
|
+
},
|
478
|
+
)
|
496
479
|
|
497
480
|
@torch.no_grad()
|
498
481
|
def event_loop_normal(self):
|
@@ -508,7 +491,7 @@ class Scheduler:
|
|
508
491
|
result = self.run_batch(batch)
|
509
492
|
self.process_batch_result(batch, result)
|
510
493
|
else:
|
511
|
-
# When the server is idle,
|
494
|
+
# When the server is idle, do self-check and re-init some states
|
512
495
|
self.check_memory()
|
513
496
|
self.new_token_ratio = self.init_new_token_ratio
|
514
497
|
|
@@ -548,7 +531,7 @@ class Scheduler:
|
|
548
531
|
)
|
549
532
|
self.process_batch_result(tmp_batch, tmp_result)
|
550
533
|
elif batch is None:
|
551
|
-
# When the server is idle,
|
534
|
+
# When the server is idle, do self-check and re-init some states
|
552
535
|
self.check_memory()
|
553
536
|
self.new_token_ratio = self.init_new_token_ratio
|
554
537
|
|
@@ -609,7 +592,7 @@ class Scheduler:
|
|
609
592
|
for recv_req in recv_reqs:
|
610
593
|
# If it is a health check generation request and there are running requests, ignore it.
|
611
594
|
if is_health_check_generate_req(recv_req) and (
|
612
|
-
self.chunked_req is not None or self.running_batch
|
595
|
+
self.chunked_req is not None or not self.running_batch.is_empty()
|
613
596
|
):
|
614
597
|
self.return_health_check_ct += 1
|
615
598
|
continue
|
@@ -789,6 +772,30 @@ class Scheduler:
|
|
789
772
|
)
|
790
773
|
req.tokenizer = self.tokenizer
|
791
774
|
|
775
|
+
# Handle multimodal inputs
|
776
|
+
if recv_req.image_inputs is not None:
|
777
|
+
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
778
|
+
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
779
|
+
req.origin_input_ids = self.pad_input_ids_func(
|
780
|
+
req.origin_input_ids, image_inputs
|
781
|
+
)
|
782
|
+
req.extend_image_inputs(image_inputs)
|
783
|
+
|
784
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
785
|
+
error_msg = (
|
786
|
+
"Multimodal prompt is too long after expanding multimodal tokens. "
|
787
|
+
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
788
|
+
)
|
789
|
+
logger.error(error_msg)
|
790
|
+
req.origin_input_ids = [0]
|
791
|
+
req.image_inputs = None
|
792
|
+
req.sampling_params.max_new_tokens = 0
|
793
|
+
req.finished_reason = FINISH_ABORT(
|
794
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
795
|
+
)
|
796
|
+
self.waiting_queue.append(req)
|
797
|
+
return
|
798
|
+
|
792
799
|
# Validate prompts length
|
793
800
|
error_msg = validate_input_length(
|
794
801
|
req,
|
@@ -809,6 +816,11 @@ class Scheduler:
|
|
809
816
|
can_run_list: List[Req],
|
810
817
|
running_bs: int,
|
811
818
|
):
|
819
|
+
gap_latency = time.time() - self.last_prefill_stats_tic
|
820
|
+
self.last_prefill_stats_tic = time.time()
|
821
|
+
self.last_input_throughput = self.num_prefill_tokens / gap_latency
|
822
|
+
self.num_prefill_tokens = 0
|
823
|
+
|
812
824
|
num_used = self.max_total_num_tokens - (
|
813
825
|
self.token_to_kv_pool_allocator.available_size()
|
814
826
|
+ self.tree_cache.evictable_size()
|
@@ -844,7 +856,7 @@ class Scheduler:
|
|
844
856
|
self.last_decode_stats_tic = time.time()
|
845
857
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
846
858
|
self.num_generated_tokens = 0
|
847
|
-
num_running_reqs = len(self.running_batch.reqs)
|
859
|
+
num_running_reqs = len(self.running_batch.reqs)
|
848
860
|
num_used = self.max_total_num_tokens - (
|
849
861
|
self.token_to_kv_pool_allocator.available_size()
|
850
862
|
+ self.tree_cache.evictable_size()
|
@@ -908,8 +920,10 @@ class Scheduler:
|
|
908
920
|
)
|
909
921
|
if memory_leak:
|
910
922
|
msg = (
|
911
|
-
"KV cache pool leak detected!"
|
923
|
+
"KV cache pool leak detected! "
|
912
924
|
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
925
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
926
|
+
f"{self.tree_cache.evictable_size()=}\n"
|
913
927
|
)
|
914
928
|
warnings.warn(msg)
|
915
929
|
if crash_on_warnings():
|
@@ -932,10 +946,10 @@ class Scheduler:
|
|
932
946
|
):
|
933
947
|
# During idle time, also collect metrics every 30 seconds.
|
934
948
|
num_used = self.max_total_num_tokens - (
|
935
|
-
self.
|
949
|
+
self.token_to_kv_pool_allocator.available_size()
|
936
950
|
+ self.tree_cache.evictable_size()
|
937
951
|
)
|
938
|
-
num_running_reqs = len(self.running_batch.reqs)
|
952
|
+
num_running_reqs = len(self.running_batch.reqs)
|
939
953
|
self.stats.num_running_reqs = num_running_reqs
|
940
954
|
self.stats.num_used_tokens = num_used
|
941
955
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
@@ -953,14 +967,20 @@ class Scheduler:
|
|
953
967
|
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
954
968
|
# chunked request keeps its rid but will get a new req_pool_idx
|
955
969
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
956
|
-
self.batch_is_full = False
|
970
|
+
self.running_batch.batch_is_full = False
|
957
971
|
|
972
|
+
# Filter batch
|
973
|
+
last_bs = self.last_batch.batch_size()
|
958
974
|
self.last_batch.filter_batch()
|
975
|
+
if self.last_batch.batch_size() < last_bs:
|
976
|
+
self.running_batch.batch_is_full = False
|
977
|
+
|
978
|
+
# Merge the new batch into the running batch
|
959
979
|
if not self.last_batch.is_empty():
|
960
|
-
if self.running_batch
|
980
|
+
if self.running_batch.is_empty():
|
961
981
|
self.running_batch = self.last_batch
|
962
982
|
else:
|
963
|
-
#
|
983
|
+
# Merge running_batch with prefill batch
|
964
984
|
self.running_batch.merge_batch(self.last_batch)
|
965
985
|
|
966
986
|
new_batch = self.get_new_batch_prefill()
|
@@ -969,11 +989,11 @@ class Scheduler:
|
|
969
989
|
ret = new_batch
|
970
990
|
else:
|
971
991
|
# Run decode
|
972
|
-
if self.running_batch
|
973
|
-
ret = None
|
974
|
-
else:
|
992
|
+
if not self.running_batch.is_empty():
|
975
993
|
self.running_batch = self.update_running_batch(self.running_batch)
|
976
|
-
ret = self.running_batch
|
994
|
+
ret = self.running_batch if not self.running_batch.is_empty() else None
|
995
|
+
else:
|
996
|
+
ret = None
|
977
997
|
|
978
998
|
# Handle DP attention
|
979
999
|
if self.server_args.enable_dp_attention:
|
@@ -988,15 +1008,20 @@ class Scheduler:
|
|
988
1008
|
|
989
1009
|
# Handle the cases where prefill is not allowed
|
990
1010
|
if (
|
991
|
-
self.batch_is_full or len(self.waiting_queue) == 0
|
1011
|
+
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
992
1012
|
) and self.chunked_req is None:
|
993
1013
|
return None
|
994
1014
|
|
995
|
-
running_bs = len(self.running_batch.reqs)
|
1015
|
+
running_bs = len(self.running_batch.reqs)
|
996
1016
|
if running_bs >= self.max_running_requests:
|
997
|
-
self.batch_is_full = True
|
1017
|
+
self.running_batch.batch_is_full = True
|
998
1018
|
return None
|
999
1019
|
|
1020
|
+
if self.enable_hierarchical_cache:
|
1021
|
+
# check for completion of hierarchical cache activities to release memory
|
1022
|
+
self.tree_cache.writing_check()
|
1023
|
+
self.tree_cache.loading_check()
|
1024
|
+
|
1000
1025
|
# Get priority queue
|
1001
1026
|
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
1002
1027
|
|
@@ -1011,17 +1036,13 @@ class Scheduler:
|
|
1011
1036
|
running_bs if self.is_mixed_chunk else 0,
|
1012
1037
|
)
|
1013
1038
|
|
1014
|
-
|
1015
|
-
if is_chunked:
|
1039
|
+
if self.chunked_req is not None:
|
1016
1040
|
self.chunked_req.init_next_round_input()
|
1017
1041
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
1018
1042
|
|
1019
1043
|
if self.lora_paths:
|
1020
|
-
lora_set = (
|
1021
|
-
|
1022
|
-
if self.running_batch is not None
|
1023
|
-
else set([])
|
1024
|
-
)
|
1044
|
+
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
1045
|
+
|
1025
1046
|
# Get requests from the waiting queue to a new prefill batch
|
1026
1047
|
for req in self.waiting_queue:
|
1027
1048
|
if (
|
@@ -1033,51 +1054,33 @@ class Scheduler:
|
|
1033
1054
|
)
|
1034
1055
|
> self.max_loras_per_batch
|
1035
1056
|
):
|
1036
|
-
self.batch_is_full = True
|
1057
|
+
self.running_batch.batch_is_full = True
|
1037
1058
|
break
|
1038
1059
|
|
1039
1060
|
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
|
1040
|
-
self.batch_is_full = True
|
1061
|
+
self.running_batch.batch_is_full = True
|
1041
1062
|
break
|
1042
1063
|
|
1043
|
-
req.init_next_round_input(
|
1064
|
+
req.init_next_round_input(
|
1065
|
+
None if prefix_computed else self.tree_cache,
|
1066
|
+
self.enable_hierarchical_cache,
|
1067
|
+
)
|
1044
1068
|
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
1049
|
-
req.last_node,
|
1050
|
-
req.prefix_indices,
|
1051
|
-
adder.rem_total_tokens,
|
1052
|
-
)
|
1053
|
-
if req.last_node.loading:
|
1054
|
-
# to prevent frequent cache invalidation
|
1055
|
-
if req.rid in self.staging_reqs:
|
1056
|
-
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
1057
|
-
self.tree_cache.inc_lock_ref(req.last_node)
|
1058
|
-
self.staging_reqs[req.rid] = req.last_node
|
1059
|
-
continue
|
1060
|
-
elif req.last_node.loading:
|
1061
|
-
if not self.tree_cache.loading_complete(req.last_node):
|
1062
|
-
continue
|
1063
|
-
|
1064
|
-
if req.rid in self.staging_reqs:
|
1065
|
-
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
1066
|
-
del self.staging_reqs[req.rid]
|
1067
|
-
|
1068
|
-
res = adder.add_one_req(req, self.chunked_req)
|
1069
|
+
res = adder.add_one_req(
|
1070
|
+
req, self.chunked_req, self.enable_hierarchical_cache
|
1071
|
+
)
|
1069
1072
|
if res != AddReqResult.CONTINUE:
|
1070
1073
|
if res == AddReqResult.NO_TOKEN:
|
1071
1074
|
if self.enable_hierarchical_cache:
|
1072
1075
|
# Set batch_is_full after making sure there are requests that can be served
|
1073
|
-
self.batch_is_full = len(
|
1076
|
+
self.running_batch.batch_is_full = len(
|
1077
|
+
adder.can_run_list
|
1078
|
+
) > 0 or (
|
1074
1079
|
self.running_batch is not None
|
1075
1080
|
and not self.running_batch.is_empty()
|
1076
1081
|
)
|
1077
1082
|
else:
|
1078
|
-
self.batch_is_full = True
|
1079
|
-
break
|
1080
|
-
if self.prefill_only_one_req:
|
1083
|
+
self.running_batch.batch_is_full = True
|
1081
1084
|
break
|
1082
1085
|
|
1083
1086
|
# Update waiting queue
|
@@ -1088,6 +1091,9 @@ class Scheduler:
|
|
1088
1091
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
1089
1092
|
]
|
1090
1093
|
|
1094
|
+
if self.enable_hierarchical_cache:
|
1095
|
+
self.tree_cache.read_to_load_cache()
|
1096
|
+
|
1091
1097
|
if adder.new_chunked_req is not None:
|
1092
1098
|
assert self.chunked_req is None
|
1093
1099
|
self.chunked_req = adder.new_chunked_req
|
@@ -1115,7 +1121,7 @@ class Scheduler:
|
|
1115
1121
|
# Mixed-style chunked prefill
|
1116
1122
|
if (
|
1117
1123
|
self.is_mixed_chunk
|
1118
|
-
and self.running_batch
|
1124
|
+
and not self.running_batch.is_empty()
|
1119
1125
|
and not (new_batch.return_logprob or self.running_batch.return_logprob)
|
1120
1126
|
):
|
1121
1127
|
# TODO (lianmin): support return_logprob + mixed chunked prefill
|
@@ -1124,7 +1130,9 @@ class Scheduler:
|
|
1124
1130
|
self.running_batch.prepare_for_decode()
|
1125
1131
|
new_batch.mix_with_running(self.running_batch)
|
1126
1132
|
new_batch.decoding_reqs = self.running_batch.reqs
|
1127
|
-
self.running_batch =
|
1133
|
+
self.running_batch = ScheduleBatch(
|
1134
|
+
reqs=[], batch_is_full=self.running_batch.batch_is_full
|
1135
|
+
)
|
1128
1136
|
else:
|
1129
1137
|
new_batch.decoding_reqs = None
|
1130
1138
|
|
@@ -1136,8 +1144,8 @@ class Scheduler:
|
|
1136
1144
|
|
1137
1145
|
batch.filter_batch()
|
1138
1146
|
if batch.is_empty():
|
1139
|
-
|
1140
|
-
return
|
1147
|
+
batch.batch_is_full = False
|
1148
|
+
return batch
|
1141
1149
|
|
1142
1150
|
# Check if decode out of memory
|
1143
1151
|
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
@@ -1161,7 +1169,7 @@ class Scheduler:
|
|
1161
1169
|
)
|
1162
1170
|
|
1163
1171
|
if batch.batch_size() < initial_bs:
|
1164
|
-
|
1172
|
+
batch.batch_is_full = False
|
1165
1173
|
|
1166
1174
|
# Update batch tensors
|
1167
1175
|
batch.prepare_for_decode()
|
@@ -1180,6 +1188,7 @@ class Scheduler:
|
|
1180
1188
|
):
|
1181
1189
|
self.stop_profile()
|
1182
1190
|
|
1191
|
+
# Run forward
|
1183
1192
|
if self.is_generation:
|
1184
1193
|
if self.spec_algorithm.is_none():
|
1185
1194
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -1200,6 +1209,7 @@ class Scheduler:
|
|
1200
1209
|
self.spec_num_total_forward_ct += batch.batch_size()
|
1201
1210
|
self.num_generated_tokens += num_accepted_tokens
|
1202
1211
|
batch.output_ids = next_token_ids
|
1212
|
+
|
1203
1213
|
# These 2 values are needed for processing the output, but the values can be
|
1204
1214
|
# modified by overlap schedule. So we have to copy them here so that
|
1205
1215
|
# we can use the correct values in output processing.
|
@@ -1233,10 +1243,7 @@ class Scheduler:
|
|
1233
1243
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1234
1244
|
):
|
1235
1245
|
if batch.forward_mode.is_decode():
|
1236
|
-
assert isinstance(result, GenerationBatchResult)
|
1237
1246
|
self.process_batch_result_decode(batch, result)
|
1238
|
-
if batch.is_empty():
|
1239
|
-
self.running_batch = None
|
1240
1247
|
elif batch.forward_mode.is_extend():
|
1241
1248
|
self.process_batch_result_prefill(batch, result)
|
1242
1249
|
elif batch.forward_mode.is_idle():
|
@@ -1258,571 +1265,6 @@ class Scheduler:
|
|
1258
1265
|
self.return_health_check_ct -= 1
|
1259
1266
|
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
1260
1267
|
|
1261
|
-
def process_batch_result_prefill(
|
1262
|
-
self,
|
1263
|
-
batch: ScheduleBatch,
|
1264
|
-
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1265
|
-
):
|
1266
|
-
skip_stream_req = None
|
1267
|
-
|
1268
|
-
if self.is_generation:
|
1269
|
-
(
|
1270
|
-
logits_output,
|
1271
|
-
next_token_ids,
|
1272
|
-
extend_input_len_per_req,
|
1273
|
-
extend_logprob_start_len_per_req,
|
1274
|
-
bid,
|
1275
|
-
) = (
|
1276
|
-
result.logits_output,
|
1277
|
-
result.next_token_ids,
|
1278
|
-
result.extend_input_len_per_req,
|
1279
|
-
result.extend_logprob_start_len_per_req,
|
1280
|
-
result.bid,
|
1281
|
-
)
|
1282
|
-
|
1283
|
-
if self.enable_overlap:
|
1284
|
-
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
1285
|
-
else:
|
1286
|
-
# Move next_token_ids and logprobs to cpu
|
1287
|
-
next_token_ids = next_token_ids.tolist()
|
1288
|
-
if batch.return_logprob:
|
1289
|
-
if logits_output.next_token_logprobs is not None:
|
1290
|
-
logits_output.next_token_logprobs = (
|
1291
|
-
logits_output.next_token_logprobs.tolist()
|
1292
|
-
)
|
1293
|
-
if logits_output.input_token_logprobs is not None:
|
1294
|
-
logits_output.input_token_logprobs = tuple(
|
1295
|
-
logits_output.input_token_logprobs.tolist()
|
1296
|
-
)
|
1297
|
-
|
1298
|
-
hidden_state_offset = 0
|
1299
|
-
|
1300
|
-
# Check finish conditions
|
1301
|
-
logprob_pt = 0
|
1302
|
-
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
1303
|
-
if req.is_retracted:
|
1304
|
-
continue
|
1305
|
-
|
1306
|
-
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
1307
|
-
# Free the one delayed token for the mixed decode batch
|
1308
|
-
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
1309
|
-
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
1310
|
-
continue
|
1311
|
-
|
1312
|
-
if req.is_chunked <= 0:
|
1313
|
-
# req output_ids are set here
|
1314
|
-
req.output_ids.append(next_token_id)
|
1315
|
-
req.check_finished()
|
1316
|
-
|
1317
|
-
if req.finished():
|
1318
|
-
self.tree_cache.cache_finished_req(req)
|
1319
|
-
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
1320
|
-
# This updates radix so others can match
|
1321
|
-
self.tree_cache.cache_unfinished_req(req)
|
1322
|
-
|
1323
|
-
if req.return_logprob:
|
1324
|
-
assert extend_logprob_start_len_per_req is not None
|
1325
|
-
assert extend_input_len_per_req is not None
|
1326
|
-
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
1327
|
-
extend_input_len = extend_input_len_per_req[i]
|
1328
|
-
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
1329
|
-
self.add_logprob_return_values(
|
1330
|
-
i,
|
1331
|
-
req,
|
1332
|
-
logprob_pt,
|
1333
|
-
next_token_ids,
|
1334
|
-
num_input_logprobs,
|
1335
|
-
logits_output,
|
1336
|
-
)
|
1337
|
-
logprob_pt += num_input_logprobs
|
1338
|
-
|
1339
|
-
if (
|
1340
|
-
req.return_hidden_states
|
1341
|
-
and logits_output.hidden_states is not None
|
1342
|
-
):
|
1343
|
-
req.hidden_states.append(
|
1344
|
-
logits_output.hidden_states[
|
1345
|
-
hidden_state_offset : (
|
1346
|
-
hidden_state_offset := hidden_state_offset
|
1347
|
-
+ len(req.origin_input_ids)
|
1348
|
-
)
|
1349
|
-
]
|
1350
|
-
.cpu()
|
1351
|
-
.clone()
|
1352
|
-
)
|
1353
|
-
|
1354
|
-
if req.grammar is not None:
|
1355
|
-
req.grammar.accept_token(next_token_id)
|
1356
|
-
req.grammar.finished = req.finished()
|
1357
|
-
else:
|
1358
|
-
# being chunked reqs' prefill is not finished
|
1359
|
-
req.is_chunked -= 1
|
1360
|
-
# There is only at most one request being currently chunked.
|
1361
|
-
# Because this request does not finish prefill,
|
1362
|
-
# we don't want to stream the request currently being chunked.
|
1363
|
-
skip_stream_req = req
|
1364
|
-
|
1365
|
-
# Incrementally update input logprobs.
|
1366
|
-
if req.return_logprob:
|
1367
|
-
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
1368
|
-
extend_input_len = extend_input_len_per_req[i]
|
1369
|
-
if extend_logprob_start_len < extend_input_len:
|
1370
|
-
# Update input logprobs.
|
1371
|
-
num_input_logprobs = (
|
1372
|
-
extend_input_len - extend_logprob_start_len
|
1373
|
-
)
|
1374
|
-
self.add_input_logprob_return_values(
|
1375
|
-
i,
|
1376
|
-
req,
|
1377
|
-
logits_output,
|
1378
|
-
logprob_pt,
|
1379
|
-
num_input_logprobs,
|
1380
|
-
last_prefill_chunk=False,
|
1381
|
-
)
|
1382
|
-
logprob_pt += num_input_logprobs
|
1383
|
-
|
1384
|
-
if batch.next_batch_sampling_info:
|
1385
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1386
|
-
self.current_stream.synchronize()
|
1387
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1388
|
-
|
1389
|
-
else: # embedding or reward model
|
1390
|
-
embeddings, bid = result.embeddings, result.bid
|
1391
|
-
embeddings = embeddings.tolist()
|
1392
|
-
|
1393
|
-
# Check finish conditions
|
1394
|
-
for i, req in enumerate(batch.reqs):
|
1395
|
-
if req.is_retracted:
|
1396
|
-
continue
|
1397
|
-
|
1398
|
-
req.embedding = embeddings[i]
|
1399
|
-
if req.is_chunked <= 0:
|
1400
|
-
# Dummy output token for embedding models
|
1401
|
-
req.output_ids.append(0)
|
1402
|
-
req.check_finished()
|
1403
|
-
|
1404
|
-
if req.finished():
|
1405
|
-
self.tree_cache.cache_finished_req(req)
|
1406
|
-
else:
|
1407
|
-
self.tree_cache.cache_unfinished_req(req)
|
1408
|
-
else:
|
1409
|
-
# being chunked reqs' prefill is not finished
|
1410
|
-
req.is_chunked -= 1
|
1411
|
-
|
1412
|
-
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
1413
|
-
|
1414
|
-
def process_batch_result_decode(
|
1415
|
-
self,
|
1416
|
-
batch: ScheduleBatch,
|
1417
|
-
result: GenerationBatchResult,
|
1418
|
-
):
|
1419
|
-
logits_output, next_token_ids, bid = (
|
1420
|
-
result.logits_output,
|
1421
|
-
result.next_token_ids,
|
1422
|
-
result.bid,
|
1423
|
-
)
|
1424
|
-
self.num_generated_tokens += len(batch.reqs)
|
1425
|
-
|
1426
|
-
if self.enable_overlap:
|
1427
|
-
assert batch.spec_algorithm.is_none()
|
1428
|
-
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
1429
|
-
next_token_logprobs = logits_output.next_token_logprobs
|
1430
|
-
elif batch.spec_algorithm.is_none():
|
1431
|
-
# spec decoding handles output logprobs inside verify process.
|
1432
|
-
next_token_ids = next_token_ids.tolist()
|
1433
|
-
if batch.return_logprob:
|
1434
|
-
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
1435
|
-
|
1436
|
-
self.token_to_kv_pool_allocator.free_group_begin()
|
1437
|
-
|
1438
|
-
# Check finish condition
|
1439
|
-
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
1440
|
-
# We should ignore using next_token_ids for spec decoding cases.
|
1441
|
-
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
1442
|
-
if req.is_retracted:
|
1443
|
-
continue
|
1444
|
-
|
1445
|
-
if self.enable_overlap and req.finished():
|
1446
|
-
# Free the one delayed token
|
1447
|
-
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
1448
|
-
continue
|
1449
|
-
|
1450
|
-
if batch.spec_algorithm.is_none():
|
1451
|
-
# speculative worker will solve the output_ids in speculative decoding
|
1452
|
-
req.output_ids.append(next_token_id)
|
1453
|
-
|
1454
|
-
req.check_finished()
|
1455
|
-
if req.finished():
|
1456
|
-
self.tree_cache.cache_finished_req(req)
|
1457
|
-
|
1458
|
-
if req.return_logprob and batch.spec_algorithm.is_none():
|
1459
|
-
# speculative worker handles logprob in speculative decoding
|
1460
|
-
req.output_token_logprobs_val.append(next_token_logprobs[i])
|
1461
|
-
req.output_token_logprobs_idx.append(next_token_id)
|
1462
|
-
if req.top_logprobs_num > 0:
|
1463
|
-
req.output_top_logprobs_val.append(
|
1464
|
-
logits_output.next_token_top_logprobs_val[i]
|
1465
|
-
)
|
1466
|
-
req.output_top_logprobs_idx.append(
|
1467
|
-
logits_output.next_token_top_logprobs_idx[i]
|
1468
|
-
)
|
1469
|
-
if req.token_ids_logprob is not None:
|
1470
|
-
req.output_token_ids_logprobs_val.append(
|
1471
|
-
logits_output.next_token_token_ids_logprobs_val[i]
|
1472
|
-
)
|
1473
|
-
req.output_token_ids_logprobs_idx.append(
|
1474
|
-
logits_output.next_token_token_ids_logprobs_idx[i]
|
1475
|
-
)
|
1476
|
-
|
1477
|
-
if req.return_hidden_states and logits_output.hidden_states is not None:
|
1478
|
-
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
1479
|
-
|
1480
|
-
if req.grammar is not None and batch.spec_algorithm.is_none():
|
1481
|
-
req.grammar.accept_token(next_token_id)
|
1482
|
-
req.grammar.finished = req.finished()
|
1483
|
-
|
1484
|
-
if batch.next_batch_sampling_info:
|
1485
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1486
|
-
self.current_stream.synchronize()
|
1487
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1488
|
-
self.stream_output(batch.reqs, batch.return_logprob)
|
1489
|
-
|
1490
|
-
self.token_to_kv_pool_allocator.free_group_end()
|
1491
|
-
|
1492
|
-
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
1493
|
-
if (
|
1494
|
-
self.attn_tp_rank == 0
|
1495
|
-
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
1496
|
-
):
|
1497
|
-
self.log_decode_stats()
|
1498
|
-
|
1499
|
-
def add_input_logprob_return_values(
|
1500
|
-
self,
|
1501
|
-
i: int,
|
1502
|
-
req: Req,
|
1503
|
-
output: LogitsProcessorOutput,
|
1504
|
-
logprob_pt: int,
|
1505
|
-
num_input_logprobs: int,
|
1506
|
-
last_prefill_chunk: bool, # If True, it means prefill is finished.
|
1507
|
-
):
|
1508
|
-
"""Incrementally add input logprobs to `req`.
|
1509
|
-
|
1510
|
-
Args:
|
1511
|
-
i: The request index in a batch.
|
1512
|
-
req: The request. Input logprobs inside req are modified as a
|
1513
|
-
consequence of the API
|
1514
|
-
fill_ids: The prefill ids processed.
|
1515
|
-
output: Logit processor output that's used to compute input logprobs
|
1516
|
-
last_prefill_chunk: True if it is the last prefill (when chunked).
|
1517
|
-
Some of input logprob operation should only happen at the last
|
1518
|
-
prefill (e.g., computing input token logprobs).
|
1519
|
-
"""
|
1520
|
-
assert output.input_token_logprobs is not None
|
1521
|
-
if req.input_token_logprobs is None:
|
1522
|
-
req.input_token_logprobs = []
|
1523
|
-
if req.temp_input_top_logprobs_val is None:
|
1524
|
-
req.temp_input_top_logprobs_val = []
|
1525
|
-
if req.temp_input_top_logprobs_idx is None:
|
1526
|
-
req.temp_input_top_logprobs_idx = []
|
1527
|
-
if req.temp_input_token_ids_logprobs_val is None:
|
1528
|
-
req.temp_input_token_ids_logprobs_val = []
|
1529
|
-
if req.temp_input_token_ids_logprobs_idx is None:
|
1530
|
-
req.temp_input_token_ids_logprobs_idx = []
|
1531
|
-
|
1532
|
-
if req.input_token_logprobs_val is not None:
|
1533
|
-
# The input logprob has been already computed. It only happens
|
1534
|
-
# upon retract.
|
1535
|
-
if req.top_logprobs_num > 0:
|
1536
|
-
assert req.input_token_logprobs_val is not None
|
1537
|
-
return
|
1538
|
-
|
1539
|
-
# Important for the performance.
|
1540
|
-
assert isinstance(output.input_token_logprobs, tuple)
|
1541
|
-
input_token_logprobs: Tuple[int] = output.input_token_logprobs
|
1542
|
-
input_token_logprobs = input_token_logprobs[
|
1543
|
-
logprob_pt : logprob_pt + num_input_logprobs
|
1544
|
-
]
|
1545
|
-
req.input_token_logprobs.extend(input_token_logprobs)
|
1546
|
-
|
1547
|
-
if req.top_logprobs_num > 0:
|
1548
|
-
req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
|
1549
|
-
req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
|
1550
|
-
|
1551
|
-
if req.token_ids_logprob is not None:
|
1552
|
-
req.temp_input_token_ids_logprobs_val.append(
|
1553
|
-
output.input_token_ids_logprobs_val[i]
|
1554
|
-
)
|
1555
|
-
req.temp_input_token_ids_logprobs_idx.append(
|
1556
|
-
output.input_token_ids_logprobs_idx[i]
|
1557
|
-
)
|
1558
|
-
|
1559
|
-
if last_prefill_chunk:
|
1560
|
-
input_token_logprobs = req.input_token_logprobs
|
1561
|
-
req.input_token_logprobs = None
|
1562
|
-
assert req.input_token_logprobs_val is None
|
1563
|
-
assert req.input_token_logprobs_idx is None
|
1564
|
-
assert req.input_top_logprobs_val is None
|
1565
|
-
assert req.input_top_logprobs_idx is None
|
1566
|
-
|
1567
|
-
# Compute input_token_logprobs_val
|
1568
|
-
# Always pad the first one with None.
|
1569
|
-
req.input_token_logprobs_val = [None]
|
1570
|
-
req.input_token_logprobs_val.extend(input_token_logprobs)
|
1571
|
-
# The last input logprob is for sampling, so just pop it out.
|
1572
|
-
req.input_token_logprobs_val.pop()
|
1573
|
-
|
1574
|
-
# Compute input_token_logprobs_idx
|
1575
|
-
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
1576
|
-
# Clip the padded hash values from image tokens.
|
1577
|
-
# Otherwise, it will lead to detokenization errors.
|
1578
|
-
input_token_logprobs_idx = [
|
1579
|
-
x if x < self.model_config.vocab_size - 1 else 0
|
1580
|
-
for x in input_token_logprobs_idx
|
1581
|
-
]
|
1582
|
-
req.input_token_logprobs_idx = input_token_logprobs_idx
|
1583
|
-
|
1584
|
-
if req.top_logprobs_num > 0:
|
1585
|
-
req.input_top_logprobs_val = [None]
|
1586
|
-
req.input_top_logprobs_idx = [None]
|
1587
|
-
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
1588
|
-
req.temp_input_token_ids_logprobs_idx
|
1589
|
-
)
|
1590
|
-
for val, idx in zip(
|
1591
|
-
req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
|
1592
|
-
):
|
1593
|
-
req.input_top_logprobs_val.extend(val)
|
1594
|
-
req.input_top_logprobs_idx.extend(idx)
|
1595
|
-
|
1596
|
-
# Last token is a sample token.
|
1597
|
-
req.input_top_logprobs_val.pop()
|
1598
|
-
req.input_top_logprobs_idx.pop()
|
1599
|
-
req.temp_input_top_logprobs_idx = None
|
1600
|
-
req.temp_input_top_logprobs_val = None
|
1601
|
-
|
1602
|
-
if req.token_ids_logprob is not None:
|
1603
|
-
req.input_token_ids_logprobs_val = [None]
|
1604
|
-
req.input_token_ids_logprobs_idx = [None]
|
1605
|
-
|
1606
|
-
for val, idx in zip(
|
1607
|
-
req.temp_input_token_ids_logprobs_val,
|
1608
|
-
req.temp_input_token_ids_logprobs_idx,
|
1609
|
-
strict=True,
|
1610
|
-
):
|
1611
|
-
req.input_token_ids_logprobs_val.extend(val)
|
1612
|
-
req.input_token_ids_logprobs_idx.extend(idx)
|
1613
|
-
|
1614
|
-
# Last token is a sample token.
|
1615
|
-
req.input_token_ids_logprobs_val.pop()
|
1616
|
-
req.input_token_ids_logprobs_idx.pop()
|
1617
|
-
req.temp_input_token_ids_logprobs_idx = None
|
1618
|
-
req.temp_input_token_ids_logprobs_val = None
|
1619
|
-
|
1620
|
-
if req.return_logprob:
|
1621
|
-
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
|
1622
|
-
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
1623
|
-
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
1624
|
-
if req.top_logprobs_num > 0:
|
1625
|
-
assert len(req.input_top_logprobs_val) == relevant_tokens_len
|
1626
|
-
assert len(req.input_top_logprobs_idx) == relevant_tokens_len
|
1627
|
-
if req.token_ids_logprob is not None:
|
1628
|
-
assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
|
1629
|
-
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
1630
|
-
|
1631
|
-
def add_logprob_return_values(
|
1632
|
-
self,
|
1633
|
-
i: int,
|
1634
|
-
req: Req,
|
1635
|
-
pt: int,
|
1636
|
-
next_token_ids: List[int],
|
1637
|
-
num_input_logprobs: int,
|
1638
|
-
output: LogitsProcessorOutput,
|
1639
|
-
):
|
1640
|
-
"""Attach logprobs to the return values."""
|
1641
|
-
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
1642
|
-
req.output_token_logprobs_idx.append(next_token_ids[i])
|
1643
|
-
|
1644
|
-
self.add_input_logprob_return_values(
|
1645
|
-
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
1646
|
-
)
|
1647
|
-
|
1648
|
-
if req.top_logprobs_num > 0:
|
1649
|
-
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
1650
|
-
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
1651
|
-
|
1652
|
-
if req.token_ids_logprob is not None:
|
1653
|
-
req.output_token_ids_logprobs_val.append(
|
1654
|
-
output.next_token_token_ids_logprobs_val[i]
|
1655
|
-
)
|
1656
|
-
req.output_token_ids_logprobs_idx.append(
|
1657
|
-
output.next_token_token_ids_logprobs_idx[i]
|
1658
|
-
)
|
1659
|
-
|
1660
|
-
return num_input_logprobs
|
1661
|
-
|
1662
|
-
def stream_output(
|
1663
|
-
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
1664
|
-
):
|
1665
|
-
"""Stream the output to detokenizer."""
|
1666
|
-
rids = []
|
1667
|
-
finished_reasons: List[BaseFinishReason] = []
|
1668
|
-
|
1669
|
-
if self.is_generation:
|
1670
|
-
decoded_texts = []
|
1671
|
-
decode_ids_list = []
|
1672
|
-
read_offsets = []
|
1673
|
-
output_ids = []
|
1674
|
-
|
1675
|
-
skip_special_tokens = []
|
1676
|
-
spaces_between_special_tokens = []
|
1677
|
-
no_stop_trim = []
|
1678
|
-
prompt_tokens = []
|
1679
|
-
completion_tokens = []
|
1680
|
-
cached_tokens = []
|
1681
|
-
spec_verify_ct = []
|
1682
|
-
output_hidden_states = None
|
1683
|
-
|
1684
|
-
if return_logprob:
|
1685
|
-
input_token_logprobs_val = []
|
1686
|
-
input_token_logprobs_idx = []
|
1687
|
-
output_token_logprobs_val = []
|
1688
|
-
output_token_logprobs_idx = []
|
1689
|
-
input_top_logprobs_val = []
|
1690
|
-
input_top_logprobs_idx = []
|
1691
|
-
output_top_logprobs_val = []
|
1692
|
-
output_top_logprobs_idx = []
|
1693
|
-
input_token_ids_logprobs_val = []
|
1694
|
-
input_token_ids_logprobs_idx = []
|
1695
|
-
output_token_ids_logprobs_val = []
|
1696
|
-
output_token_ids_logprobs_idx = []
|
1697
|
-
else:
|
1698
|
-
input_token_logprobs_val = input_token_logprobs_idx = (
|
1699
|
-
output_token_logprobs_val
|
1700
|
-
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
1701
|
-
input_top_logprobs_idx
|
1702
|
-
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
1703
|
-
input_token_ids_logprobs_val
|
1704
|
-
) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
|
1705
|
-
output_token_ids_logprobs_idx
|
1706
|
-
) = None
|
1707
|
-
|
1708
|
-
for req in reqs:
|
1709
|
-
if req is skip_req:
|
1710
|
-
continue
|
1711
|
-
|
1712
|
-
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
|
1713
|
-
if self.model_config.is_multimodal_gen and req.to_abort:
|
1714
|
-
continue
|
1715
|
-
|
1716
|
-
if (
|
1717
|
-
req.finished()
|
1718
|
-
# If stream, follow the given stream_interval
|
1719
|
-
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
1720
|
-
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
1721
|
-
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
|
1722
|
-
# always increase one-by-one.
|
1723
|
-
or (
|
1724
|
-
not req.stream
|
1725
|
-
and len(req.output_ids) % 50 == 0
|
1726
|
-
and not self.model_config.is_multimodal_gen
|
1727
|
-
)
|
1728
|
-
):
|
1729
|
-
rids.append(req.rid)
|
1730
|
-
finished_reasons.append(
|
1731
|
-
req.finished_reason.to_json() if req.finished_reason else None
|
1732
|
-
)
|
1733
|
-
decoded_texts.append(req.decoded_text)
|
1734
|
-
decode_ids, read_offset = req.init_incremental_detokenize()
|
1735
|
-
decode_ids_list.append(decode_ids)
|
1736
|
-
read_offsets.append(read_offset)
|
1737
|
-
if self.skip_tokenizer_init:
|
1738
|
-
output_ids.append(req.output_ids)
|
1739
|
-
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
1740
|
-
spaces_between_special_tokens.append(
|
1741
|
-
req.sampling_params.spaces_between_special_tokens
|
1742
|
-
)
|
1743
|
-
no_stop_trim.append(req.sampling_params.no_stop_trim)
|
1744
|
-
|
1745
|
-
prompt_tokens.append(len(req.origin_input_ids))
|
1746
|
-
completion_tokens.append(len(req.output_ids))
|
1747
|
-
cached_tokens.append(req.cached_tokens)
|
1748
|
-
|
1749
|
-
if not self.spec_algorithm.is_none():
|
1750
|
-
spec_verify_ct.append(req.spec_verify_ct)
|
1751
|
-
|
1752
|
-
if return_logprob:
|
1753
|
-
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
1754
|
-
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
1755
|
-
output_token_logprobs_val.append(req.output_token_logprobs_val)
|
1756
|
-
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
|
1757
|
-
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
1758
|
-
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
1759
|
-
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1760
|
-
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1761
|
-
input_token_ids_logprobs_val.append(
|
1762
|
-
req.input_token_ids_logprobs_val
|
1763
|
-
)
|
1764
|
-
input_token_ids_logprobs_idx.append(
|
1765
|
-
req.input_token_ids_logprobs_idx
|
1766
|
-
)
|
1767
|
-
output_token_ids_logprobs_val.append(
|
1768
|
-
req.output_token_ids_logprobs_val
|
1769
|
-
)
|
1770
|
-
output_token_ids_logprobs_idx.append(
|
1771
|
-
req.output_token_ids_logprobs_idx
|
1772
|
-
)
|
1773
|
-
|
1774
|
-
if req.return_hidden_states:
|
1775
|
-
if output_hidden_states is None:
|
1776
|
-
output_hidden_states = []
|
1777
|
-
output_hidden_states.append(req.hidden_states)
|
1778
|
-
|
1779
|
-
# Send to detokenizer
|
1780
|
-
if rids:
|
1781
|
-
if self.model_config.is_multimodal_gen:
|
1782
|
-
raise NotImplementedError()
|
1783
|
-
self.send_to_detokenizer.send_pyobj(
|
1784
|
-
BatchTokenIDOut(
|
1785
|
-
rids,
|
1786
|
-
finished_reasons,
|
1787
|
-
decoded_texts,
|
1788
|
-
decode_ids_list,
|
1789
|
-
read_offsets,
|
1790
|
-
output_ids,
|
1791
|
-
skip_special_tokens,
|
1792
|
-
spaces_between_special_tokens,
|
1793
|
-
no_stop_trim,
|
1794
|
-
prompt_tokens,
|
1795
|
-
completion_tokens,
|
1796
|
-
cached_tokens,
|
1797
|
-
spec_verify_ct,
|
1798
|
-
input_token_logprobs_val,
|
1799
|
-
input_token_logprobs_idx,
|
1800
|
-
output_token_logprobs_val,
|
1801
|
-
output_token_logprobs_idx,
|
1802
|
-
input_top_logprobs_val,
|
1803
|
-
input_top_logprobs_idx,
|
1804
|
-
output_top_logprobs_val,
|
1805
|
-
output_top_logprobs_idx,
|
1806
|
-
input_token_ids_logprobs_val,
|
1807
|
-
input_token_ids_logprobs_idx,
|
1808
|
-
output_token_ids_logprobs_val,
|
1809
|
-
output_token_ids_logprobs_idx,
|
1810
|
-
output_hidden_states,
|
1811
|
-
)
|
1812
|
-
)
|
1813
|
-
else: # embedding or reward model
|
1814
|
-
embeddings = []
|
1815
|
-
prompt_tokens = []
|
1816
|
-
for req in reqs:
|
1817
|
-
if req.finished():
|
1818
|
-
rids.append(req.rid)
|
1819
|
-
finished_reasons.append(req.finished_reason.to_json())
|
1820
|
-
embeddings.append(req.embedding)
|
1821
|
-
prompt_tokens.append(len(req.origin_input_ids))
|
1822
|
-
self.send_to_detokenizer.send_pyobj(
|
1823
|
-
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
|
1824
|
-
)
|
1825
|
-
|
1826
1268
|
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
1827
1269
|
# Check if other DP workers have running batches
|
1828
1270
|
if local_batch is None:
|
@@ -1906,18 +1348,46 @@ class Scheduler:
|
|
1906
1348
|
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
1907
1349
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1908
1350
|
|
1351
|
+
def watchdog_thread(self):
|
1352
|
+
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
1353
|
+
self.watchdog_last_forward_ct = 0
|
1354
|
+
self.watchdog_last_time = time.time()
|
1355
|
+
|
1356
|
+
while True:
|
1357
|
+
current = time.time()
|
1358
|
+
if self.cur_batch is not None:
|
1359
|
+
if self.watchdog_last_forward_ct == self.forward_ct:
|
1360
|
+
if current > self.watchdog_last_time + self.watchdog_timeout:
|
1361
|
+
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
1362
|
+
break
|
1363
|
+
else:
|
1364
|
+
self.watchdog_last_forward_ct = self.forward_ct
|
1365
|
+
self.watchdog_last_time = current
|
1366
|
+
time.sleep(self.watchdog_timeout // 2)
|
1367
|
+
|
1368
|
+
# Print batch size and memory pool info to check whether there are de-sync issues.
|
1369
|
+
logger.error(
|
1370
|
+
f"{self.cur_batch.batch_size()=}, "
|
1371
|
+
f"{self.cur_batch.reqs=}, "
|
1372
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
1373
|
+
f"{self.tree_cache.evictable_size()=}, "
|
1374
|
+
)
|
1375
|
+
# Wait for some time so that the parent process can print the error.
|
1376
|
+
pyspy_dump_schedulers()
|
1377
|
+
print(file=sys.stderr, flush=True)
|
1378
|
+
print(file=sys.stdout, flush=True)
|
1379
|
+
time.sleep(5)
|
1380
|
+
self.parent_process.send_signal(signal.SIGQUIT)
|
1381
|
+
|
1909
1382
|
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
1910
1383
|
self.flush_cache()
|
1911
1384
|
|
1912
1385
|
def flush_cache(self):
|
1913
1386
|
"""Flush the memory pool and cache."""
|
1914
|
-
if len(self.waiting_queue) == 0 and (
|
1915
|
-
self.running_batch is None or len(self.running_batch.reqs) == 0
|
1916
|
-
):
|
1387
|
+
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
|
1917
1388
|
self.cur_batch = None
|
1918
1389
|
self.last_batch = None
|
1919
1390
|
self.tree_cache.reset()
|
1920
|
-
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
1921
1391
|
if self.grammar_backend:
|
1922
1392
|
self.grammar_backend.reset()
|
1923
1393
|
self.req_to_token_pool.clear()
|
@@ -1940,7 +1410,7 @@ class Scheduler:
|
|
1940
1410
|
logging.warning(
|
1941
1411
|
f"Cache not flushed because there are pending requests. "
|
1942
1412
|
f"#queue-req: {len(self.waiting_queue)}, "
|
1943
|
-
f"#running-req: {
|
1413
|
+
f"#running-req: {len(self.running_batch.reqs)}"
|
1944
1414
|
)
|
1945
1415
|
if_success = False
|
1946
1416
|
return if_success
|
@@ -1990,24 +1460,27 @@ class Scheduler:
|
|
1990
1460
|
|
1991
1461
|
def abort_request(self, recv_req: AbortReq):
|
1992
1462
|
# Delete requests in the waiting queue
|
1993
|
-
to_del =
|
1463
|
+
to_del = []
|
1994
1464
|
for i, req in enumerate(self.waiting_queue):
|
1995
|
-
if req.rid
|
1996
|
-
to_del
|
1465
|
+
if req.rid.startswith(recv_req.rid):
|
1466
|
+
to_del.append(i)
|
1997
1467
|
break
|
1998
1468
|
|
1999
|
-
|
2000
|
-
|
1469
|
+
# Sort in reverse order to avoid index issues when deleting
|
1470
|
+
for i in sorted(to_del, reverse=True):
|
1471
|
+
req = self.waiting_queue.pop(i)
|
2001
1472
|
logger.debug(f"Abort queued request. {req.rid=}")
|
2002
1473
|
return
|
2003
1474
|
|
2004
1475
|
# Delete requests in the running batch
|
2005
|
-
|
2006
|
-
|
2007
|
-
|
2008
|
-
|
2009
|
-
|
2010
|
-
|
1476
|
+
for req in self.running_batch.reqs:
|
1477
|
+
if req.rid.startswith(recv_req.rid) and not req.finished():
|
1478
|
+
logger.debug(f"Abort running request. {req.rid=}")
|
1479
|
+
req.to_abort = True
|
1480
|
+
return
|
1481
|
+
|
1482
|
+
def _pause_engine(self) -> Tuple[List[Req], int]:
|
1483
|
+
raise NotImplementedError()
|
2011
1484
|
|
2012
1485
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
2013
1486
|
"""In-place update of the weights from disk."""
|
@@ -2211,9 +1684,16 @@ def run_scheduler_process(
|
|
2211
1684
|
dp_rank: Optional[int],
|
2212
1685
|
pipe_writer,
|
2213
1686
|
):
|
1687
|
+
|
1688
|
+
# Generate the prefix
|
1689
|
+
if dp_rank is None:
|
1690
|
+
prefix = f" TP{tp_rank}"
|
1691
|
+
else:
|
1692
|
+
prefix = f" DP{dp_rank} TP{tp_rank}"
|
1693
|
+
|
2214
1694
|
# Config the process
|
2215
1695
|
# kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
|
2216
|
-
setproctitle.setproctitle(f"sglang::
|
1696
|
+
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
2217
1697
|
faulthandler.enable()
|
2218
1698
|
parent_process = psutil.Process().parent()
|
2219
1699
|
|
@@ -2222,10 +1702,6 @@ def run_scheduler_process(
|
|
2222
1702
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2223
1703
|
|
2224
1704
|
# Configure the logger
|
2225
|
-
if dp_rank is None:
|
2226
|
-
prefix = f" TP{tp_rank}"
|
2227
|
-
else:
|
2228
|
-
prefix = f" DP{dp_rank} TP{tp_rank}"
|
2229
1705
|
configure_logger(server_args, prefix=prefix)
|
2230
1706
|
suppress_other_loggers()
|
2231
1707
|
|