sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -2
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/decode.py +43 -0
- sglang/srt/disaggregation/mini_lb.py +69 -8
- sglang/srt/disaggregation/mooncake/conn.py +1 -1
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +100 -16
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +781 -150
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +19 -4
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/rotary_embedding.py +6 -6
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/io_struct.py +14 -3
- sglang/srt/managers/schedule_batch.py +13 -0
- sglang/srt/managers/scheduler.py +16 -6
- sglang/srt/managers/tokenizer_manager.py +115 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +31 -13
- sglang/srt/model_executor/cuda_graph_runner.py +13 -8
- sglang/srt/model_executor/model_runner.py +19 -4
- sglang/srt/models/deepseek_v2.py +9 -6
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +52 -40
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/utils.py +46 -5
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -96,8 +96,8 @@ class GenerateReqInput:
|
|
96
96
|
return_hidden_states: bool = False
|
97
97
|
|
98
98
|
# For disaggregated inference
|
99
|
-
bootstrap_host: Optional[str] = None
|
100
|
-
bootstrap_room: Optional[int] = None
|
99
|
+
bootstrap_host: Optional[Union[List[str], str]] = None
|
100
|
+
bootstrap_room: Optional[Union[List[int], int]] = None
|
101
101
|
|
102
102
|
def normalize_batch_and_arguments(self):
|
103
103
|
"""
|
@@ -397,6 +397,12 @@ class GenerateReqInput:
|
|
397
397
|
else None
|
398
398
|
),
|
399
399
|
return_hidden_states=self.return_hidden_states,
|
400
|
+
bootstrap_host=(
|
401
|
+
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
402
|
+
),
|
403
|
+
bootstrap_room=(
|
404
|
+
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
405
|
+
),
|
400
406
|
)
|
401
407
|
|
402
408
|
|
@@ -665,10 +671,15 @@ class BatchEmbeddingOut:
|
|
665
671
|
|
666
672
|
|
667
673
|
@dataclass
|
668
|
-
class
|
674
|
+
class FlushCacheReqInput:
|
669
675
|
pass
|
670
676
|
|
671
677
|
|
678
|
+
@dataclass
|
679
|
+
class FlushCacheReqOutput:
|
680
|
+
success: bool
|
681
|
+
|
682
|
+
|
672
683
|
@dataclass
|
673
684
|
class UpdateWeightFromDiskReqInput:
|
674
685
|
# The model path with the new weights
|
@@ -539,6 +539,11 @@ class Req:
|
|
539
539
|
# The first output_id transferred from prefill instance.
|
540
540
|
self.transferred_output_id: Optional[int] = None
|
541
541
|
|
542
|
+
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
543
|
+
# This is because kv is not ready in `process_prefill_chunk`.
|
544
|
+
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
545
|
+
self.tmp_end_idx: int = -1
|
546
|
+
|
542
547
|
@property
|
543
548
|
def seqlen(self):
|
544
549
|
return len(self.origin_input_ids) + len(self.output_ids)
|
@@ -571,6 +576,14 @@ class Req:
|
|
571
576
|
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
572
577
|
rid=self.rid, key=self.adjust_max_prefix_ids()
|
573
578
|
)
|
579
|
+
elif enable_hierarchical_cache:
|
580
|
+
# in case last_node is evicted during scheduling, we need to update the prefix_indices
|
581
|
+
while self.last_node.evicted:
|
582
|
+
self.prefix_indices = self.prefix_indices[
|
583
|
+
: -len(self.last_node.host_value)
|
584
|
+
]
|
585
|
+
self.last_node = self.last_node.parent
|
586
|
+
|
574
587
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
575
588
|
|
576
589
|
def adjust_max_prefix_ids(self):
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -60,7 +60,8 @@ from sglang.srt.managers.io_struct import (
|
|
60
60
|
CloseSessionReqInput,
|
61
61
|
ExpertDistributionReq,
|
62
62
|
ExpertDistributionReqOutput,
|
63
|
-
|
63
|
+
FlushCacheReqInput,
|
64
|
+
FlushCacheReqOutput,
|
64
65
|
GetInternalStateReq,
|
65
66
|
GetInternalStateReqOutput,
|
66
67
|
GetWeightsByNameReqInput,
|
@@ -402,7 +403,7 @@ class Scheduler(
|
|
402
403
|
[
|
403
404
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
404
405
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
405
|
-
(
|
406
|
+
(FlushCacheReqInput, self.flush_cache_wrapped),
|
406
407
|
(AbortReq, self.abort_request),
|
407
408
|
(OpenSessionReqInput, self.open_session),
|
408
409
|
(CloseSessionReqInput, self.close_session),
|
@@ -488,6 +489,8 @@ class Scheduler(
|
|
488
489
|
tp_cache_group=self.tp_cpu_group,
|
489
490
|
page_size=self.page_size,
|
490
491
|
hicache_ratio=server_args.hicache_ratio,
|
492
|
+
hicache_size=server_args.hicache_size,
|
493
|
+
hicache_write_policy=server_args.hicache_write_policy,
|
491
494
|
)
|
492
495
|
else:
|
493
496
|
self.tree_cache = RadixCache(
|
@@ -1596,8 +1599,9 @@ class Scheduler(
|
|
1596
1599
|
time.sleep(5)
|
1597
1600
|
self.parent_process.send_signal(signal.SIGQUIT)
|
1598
1601
|
|
1599
|
-
def flush_cache_wrapped(self, recv_req:
|
1600
|
-
self.flush_cache()
|
1602
|
+
def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
|
1603
|
+
success = self.flush_cache()
|
1604
|
+
return FlushCacheReqOutput(success=success)
|
1601
1605
|
|
1602
1606
|
def flush_cache(self):
|
1603
1607
|
"""Flush the memory pool and cache."""
|
@@ -2010,9 +2014,15 @@ def run_scheduler_process(
|
|
2010
2014
|
else:
|
2011
2015
|
scheduler.event_loop_normal()
|
2012
2016
|
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
2013
|
-
scheduler.
|
2017
|
+
if scheduler.enable_overlap:
|
2018
|
+
scheduler.event_loop_overlap_disagg_prefill()
|
2019
|
+
else:
|
2020
|
+
scheduler.event_loop_normal_disagg_prefill()
|
2014
2021
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2015
|
-
scheduler.
|
2022
|
+
if scheduler.enable_overlap:
|
2023
|
+
scheduler.event_loop_overlap_disagg_decode()
|
2024
|
+
else:
|
2025
|
+
scheduler.event_loop_normal_disagg_decode()
|
2016
2026
|
|
2017
2027
|
except Exception:
|
2018
2028
|
traceback = get_exception_traceback()
|
@@ -66,7 +66,8 @@ from sglang.srt.managers.io_struct import (
|
|
66
66
|
EmbeddingReqInput,
|
67
67
|
ExpertDistributionReq,
|
68
68
|
ExpertDistributionReqOutput,
|
69
|
-
|
69
|
+
FlushCacheReqInput,
|
70
|
+
FlushCacheReqOutput,
|
70
71
|
GenerateReqInput,
|
71
72
|
GetInternalStateReq,
|
72
73
|
GetInternalStateReqOutput,
|
@@ -264,6 +265,9 @@ class TokenizerManager:
|
|
264
265
|
self.resume_memory_occupation_communicator = _Communicator(
|
265
266
|
self.send_to_scheduler, server_args.dp_size
|
266
267
|
)
|
268
|
+
self.flush_cache_communicator = _Communicator(
|
269
|
+
self.send_to_scheduler, server_args.dp_size
|
270
|
+
)
|
267
271
|
self.start_profile_communicator = _Communicator(
|
268
272
|
self.send_to_scheduler, server_args.dp_size
|
269
273
|
)
|
@@ -314,6 +318,10 @@ class TokenizerManager:
|
|
314
318
|
ResumeMemoryOccupationReqOutput,
|
315
319
|
self.resume_memory_occupation_communicator.handle_recv,
|
316
320
|
),
|
321
|
+
(
|
322
|
+
FlushCacheReqOutput,
|
323
|
+
self.flush_cache_communicator.handle_recv,
|
324
|
+
),
|
317
325
|
(
|
318
326
|
ProfileReqOutput,
|
319
327
|
self.start_profile_communicator.handle_recv,
|
@@ -415,38 +423,60 @@ class TokenizerManager:
|
|
415
423
|
)
|
416
424
|
if image_inputs and "input_ids" in image_inputs:
|
417
425
|
input_ids = image_inputs["input_ids"]
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
+
|
427
|
+
self._validate_token_len(obj, input_ids)
|
428
|
+
return self._create_tokenized_object(
|
429
|
+
obj, input_text, input_ids, input_embeds, image_inputs
|
430
|
+
)
|
431
|
+
|
432
|
+
def _validate_token_len(
|
433
|
+
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
434
|
+
) -> None:
|
435
|
+
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
426
436
|
|
427
437
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
438
|
+
# Check if input alone exceeds context length
|
428
439
|
if input_token_num >= self.context_len:
|
429
440
|
raise ValueError(
|
430
441
|
f"The input ({input_token_num} tokens) is longer than the "
|
431
442
|
f"model's context length ({self.context_len} tokens)."
|
432
443
|
)
|
433
444
|
|
445
|
+
# Check total tokens (input + max_new_tokens)
|
446
|
+
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
434
447
|
if (
|
435
|
-
|
436
|
-
and
|
437
|
-
>= self.context_len
|
448
|
+
max_new_tokens is not None
|
449
|
+
and (max_new_tokens + input_token_num) >= self.context_len
|
438
450
|
):
|
439
|
-
|
451
|
+
total_tokens = max_new_tokens + input_token_num
|
452
|
+
error_msg = (
|
440
453
|
f"Requested token count exceeds the model's maximum context length "
|
441
|
-
f"of {self.context_len} tokens. You requested a total of "
|
442
|
-
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
|
454
|
+
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
443
455
|
f"tokens: {input_token_num} tokens from the input messages and "
|
444
|
-
f"{
|
445
|
-
f"
|
446
|
-
|
456
|
+
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
457
|
+
f"of tokens in the input messages or the completion to fit within the limit."
|
458
|
+
)
|
459
|
+
raise ValueError(error_msg)
|
460
|
+
|
461
|
+
def _create_tokenized_object(
|
462
|
+
self,
|
463
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
464
|
+
input_text: str,
|
465
|
+
input_ids: List[int],
|
466
|
+
input_embeds: Optional[Union[List[float], None]] = None,
|
467
|
+
image_inputs: Optional[Dict] = None,
|
468
|
+
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
469
|
+
"""Create a tokenized request object from common parameters."""
|
470
|
+
|
471
|
+
if self.is_generation:
|
472
|
+
return_logprob = obj.return_logprob
|
473
|
+
logprob_start_len = obj.logprob_start_len
|
474
|
+
top_logprobs_num = obj.top_logprobs_num
|
475
|
+
token_ids_logprob = obj.token_ids_logprob
|
476
|
+
session_params = (
|
477
|
+
SessionParams(**obj.session_params) if obj.session_params else None
|
447
478
|
)
|
448
479
|
|
449
|
-
# Parse sampling parameters
|
450
480
|
sampling_params = SamplingParams(**obj.sampling_params)
|
451
481
|
sampling_params.normalize(self.tokenizer)
|
452
482
|
sampling_params.verify()
|
@@ -483,6 +513,50 @@ class TokenizerManager:
|
|
483
513
|
|
484
514
|
return tokenized_obj
|
485
515
|
|
516
|
+
async def _batch_tokenize_and_process(
|
517
|
+
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
518
|
+
) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
|
519
|
+
"""Handle batch tokenization for text inputs only."""
|
520
|
+
logger.debug(f"Starting batch tokenization for {batch_size} text requests")
|
521
|
+
|
522
|
+
# Collect requests and texts
|
523
|
+
requests = [obj[i] for i in range(batch_size)]
|
524
|
+
texts = [req.text for req in requests]
|
525
|
+
|
526
|
+
# Batch tokenize all texts
|
527
|
+
encoded = self.tokenizer(texts)
|
528
|
+
input_ids_list = encoded["input_ids"]
|
529
|
+
|
530
|
+
# Process all requests
|
531
|
+
tokenized_objs = []
|
532
|
+
for i, req in enumerate(requests):
|
533
|
+
self._validate_token_len(obj[i], input_ids_list[i])
|
534
|
+
tokenized_objs.append(
|
535
|
+
self._create_tokenized_object(
|
536
|
+
req, req.text, input_ids_list[i], None, None
|
537
|
+
)
|
538
|
+
)
|
539
|
+
logger.debug(f"Completed batch processing for {batch_size} requests")
|
540
|
+
return tokenized_objs
|
541
|
+
|
542
|
+
def _validate_batch_tokenization_constraints(
|
543
|
+
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
544
|
+
) -> None:
|
545
|
+
"""Validate constraints for batch tokenization processing."""
|
546
|
+
for i in range(batch_size):
|
547
|
+
if self.is_generation and obj[i].image_data:
|
548
|
+
raise ValueError(
|
549
|
+
"For image input processing do not set `enable_tokenizer_batch_encode`."
|
550
|
+
)
|
551
|
+
if obj[i].input_ids is not None:
|
552
|
+
raise ValueError(
|
553
|
+
"Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
|
554
|
+
)
|
555
|
+
if obj[i].input_embeds is not None:
|
556
|
+
raise ValueError(
|
557
|
+
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
558
|
+
)
|
559
|
+
|
486
560
|
def _send_one_request(
|
487
561
|
self,
|
488
562
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -560,14 +634,27 @@ class TokenizerManager:
|
|
560
634
|
|
561
635
|
generators = []
|
562
636
|
rids = []
|
637
|
+
|
563
638
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
self.
|
569
|
-
|
570
|
-
|
639
|
+
if self.server_args.enable_tokenizer_batch_encode:
|
640
|
+
# Validate batch tokenization constraints
|
641
|
+
self._validate_batch_tokenization_constraints(batch_size, obj)
|
642
|
+
|
643
|
+
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
644
|
+
|
645
|
+
for i, tokenized_obj in enumerate(tokenized_objs):
|
646
|
+
tmp_obj = obj[i]
|
647
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
648
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
649
|
+
rids.append(tmp_obj.rid)
|
650
|
+
else:
|
651
|
+
# Sequential tokenization and processing
|
652
|
+
for i in range(batch_size):
|
653
|
+
tmp_obj = obj[i]
|
654
|
+
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
655
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
656
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
657
|
+
rids.append(tmp_obj.rid)
|
571
658
|
else:
|
572
659
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
573
660
|
if batch_size > 128:
|
@@ -628,9 +715,8 @@ class TokenizerManager:
|
|
628
715
|
except StopAsyncIteration:
|
629
716
|
pass
|
630
717
|
|
631
|
-
def flush_cache(self):
|
632
|
-
|
633
|
-
self.send_to_scheduler.send_pyobj(req)
|
718
|
+
async def flush_cache(self) -> FlushCacheReqOutput:
|
719
|
+
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
634
720
|
|
635
721
|
def abort_request(self, rid: str):
|
636
722
|
if rid not in self.rid_to_state:
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
|
|
29
29
|
tp_cache_group: torch.distributed.ProcessGroup,
|
30
30
|
page_size: int,
|
31
31
|
hicache_ratio: float,
|
32
|
+
hicache_size: int,
|
33
|
+
hicache_write_policy: str,
|
32
34
|
):
|
33
35
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
34
36
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
35
37
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
36
|
-
self.kv_cache, hicache_ratio, page_size
|
38
|
+
self.kv_cache, hicache_ratio, hicache_size, page_size
|
37
39
|
)
|
38
40
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
39
41
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
40
|
-
self.kv_cache, hicache_ratio, page_size
|
42
|
+
self.kv_cache, hicache_ratio, hicache_size, page_size
|
41
43
|
)
|
42
44
|
else:
|
43
45
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
|
|
50
52
|
self.token_to_kv_pool_host,
|
51
53
|
page_size,
|
52
54
|
load_cache_event=self.load_cache_event,
|
55
|
+
write_policy=hicache_write_policy,
|
53
56
|
)
|
54
57
|
|
55
58
|
# record the nodes with ongoing write through
|
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
|
|
57
60
|
# record the node segments with ongoing load back
|
58
61
|
self.ongoing_load_back = {}
|
59
62
|
# todo: dynamically adjust the threshold
|
60
|
-
self.write_through_threshold =
|
63
|
+
self.write_through_threshold = (
|
64
|
+
1 if hicache_write_policy == "write_through" else 3
|
65
|
+
)
|
61
66
|
self.load_back_threshold = 10
|
62
67
|
super().__init__(
|
63
68
|
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
|
|
76
81
|
height += 1
|
77
82
|
return height
|
78
83
|
|
79
|
-
def write_backup(self, node: TreeNode):
|
84
|
+
def write_backup(self, node: TreeNode, write_back=False):
|
80
85
|
host_indices = self.cache_controller.write(
|
81
86
|
device_indices=node.value,
|
82
87
|
node_id=node.id,
|
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
|
|
90
95
|
if host_indices is not None:
|
91
96
|
node.host_value = host_indices
|
92
97
|
self.ongoing_write_through[node.id] = node
|
93
|
-
|
98
|
+
if not write_back:
|
99
|
+
# no need to lock nodes if write back
|
100
|
+
self.inc_lock_ref(node)
|
94
101
|
else:
|
95
102
|
return 0
|
96
103
|
|
97
104
|
return len(host_indices)
|
98
105
|
|
99
106
|
def inc_hit_count(self, node: TreeNode):
|
100
|
-
if self.cache_controller.write_policy
|
107
|
+
if node.backuped or self.cache_controller.write_policy == "write_back":
|
101
108
|
return
|
102
109
|
node.hit_count += 1
|
103
|
-
if node.
|
110
|
+
if node.hit_count >= self.write_through_threshold:
|
104
111
|
self.write_backup(node)
|
105
112
|
node.hit_count = 0
|
106
113
|
|
107
|
-
def writing_check(self):
|
114
|
+
def writing_check(self, write_back=False):
|
115
|
+
if write_back:
|
116
|
+
# blocking till all write back complete
|
117
|
+
while len(self.ongoing_write_through) > 0:
|
118
|
+
ack_id = self.cache_controller.ack_write_queue.get()
|
119
|
+
del self.ongoing_write_through[ack_id]
|
120
|
+
return
|
108
121
|
queue_size = torch.tensor(
|
109
122
|
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
110
123
|
)
|
@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
|
|
143
156
|
heapq.heapify(leaves)
|
144
157
|
|
145
158
|
num_evicted = 0
|
146
|
-
|
159
|
+
write_back_nodes = []
|
147
160
|
while num_evicted < num_tokens and len(leaves):
|
148
161
|
x = heapq.heappop(leaves)
|
149
162
|
|
150
163
|
if x.lock_ref > 0:
|
151
164
|
continue
|
152
165
|
|
153
|
-
if x.
|
166
|
+
if not x.backuped:
|
154
167
|
if self.cache_controller.write_policy == "write_back":
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
num_evicted += self._evict_write_through_selective(x)
|
168
|
+
# write to host if the node is not backuped
|
169
|
+
num_evicted += self.write_backup(x, write_back=True)
|
170
|
+
write_back_nodes.append(x)
|
159
171
|
else:
|
160
|
-
|
161
|
-
self.cache_controller.write_policy != "write_through"
|
162
|
-
), "write_through should be inclusive"
|
163
|
-
raise NotImplementedError
|
172
|
+
num_evicted += self._evict_regular(x)
|
164
173
|
else:
|
165
|
-
num_evicted += self.
|
174
|
+
num_evicted += self._evict_backuped(x)
|
166
175
|
|
167
176
|
for child in x.parent.children.values():
|
168
|
-
if child in
|
177
|
+
if child in write_back_nodes:
|
169
178
|
continue
|
170
179
|
if not child.evicted:
|
171
180
|
break
|
@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
|
|
174
183
|
heapq.heappush(leaves, x.parent)
|
175
184
|
|
176
185
|
if self.cache_controller.write_policy == "write_back":
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
for node in pending_nodes:
|
182
|
-
assert node.host_value is not None
|
183
|
-
self._evict_write_through(node)
|
186
|
+
self.writing_check(write_back=True)
|
187
|
+
for node in write_back_nodes:
|
188
|
+
assert node.backuped
|
189
|
+
self._evict_backuped(node)
|
184
190
|
|
185
|
-
def
|
191
|
+
def _evict_backuped(self, node: TreeNode):
|
186
192
|
# evict a node already written to host
|
187
193
|
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
|
188
194
|
assert num_evicted > 0
|
@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
|
|
190
196
|
node.value = None
|
191
197
|
return num_evicted
|
192
198
|
|
193
|
-
def
|
199
|
+
def _evict_regular(self, node: TreeNode):
|
194
200
|
# evict a node not initiated write to host
|
195
201
|
self.cache_controller.mem_pool_device_allocator.free(node.value)
|
196
202
|
num_evicted = len(node.value)
|
@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
|
|
339
345
|
prefix_len = self.key_match_fn(child.key, key)
|
340
346
|
if prefix_len < len(child.key):
|
341
347
|
new_node = self._split_node(child.key, child, prefix_len)
|
348
|
+
self.inc_hit_count(new_node)
|
342
349
|
if not new_node.evicted:
|
343
350
|
value.append(new_node.value)
|
344
351
|
node = new_node
|
345
352
|
break
|
346
353
|
else:
|
354
|
+
self.inc_hit_count(child)
|
347
355
|
if not child.evicted:
|
348
356
|
value.append(child.value)
|
349
357
|
node = child
|
@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
|
|
369
377
|
else:
|
370
378
|
new_node.value = child.value[:split_len]
|
371
379
|
child.value = child.value[split_len:]
|
372
|
-
if child.
|
380
|
+
if child.backuped:
|
373
381
|
new_node.host_value = child.host_value[:split_len]
|
374
382
|
child.host_value = child.host_value[split_len:]
|
375
383
|
child.parent = new_node
|
@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
|
|
426
434
|
node.children[child_key] = new_node
|
427
435
|
self.evictable_size_ += len(value)
|
428
436
|
|
429
|
-
if self.cache_controller.write_policy
|
430
|
-
self.
|
437
|
+
if self.cache_controller.write_policy != "write_back":
|
438
|
+
self.inc_hit_count(new_node)
|
431
439
|
return total_prefix_length
|
432
440
|
|
433
441
|
def _collect_leaves_device(self):
|
@@ -446,13 +446,28 @@ class MLATokenToKVPool(KVCache):
|
|
446
446
|
]
|
447
447
|
|
448
448
|
self.layer_transfer_counter = None
|
449
|
+
self.page_size = page_size
|
450
|
+
|
451
|
+
kv_size = self.get_kv_size_bytes()
|
452
|
+
logger.info(
|
453
|
+
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
454
|
+
)
|
455
|
+
|
456
|
+
def get_kv_size_bytes(self):
|
457
|
+
assert hasattr(self, "kv_buffer")
|
458
|
+
kv_size_bytes = 0
|
459
|
+
for kv_cache in self.kv_buffer:
|
460
|
+
kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
|
461
|
+
return kv_size_bytes
|
449
462
|
|
450
463
|
# for disagg
|
451
464
|
def get_contiguous_buf_infos(self):
|
452
465
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
453
466
|
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
454
467
|
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
455
|
-
kv_item_lens = [
|
468
|
+
kv_item_lens = [
|
469
|
+
self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
|
470
|
+
]
|
456
471
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
457
472
|
|
458
473
|
def get_key_buffer(self, layer_id: int):
|
@@ -621,26 +636,27 @@ class HostKVCache(abc.ABC):
|
|
621
636
|
self,
|
622
637
|
device_pool: MHATokenToKVPool,
|
623
638
|
host_to_device_ratio: float,
|
639
|
+
host_size: int,
|
624
640
|
pin_memory: bool,
|
625
641
|
device: str,
|
626
642
|
page_size: int,
|
627
643
|
):
|
628
|
-
assert (
|
629
|
-
host_to_device_ratio >= 1
|
630
|
-
), "The host memory should be larger than the device memory with the current protocol"
|
631
|
-
# todo, other ways of configuring the size
|
632
|
-
|
633
644
|
self.device_pool = device_pool
|
634
|
-
self.
|
645
|
+
self.dtype = device_pool.store_dtype
|
635
646
|
self.pin_memory = pin_memory
|
636
647
|
self.device = device
|
637
648
|
self.page_size = page_size
|
638
|
-
|
639
|
-
|
649
|
+
self.size_per_token = self.get_size_per_token()
|
650
|
+
if host_size > 0:
|
651
|
+
self.size = int(host_size * 1e9 // self.size_per_token)
|
652
|
+
else:
|
653
|
+
self.size = int(device_pool.size * host_to_device_ratio)
|
640
654
|
# Align the host memory pool size to the page size
|
641
655
|
self.size = self.size - (self.size % self.page_size)
|
642
|
-
|
643
|
-
|
656
|
+
|
657
|
+
assert (
|
658
|
+
self.size > device_pool.size
|
659
|
+
), "The host memory should be larger than the device memory with the current protocol"
|
644
660
|
|
645
661
|
# Verify there is enough available host memory.
|
646
662
|
host_mem = psutil.virtual_memory()
|
@@ -792,12 +808,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
792
808
|
self,
|
793
809
|
device_pool: MHATokenToKVPool,
|
794
810
|
host_to_device_ratio: float,
|
811
|
+
host_size: int,
|
795
812
|
page_size: int,
|
796
813
|
pin_memory: bool = True,
|
797
814
|
device: str = "cpu",
|
798
815
|
):
|
799
816
|
super().__init__(
|
800
|
-
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
817
|
+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
801
818
|
)
|
802
819
|
|
803
820
|
def get_size_per_token(self):
|
@@ -866,12 +883,13 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
866
883
|
self,
|
867
884
|
device_pool: MLATokenToKVPool,
|
868
885
|
host_to_device_ratio: float,
|
886
|
+
host_size: int,
|
869
887
|
page_size: int,
|
870
888
|
pin_memory: bool = True,
|
871
889
|
device: str = "cpu",
|
872
890
|
):
|
873
891
|
super().__init__(
|
874
|
-
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
892
|
+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
875
893
|
)
|
876
894
|
|
877
895
|
def get_size_per_token(self):
|
@@ -35,7 +35,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
35
35
|
ForwardMode,
|
36
36
|
)
|
37
37
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
38
|
-
from sglang.srt.utils import
|
38
|
+
from sglang.srt.utils import (
|
39
|
+
get_available_gpu_memory,
|
40
|
+
get_device_memory_capacity,
|
41
|
+
is_hip,
|
42
|
+
)
|
39
43
|
|
40
44
|
if TYPE_CHECKING:
|
41
45
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -129,7 +133,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
129
133
|
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
|
130
134
|
)
|
131
135
|
|
132
|
-
|
136
|
+
gpu_mem = get_device_memory_capacity()
|
137
|
+
if gpu_mem is not None and gpu_mem > 81920:
|
133
138
|
capture_bs += list(range(160, 257, 8))
|
134
139
|
|
135
140
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
@@ -140,12 +145,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
140
145
|
]
|
141
146
|
|
142
147
|
capture_bs = list(sorted(set(capture_bs)))
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
]
|
148
|
+
|
149
|
+
assert len(capture_bs) > 0 and capture_bs[0] > 0
|
150
|
+
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
151
|
+
if server_args.cuda_graph_max_bs:
|
152
|
+
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
149
153
|
compile_bs = (
|
150
154
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
151
155
|
if server_args.enable_torch_compile
|
@@ -186,6 +190,7 @@ class CudaGraphRunner:
|
|
186
190
|
|
187
191
|
# Batch sizes to capture
|
188
192
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
193
|
+
|
189
194
|
self.capture_forward_mode = ForwardMode.DECODE
|
190
195
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
191
196
|
self.num_tokens_per_bs = 1
|