sglang 0.4.3.post3__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +94 -48
- sglang/srt/layers/attention/triton_backend.py +4 -2
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/scheduler.py +144 -127
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +34 -29
- sglang/srt/metrics/collector.py +8 -0
- sglang/srt/model_executor/cuda_graph_runner.py +1 -7
- sglang/srt/model_executor/model_runner.py +97 -78
- sglang/srt/server_args.py +3 -12
- sglang/srt/speculative/build_eagle_tree.py +6 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
- sglang/srt/speculative/eagle_utils.py +2 -1
- sglang/srt/speculative/eagle_worker.py +67 -32
- sglang/version.py +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +2 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +21 -21
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -159,17 +159,6 @@ class Scheduler:
|
|
159
159
|
)
|
160
160
|
self.gpu_id = gpu_id
|
161
161
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
162
|
-
self.decode_mem_cache_buf_multiplier = (
|
163
|
-
(
|
164
|
-
self.server_args.speculative_num_draft_tokens
|
165
|
-
+ (
|
166
|
-
self.server_args.speculative_eagle_topk
|
167
|
-
* self.server_args.speculative_num_draft_tokens
|
168
|
-
)
|
169
|
-
)
|
170
|
-
if not self.spec_algorithm.is_none()
|
171
|
-
else 1
|
172
|
-
)
|
173
162
|
|
174
163
|
# Distributed rank info
|
175
164
|
self.dp_size = server_args.dp_size
|
@@ -208,42 +197,12 @@ class Scheduler:
|
|
208
197
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
209
198
|
|
210
199
|
# Init tokenizer
|
211
|
-
self.
|
212
|
-
server_args.model_path,
|
213
|
-
trust_remote_code=server_args.trust_remote_code,
|
214
|
-
revision=server_args.revision,
|
215
|
-
context_length=server_args.context_length,
|
216
|
-
model_override_args=server_args.json_model_override_args,
|
217
|
-
is_embedding=server_args.is_embedding,
|
218
|
-
dtype=server_args.dtype,
|
219
|
-
quantization=server_args.quantization,
|
220
|
-
)
|
221
|
-
self.is_generation = self.model_config.is_generation
|
222
|
-
|
223
|
-
if server_args.skip_tokenizer_init:
|
224
|
-
self.tokenizer = self.processor = None
|
225
|
-
else:
|
226
|
-
if self.model_config.is_multimodal:
|
227
|
-
self.processor = get_processor(
|
228
|
-
server_args.tokenizer_path,
|
229
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
230
|
-
trust_remote_code=server_args.trust_remote_code,
|
231
|
-
revision=server_args.revision,
|
232
|
-
)
|
233
|
-
self.tokenizer = self.processor.tokenizer
|
234
|
-
else:
|
235
|
-
self.tokenizer = get_tokenizer(
|
236
|
-
server_args.tokenizer_path,
|
237
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
238
|
-
trust_remote_code=server_args.trust_remote_code,
|
239
|
-
revision=server_args.revision,
|
240
|
-
)
|
200
|
+
self.init_tokenizer()
|
241
201
|
|
242
202
|
# Check whether overlap can be enabled
|
243
203
|
if not self.is_generation:
|
244
204
|
self.enable_overlap = False
|
245
205
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
246
|
-
|
247
206
|
if self.model_config.is_multimodal:
|
248
207
|
self.enable_overlap = False
|
249
208
|
logger.info("Overlap scheduler is disabled for multimodal models.")
|
@@ -274,10 +233,8 @@ class Scheduler:
|
|
274
233
|
target_worker=self.tp_worker,
|
275
234
|
dp_rank=dp_rank,
|
276
235
|
)
|
277
|
-
self.prefill_only_one_req = True
|
278
236
|
else:
|
279
237
|
self.draft_worker = None
|
280
|
-
self.prefill_only_one_req = False
|
281
238
|
|
282
239
|
# Get token and memory info from the model worker
|
283
240
|
(
|
@@ -309,32 +266,7 @@ class Scheduler:
|
|
309
266
|
)
|
310
267
|
|
311
268
|
# Init memory pool and cache
|
312
|
-
self.
|
313
|
-
self.tp_worker.get_memory_pool()
|
314
|
-
)
|
315
|
-
|
316
|
-
if (
|
317
|
-
server_args.chunked_prefill_size is not None
|
318
|
-
and server_args.disable_radix_cache
|
319
|
-
):
|
320
|
-
self.tree_cache = ChunkCache(
|
321
|
-
req_to_token_pool=self.req_to_token_pool,
|
322
|
-
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
323
|
-
)
|
324
|
-
else:
|
325
|
-
if self.enable_hierarchical_cache:
|
326
|
-
self.tree_cache = HiRadixCache(
|
327
|
-
req_to_token_pool=self.req_to_token_pool,
|
328
|
-
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
329
|
-
)
|
330
|
-
else:
|
331
|
-
self.tree_cache = RadixCache(
|
332
|
-
req_to_token_pool=self.req_to_token_pool,
|
333
|
-
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
334
|
-
disable=server_args.disable_radix_cache,
|
335
|
-
)
|
336
|
-
|
337
|
-
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
269
|
+
self.init_memory_pool_and_cache()
|
338
270
|
|
339
271
|
# Init running status
|
340
272
|
self.waiting_queue: List[Req] = []
|
@@ -348,25 +280,13 @@ class Scheduler:
|
|
348
280
|
self.forward_ct = 0
|
349
281
|
self.forward_ct_decode = 0
|
350
282
|
self.num_generated_tokens = 0
|
351
|
-
self.spec_num_total_accepted_tokens = 0
|
352
|
-
self.spec_num_total_forward_ct = 0
|
353
|
-
self.cum_spec_accept_length = 0
|
354
|
-
self.cum_spec_accept_count = 0
|
355
283
|
self.last_decode_stats_tic = time.time()
|
356
284
|
self.return_health_check_ct = 0
|
357
285
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
358
286
|
if self.device == "cpu":
|
359
287
|
self.current_stream.synchronize = lambda: None # No-op for CPU
|
360
288
|
|
361
|
-
#
|
362
|
-
# The largest prefill length of a single request
|
363
|
-
self._largest_prefill_len: int = 0
|
364
|
-
# The largest context length (prefill + generation) of a single request
|
365
|
-
self._largest_prefill_decode_len: int = 0
|
366
|
-
self.last_gen_throughput: float = 0.0
|
367
|
-
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
368
|
-
|
369
|
-
# Session info
|
289
|
+
# Init session info
|
370
290
|
self.sessions: Dict[str, Session] = {}
|
371
291
|
|
372
292
|
# Init chunked prefill
|
@@ -387,11 +307,11 @@ class Scheduler:
|
|
387
307
|
else:
|
388
308
|
self.grammar_backend = None
|
389
309
|
|
390
|
-
# Init new token estimation
|
310
|
+
# Init schedule policy and new token estimation
|
311
|
+
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
391
312
|
assert (
|
392
313
|
server_args.schedule_conservativeness >= 0
|
393
314
|
), "Invalid schedule_conservativeness"
|
394
|
-
|
395
315
|
self.init_new_token_ratio = min(
|
396
316
|
global_config.default_init_new_token_ratio
|
397
317
|
* server_args.schedule_conservativeness,
|
@@ -430,14 +350,7 @@ class Scheduler:
|
|
430
350
|
self.profiler_target_forward_ct: Optional[int] = None
|
431
351
|
|
432
352
|
# Init metrics stats
|
433
|
-
self.
|
434
|
-
if self.enable_metrics:
|
435
|
-
self.metrics_collector = SchedulerMetricsCollector(
|
436
|
-
labels={
|
437
|
-
"model_name": self.server_args.served_model_name,
|
438
|
-
# TODO: Add lora name/path in the future,
|
439
|
-
},
|
440
|
-
)
|
353
|
+
self.init_metrics()
|
441
354
|
|
442
355
|
# Init request dispatcher
|
443
356
|
self._request_dispatcher = TypeBasedDispatcher(
|
@@ -460,39 +373,104 @@ class Scheduler:
|
|
460
373
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
461
374
|
(ProfileReq, self.profile),
|
462
375
|
(GetInternalStateReq, self.get_internal_state),
|
376
|
+
(SetInternalStateReq, self.set_internal_state),
|
463
377
|
]
|
464
378
|
)
|
465
379
|
|
466
|
-
def
|
467
|
-
|
468
|
-
self.watchdog_last_forward_ct = 0
|
469
|
-
self.watchdog_last_time = time.time()
|
380
|
+
def init_tokenizer(self):
|
381
|
+
server_args = self.server_args
|
470
382
|
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
383
|
+
self.model_config = ModelConfig(
|
384
|
+
server_args.model_path,
|
385
|
+
trust_remote_code=server_args.trust_remote_code,
|
386
|
+
revision=server_args.revision,
|
387
|
+
context_length=server_args.context_length,
|
388
|
+
model_override_args=server_args.json_model_override_args,
|
389
|
+
is_embedding=server_args.is_embedding,
|
390
|
+
dtype=server_args.dtype,
|
391
|
+
quantization=server_args.quantization,
|
392
|
+
)
|
393
|
+
self.is_generation = self.model_config.is_generation
|
482
394
|
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
395
|
+
if server_args.skip_tokenizer_init:
|
396
|
+
self.tokenizer = self.processor = None
|
397
|
+
else:
|
398
|
+
if self.model_config.is_multimodal:
|
399
|
+
self.processor = get_processor(
|
400
|
+
server_args.tokenizer_path,
|
401
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
402
|
+
trust_remote_code=server_args.trust_remote_code,
|
403
|
+
revision=server_args.revision,
|
404
|
+
)
|
405
|
+
self.tokenizer = self.processor.tokenizer
|
406
|
+
else:
|
407
|
+
self.tokenizer = get_tokenizer(
|
408
|
+
server_args.tokenizer_path,
|
409
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
410
|
+
trust_remote_code=server_args.trust_remote_code,
|
411
|
+
revision=server_args.revision,
|
412
|
+
)
|
413
|
+
|
414
|
+
def init_memory_pool_and_cache(self):
|
415
|
+
server_args = self.server_args
|
416
|
+
|
417
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
418
|
+
self.tp_worker.get_memory_pool()
|
419
|
+
)
|
420
|
+
|
421
|
+
if (
|
422
|
+
server_args.chunked_prefill_size is not None
|
423
|
+
and server_args.disable_radix_cache
|
424
|
+
):
|
425
|
+
self.tree_cache = ChunkCache(
|
426
|
+
req_to_token_pool=self.req_to_token_pool,
|
427
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
428
|
+
)
|
429
|
+
else:
|
430
|
+
if self.enable_hierarchical_cache:
|
431
|
+
self.tree_cache = HiRadixCache(
|
432
|
+
req_to_token_pool=self.req_to_token_pool,
|
433
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
434
|
+
)
|
435
|
+
else:
|
436
|
+
self.tree_cache = RadixCache(
|
437
|
+
req_to_token_pool=self.req_to_token_pool,
|
438
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
439
|
+
disable=server_args.disable_radix_cache,
|
440
|
+
)
|
441
|
+
|
442
|
+
self.decode_mem_cache_buf_multiplier = (
|
443
|
+
1
|
444
|
+
if self.spec_algorithm.is_none()
|
445
|
+
else (
|
446
|
+
server_args.speculative_num_draft_tokens
|
447
|
+
+ (
|
448
|
+
server_args.speculative_eagle_topk
|
449
|
+
* server_args.speculative_num_steps
|
450
|
+
)
|
451
|
+
)
|
489
452
|
)
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
self.
|
453
|
+
|
454
|
+
def init_metrics(self):
|
455
|
+
# The largest prefill length of a single request
|
456
|
+
self._largest_prefill_len: int = 0
|
457
|
+
# The largest context length (prefill + generation) of a single request
|
458
|
+
self._largest_prefill_decode_len: int = 0
|
459
|
+
self.last_gen_throughput: float = 0.0
|
460
|
+
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
461
|
+
self.spec_num_total_accepted_tokens = 0
|
462
|
+
self.spec_num_total_forward_ct = 0
|
463
|
+
self.cum_spec_accept_length = 0
|
464
|
+
self.cum_spec_accept_count = 0
|
465
|
+
self.stats = SchedulerStats()
|
466
|
+
if self.enable_metrics:
|
467
|
+
engine_type = "unified"
|
468
|
+
self.metrics_collector = SchedulerMetricsCollector(
|
469
|
+
labels={
|
470
|
+
"model_name": self.server_args.served_model_name,
|
471
|
+
"engine_type": engine_type,
|
472
|
+
},
|
473
|
+
)
|
496
474
|
|
497
475
|
@torch.no_grad()
|
498
476
|
def event_loop_normal(self):
|
@@ -932,7 +910,7 @@ class Scheduler:
|
|
932
910
|
):
|
933
911
|
# During idle time, also collect metrics every 30 seconds.
|
934
912
|
num_used = self.max_total_num_tokens - (
|
935
|
-
self.
|
913
|
+
self.token_to_kv_pool_allocator.available_size()
|
936
914
|
+ self.tree_cache.evictable_size()
|
937
915
|
)
|
938
916
|
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
@@ -1077,8 +1055,6 @@ class Scheduler:
|
|
1077
1055
|
else:
|
1078
1056
|
self.batch_is_full = True
|
1079
1057
|
break
|
1080
|
-
if self.prefill_only_one_req:
|
1081
|
-
break
|
1082
1058
|
|
1083
1059
|
# Update waiting queue
|
1084
1060
|
can_run_list: List[Req] = adder.can_run_list
|
@@ -1180,6 +1156,7 @@ class Scheduler:
|
|
1180
1156
|
):
|
1181
1157
|
self.stop_profile()
|
1182
1158
|
|
1159
|
+
# Run forward
|
1183
1160
|
if self.is_generation:
|
1184
1161
|
if self.spec_algorithm.is_none():
|
1185
1162
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -1200,6 +1177,7 @@ class Scheduler:
|
|
1200
1177
|
self.spec_num_total_forward_ct += batch.batch_size()
|
1201
1178
|
self.num_generated_tokens += num_accepted_tokens
|
1202
1179
|
batch.output_ids = next_token_ids
|
1180
|
+
|
1203
1181
|
# These 2 values are needed for processing the output, but the values can be
|
1204
1182
|
# modified by overlap schedule. So we have to copy them here so that
|
1205
1183
|
# we can use the correct values in output processing.
|
@@ -1233,7 +1211,6 @@ class Scheduler:
|
|
1233
1211
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1234
1212
|
):
|
1235
1213
|
if batch.forward_mode.is_decode():
|
1236
|
-
assert isinstance(result, GenerationBatchResult)
|
1237
1214
|
self.process_batch_result_decode(batch, result)
|
1238
1215
|
if batch.is_empty():
|
1239
1216
|
self.running_batch = None
|
@@ -1485,6 +1462,7 @@ class Scheduler:
|
|
1485
1462
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1486
1463
|
self.current_stream.synchronize()
|
1487
1464
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1465
|
+
|
1488
1466
|
self.stream_output(batch.reqs, batch.return_logprob)
|
1489
1467
|
|
1490
1468
|
self.token_to_kv_pool_allocator.free_group_end()
|
@@ -1588,7 +1566,9 @@ class Scheduler:
|
|
1588
1566
|
req.temp_input_token_ids_logprobs_idx
|
1589
1567
|
)
|
1590
1568
|
for val, idx in zip(
|
1591
|
-
req.temp_input_top_logprobs_val,
|
1569
|
+
req.temp_input_top_logprobs_val,
|
1570
|
+
req.temp_input_top_logprobs_idx,
|
1571
|
+
strict=True,
|
1592
1572
|
):
|
1593
1573
|
req.input_top_logprobs_val.extend(val)
|
1594
1574
|
req.input_top_logprobs_idx.extend(idx)
|
@@ -1813,14 +1793,18 @@ class Scheduler:
|
|
1813
1793
|
else: # embedding or reward model
|
1814
1794
|
embeddings = []
|
1815
1795
|
prompt_tokens = []
|
1796
|
+
cached_tokens = []
|
1816
1797
|
for req in reqs:
|
1817
1798
|
if req.finished():
|
1818
1799
|
rids.append(req.rid)
|
1819
1800
|
finished_reasons.append(req.finished_reason.to_json())
|
1820
1801
|
embeddings.append(req.embedding)
|
1821
1802
|
prompt_tokens.append(len(req.origin_input_ids))
|
1803
|
+
cached_tokens.append(req.cached_tokens)
|
1822
1804
|
self.send_to_detokenizer.send_pyobj(
|
1823
|
-
BatchEmbeddingOut(
|
1805
|
+
BatchEmbeddingOut(
|
1806
|
+
rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
|
1807
|
+
)
|
1824
1808
|
)
|
1825
1809
|
|
1826
1810
|
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
@@ -1906,6 +1890,37 @@ class Scheduler:
|
|
1906
1890
|
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
1907
1891
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1908
1892
|
|
1893
|
+
def watchdog_thread(self):
|
1894
|
+
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
1895
|
+
self.watchdog_last_forward_ct = 0
|
1896
|
+
self.watchdog_last_time = time.time()
|
1897
|
+
|
1898
|
+
while True:
|
1899
|
+
current = time.time()
|
1900
|
+
if self.cur_batch is not None:
|
1901
|
+
if self.watchdog_last_forward_ct == self.forward_ct:
|
1902
|
+
if current > self.watchdog_last_time + self.watchdog_timeout:
|
1903
|
+
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
1904
|
+
break
|
1905
|
+
else:
|
1906
|
+
self.watchdog_last_forward_ct = self.forward_ct
|
1907
|
+
self.watchdog_last_time = current
|
1908
|
+
time.sleep(self.watchdog_timeout // 2)
|
1909
|
+
|
1910
|
+
# Print batch size and memory pool info to check whether there are de-sync issues.
|
1911
|
+
logger.error(
|
1912
|
+
f"{self.cur_batch.batch_size()=}, "
|
1913
|
+
f"{self.cur_batch.reqs=}, "
|
1914
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
1915
|
+
f"{self.tree_cache.evictable_size()=}, "
|
1916
|
+
)
|
1917
|
+
# Wait for some time so that the parent process can print the error.
|
1918
|
+
pyspy_dump_schedulers()
|
1919
|
+
print(file=sys.stderr, flush=True)
|
1920
|
+
print(file=sys.stdout, flush=True)
|
1921
|
+
time.sleep(5)
|
1922
|
+
self.parent_process.send_signal(signal.SIGQUIT)
|
1923
|
+
|
1909
1924
|
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
1910
1925
|
self.flush_cache()
|
1911
1926
|
|
@@ -1917,7 +1932,6 @@ class Scheduler:
|
|
1917
1932
|
self.cur_batch = None
|
1918
1933
|
self.last_batch = None
|
1919
1934
|
self.tree_cache.reset()
|
1920
|
-
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
1921
1935
|
if self.grammar_backend:
|
1922
1936
|
self.grammar_backend.reset()
|
1923
1937
|
self.req_to_token_pool.clear()
|
@@ -2009,6 +2023,9 @@ class Scheduler:
|
|
2009
2023
|
req.to_abort = True
|
2010
2024
|
break
|
2011
2025
|
|
2026
|
+
def _pause_engine(self) -> Tuple[List[Req], int]:
|
2027
|
+
raise NotImplementedError()
|
2028
|
+
|
2012
2029
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
2013
2030
|
"""In-place update of the weights from disk."""
|
2014
2031
|
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
@@ -20,9 +20,8 @@ Memory pool.
|
|
20
20
|
|
21
21
|
SGLang has two levels of memory pool.
|
22
22
|
ReqToTokenPool maps a a request to its token locations.
|
23
|
-
TokenToKVPoolAllocator
|
24
|
-
KVCache actually holds the physical kv cache.
|
25
|
-
by TokenToKVPoolAllocator
|
23
|
+
TokenToKVPoolAllocator manages the indices to kv cache data.
|
24
|
+
KVCache actually holds the physical kv cache.
|
26
25
|
"""
|
27
26
|
|
28
27
|
import abc
|
@@ -92,14 +91,40 @@ class ReqToTokenPool:
|
|
92
91
|
self.free_slots = list(range(self.size))
|
93
92
|
|
94
93
|
|
94
|
+
class KVCache(abc.ABC):
|
95
|
+
|
96
|
+
@abc.abstractmethod
|
97
|
+
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
98
|
+
raise NotImplementedError()
|
99
|
+
|
100
|
+
@abc.abstractmethod
|
101
|
+
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
102
|
+
raise NotImplementedError()
|
103
|
+
|
104
|
+
@abc.abstractmethod
|
105
|
+
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
106
|
+
raise NotImplementedError()
|
107
|
+
|
108
|
+
@abc.abstractmethod
|
109
|
+
def set_kv_buffer(
|
110
|
+
self,
|
111
|
+
layer: RadixAttention,
|
112
|
+
loc: torch.Tensor,
|
113
|
+
cache_k: torch.Tensor,
|
114
|
+
cache_v: torch.Tensor,
|
115
|
+
) -> None:
|
116
|
+
raise NotImplementedError()
|
117
|
+
|
118
|
+
|
95
119
|
class TokenToKVPoolAllocator:
|
96
|
-
"""
|
120
|
+
"""An allocator managing the indices to kv cache data."""
|
97
121
|
|
98
122
|
def __init__(
|
99
123
|
self,
|
100
124
|
size: int,
|
101
125
|
dtype: torch.dtype,
|
102
126
|
device: str,
|
127
|
+
kvcache: KVCache,
|
103
128
|
):
|
104
129
|
self.size = size
|
105
130
|
self.dtype = dtype
|
@@ -110,9 +135,14 @@ class TokenToKVPoolAllocator:
|
|
110
135
|
self.free_group = []
|
111
136
|
self.clear()
|
112
137
|
|
138
|
+
self._kvcache = kvcache
|
139
|
+
|
113
140
|
def available_size(self):
|
114
141
|
return len(self.free_slots)
|
115
142
|
|
143
|
+
def get_kvcache(self):
|
144
|
+
return self._kvcache
|
145
|
+
|
116
146
|
def alloc(self, need_size: int):
|
117
147
|
if need_size > len(self.free_slots):
|
118
148
|
return None
|
@@ -147,31 +177,6 @@ class TokenToKVPoolAllocator:
|
|
147
177
|
self.free_group = []
|
148
178
|
|
149
179
|
|
150
|
-
class KVCache(abc.ABC):
|
151
|
-
|
152
|
-
@abc.abstractmethod
|
153
|
-
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
154
|
-
raise NotImplementedError()
|
155
|
-
|
156
|
-
@abc.abstractmethod
|
157
|
-
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
158
|
-
raise NotImplementedError()
|
159
|
-
|
160
|
-
@abc.abstractmethod
|
161
|
-
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
162
|
-
raise NotImplementedError()
|
163
|
-
|
164
|
-
@abc.abstractmethod
|
165
|
-
def set_kv_buffer(
|
166
|
-
self,
|
167
|
-
layer: RadixAttention,
|
168
|
-
loc: torch.Tensor,
|
169
|
-
cache_k: torch.Tensor,
|
170
|
-
cache_v: torch.Tensor,
|
171
|
-
) -> None:
|
172
|
-
raise NotImplementedError()
|
173
|
-
|
174
|
-
|
175
180
|
class MHATokenToKVPool(KVCache):
|
176
181
|
|
177
182
|
def __init__(
|
sglang/srt/metrics/collector.py
CHANGED
@@ -121,6 +121,12 @@ class TokenizerMetricsCollector:
|
|
121
121
|
labelnames=labels.keys(),
|
122
122
|
)
|
123
123
|
|
124
|
+
self.cached_tokens_total = Counter(
|
125
|
+
name="sglang:cached_tokens_total",
|
126
|
+
documentation="Number of cached prompt tokens.",
|
127
|
+
labelnames=labels.keys(),
|
128
|
+
)
|
129
|
+
|
124
130
|
self.num_requests_total = Counter(
|
125
131
|
name="sglang:num_requests_total",
|
126
132
|
documentation="Number of requests processed.",
|
@@ -245,10 +251,12 @@ class TokenizerMetricsCollector:
|
|
245
251
|
self,
|
246
252
|
prompt_tokens: int,
|
247
253
|
generation_tokens: int,
|
254
|
+
cached_tokens: int,
|
248
255
|
e2e_latency: float,
|
249
256
|
):
|
250
257
|
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
251
258
|
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
259
|
+
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
|
252
260
|
self.num_requests_total.labels(**self.labels).inc(1)
|
253
261
|
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
254
262
|
if generation_tokens >= 1:
|
@@ -396,16 +396,10 @@ class CudaGraphRunner:
|
|
396
396
|
|
397
397
|
run_once()
|
398
398
|
|
399
|
-
torch.cuda.synchronize()
|
400
|
-
self.model_runner.tp_group.barrier()
|
401
|
-
|
402
399
|
global global_graph_memory_pool
|
403
400
|
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
404
401
|
out = run_once()
|
405
402
|
|
406
|
-
torch.cuda.synchronize()
|
407
|
-
self.model_runner.tp_group.barrier()
|
408
|
-
|
409
403
|
global_graph_memory_pool = graph.pool()
|
410
404
|
return graph, out
|
411
405
|
|
@@ -427,7 +421,7 @@ class CudaGraphRunner:
|
|
427
421
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
428
422
|
self.capture()
|
429
423
|
|
430
|
-
def replay(self, forward_batch: ForwardBatch):
|
424
|
+
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
431
425
|
self.recapture_if_needed(forward_batch)
|
432
426
|
|
433
427
|
raw_bs = forward_batch.batch_size
|