sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -1058,12 +1058,7 @@ class TokenizerManager:
|
|
1058
1058
|
"lora_path",
|
1059
1059
|
]
|
1060
1060
|
)
|
1061
|
-
out_skip_names = set(
|
1062
|
-
[
|
1063
|
-
"text",
|
1064
|
-
"output_ids",
|
1065
|
-
]
|
1066
|
-
)
|
1061
|
+
out_skip_names = set(["text", "output_ids", "embedding"])
|
1067
1062
|
elif self.log_requests_level == 1:
|
1068
1063
|
max_length = 2048
|
1069
1064
|
elif self.log_requests_level == 2:
|
@@ -1140,13 +1135,21 @@ class TokenizerManager:
|
|
1140
1135
|
remain_num_req = len(self.rid_to_state)
|
1141
1136
|
|
1142
1137
|
if self.health_check_failed:
|
1143
|
-
# if health check failed,
|
1138
|
+
# if health check failed, exit immediately
|
1144
1139
|
logger.error(
|
1145
1140
|
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
1146
1141
|
remain_num_req,
|
1147
1142
|
)
|
1148
1143
|
break
|
1149
1144
|
|
1145
|
+
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
1146
|
+
# if force shutdown flag set, exit immediately
|
1147
|
+
logger.error(
|
1148
|
+
"Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
|
1149
|
+
remain_num_req,
|
1150
|
+
)
|
1151
|
+
break
|
1152
|
+
|
1150
1153
|
logger.info(
|
1151
1154
|
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
1152
1155
|
)
|
@@ -1223,7 +1226,7 @@ class TokenizerManager:
|
|
1223
1226
|
state.last_output_offset = len(state.output_ids)
|
1224
1227
|
else:
|
1225
1228
|
state.output_ids.extend(recv_obj.output_ids[i])
|
1226
|
-
output_token_ids = state.output_ids
|
1229
|
+
output_token_ids = state.output_ids.copy()
|
1227
1230
|
|
1228
1231
|
out_dict = {
|
1229
1232
|
"output_ids": output_token_ids,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
|
|
35
35
|
UpdateWeightsFromTensorReqInput,
|
36
36
|
)
|
37
37
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
38
|
-
from sglang.srt.mem_cache.
|
38
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
39
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
39
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
40
41
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
41
42
|
from sglang.srt.server_args import ServerArgs
|
@@ -57,7 +58,7 @@ class TpModelWorker:
|
|
57
58
|
nccl_port: int,
|
58
59
|
is_draft_worker: bool = False,
|
59
60
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
60
|
-
token_to_kv_pool_allocator: Optional[
|
61
|
+
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
61
62
|
):
|
62
63
|
# Parse args
|
63
64
|
self.tp_size = server_args.tp_size
|
@@ -147,6 +148,15 @@ class TpModelWorker:
|
|
147
148
|
# A reference make this class has the same member as TpModelWorkerClient
|
148
149
|
self.worker = self
|
149
150
|
|
151
|
+
self.hicache_layer_transfer_counter = None
|
152
|
+
|
153
|
+
def register_hicache_layer_transfer_counter(self, counter):
|
154
|
+
self.hicache_layer_transfer_counter = counter
|
155
|
+
|
156
|
+
def set_hicache_consumer(self, consumer_index):
|
157
|
+
if self.hicache_layer_transfer_counter is not None:
|
158
|
+
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
159
|
+
|
150
160
|
def get_worker_info(self):
|
151
161
|
return (
|
152
162
|
self.max_total_num_tokens,
|
@@ -88,6 +88,15 @@ class TpModelWorkerClient:
|
|
88
88
|
if self.device == "cpu":
|
89
89
|
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
90
90
|
|
91
|
+
self.hicache_layer_transfer_counter = None
|
92
|
+
|
93
|
+
def register_hicache_layer_transfer_counter(self, counter):
|
94
|
+
self.hicache_layer_transfer_counter = counter
|
95
|
+
|
96
|
+
def set_hicache_consumer(self, consumer_index):
|
97
|
+
if self.hicache_layer_transfer_counter is not None:
|
98
|
+
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
99
|
+
|
91
100
|
def get_worker_info(self):
|
92
101
|
return self.worker.get_worker_info()
|
93
102
|
|
@@ -146,6 +155,8 @@ class TpModelWorkerClient:
|
|
146
155
|
input_ids = model_worker_batch.input_ids
|
147
156
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
148
157
|
|
158
|
+
# update the consumer index of hicache to the running batch
|
159
|
+
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
149
160
|
# Run forward
|
150
161
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
151
162
|
self.worker.forward_batch_generation(
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2025 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -17,13 +19,132 @@ limitations under the License.
|
|
17
19
|
Page-aligned memory pool.
|
18
20
|
"""
|
19
21
|
|
22
|
+
import abc
|
23
|
+
from typing import TYPE_CHECKING
|
24
|
+
|
20
25
|
import torch
|
21
26
|
import triton
|
22
27
|
import triton.language as tl
|
23
28
|
|
24
|
-
from sglang.srt.mem_cache.memory_pool import KVCache
|
25
29
|
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
26
30
|
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from sglang.srt.mem_cache.memory_pool import KVCache
|
33
|
+
|
34
|
+
|
35
|
+
class BaseTokenToKVPoolAllocator(abc.ABC):
|
36
|
+
@abc.abstractmethod
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
size: int,
|
40
|
+
page_size: int,
|
41
|
+
dtype: torch.dtype,
|
42
|
+
device: str,
|
43
|
+
kvcache: KVCache,
|
44
|
+
):
|
45
|
+
self.size = size
|
46
|
+
self.page_size = page_size
|
47
|
+
self.dtype = dtype
|
48
|
+
self.device = device
|
49
|
+
self._kvcache = kvcache
|
50
|
+
|
51
|
+
self.free_pages = None
|
52
|
+
self.is_not_in_free_group = True
|
53
|
+
self.free_group = []
|
54
|
+
|
55
|
+
def debug_print(self) -> str:
|
56
|
+
return ""
|
57
|
+
|
58
|
+
def available_size(self):
|
59
|
+
return len(self.free_pages) * self.page_size
|
60
|
+
|
61
|
+
def get_kvcache(self):
|
62
|
+
return self._kvcache
|
63
|
+
|
64
|
+
def restore_state(self, free_pages):
|
65
|
+
self.free_pages = free_pages
|
66
|
+
|
67
|
+
def backup_state(self):
|
68
|
+
return self.free_pages
|
69
|
+
|
70
|
+
def free_group_begin(self):
|
71
|
+
self.is_not_in_free_group = False
|
72
|
+
self.free_group = []
|
73
|
+
|
74
|
+
def free_group_end(self):
|
75
|
+
self.is_not_in_free_group = True
|
76
|
+
if self.free_group:
|
77
|
+
self.free(torch.cat(self.free_group))
|
78
|
+
|
79
|
+
def get_cpu_copy(self, *args, **kwargs):
|
80
|
+
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
|
81
|
+
raise NotImplementedError()
|
82
|
+
|
83
|
+
def load_cpu_copy(self, *args, **kwargs):
|
84
|
+
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
|
85
|
+
raise NotImplementedError()
|
86
|
+
|
87
|
+
def alloc_extend(self, *args, **kwargs):
|
88
|
+
raise NotImplementedError("alloc_extend is only for paged allocator")
|
89
|
+
|
90
|
+
def alloc_decode(self, *args, **kwargs):
|
91
|
+
raise NotImplementedError("alloc_decode is only for paged allocator")
|
92
|
+
|
93
|
+
@abc.abstractmethod
|
94
|
+
def clear(self):
|
95
|
+
raise NotImplementedError()
|
96
|
+
|
97
|
+
@abc.abstractmethod
|
98
|
+
def alloc(self, need_size: int):
|
99
|
+
raise NotImplementedError()
|
100
|
+
|
101
|
+
@abc.abstractmethod
|
102
|
+
def free(self, free_index: torch.Tensor):
|
103
|
+
raise NotImplementedError()
|
104
|
+
|
105
|
+
|
106
|
+
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
107
|
+
"""An allocator managing the indices to kv cache data."""
|
108
|
+
|
109
|
+
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
|
110
|
+
super().__init__(size, 1, dtype, device, kvcache)
|
111
|
+
self.clear()
|
112
|
+
|
113
|
+
def clear(self):
|
114
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
115
|
+
self.free_pages = torch.arange(
|
116
|
+
1, self.size + 1, dtype=torch.int64, device=self.device
|
117
|
+
)
|
118
|
+
self.is_not_in_free_group = True
|
119
|
+
self.free_group = []
|
120
|
+
|
121
|
+
def available_size(self):
|
122
|
+
# To avoid minor "len(free_pages) * 1" overhead
|
123
|
+
return len(self.free_pages)
|
124
|
+
|
125
|
+
def alloc(self, need_size: int):
|
126
|
+
if need_size > len(self.free_pages):
|
127
|
+
return None
|
128
|
+
|
129
|
+
select_index = self.free_pages[:need_size]
|
130
|
+
self.free_pages = self.free_pages[need_size:]
|
131
|
+
return select_index
|
132
|
+
|
133
|
+
def free(self, free_index: torch.Tensor):
|
134
|
+
if free_index.numel() == 0:
|
135
|
+
return
|
136
|
+
|
137
|
+
if self.is_not_in_free_group:
|
138
|
+
self.free_pages = torch.cat((self.free_pages, free_index))
|
139
|
+
else:
|
140
|
+
self.free_group.append(free_index)
|
141
|
+
|
142
|
+
def get_cpu_copy(self, indices):
|
143
|
+
return self._kvcache.get_cpu_copy(indices)
|
144
|
+
|
145
|
+
def load_cpu_copy(self, kv_cache_cpu, indices):
|
146
|
+
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
147
|
+
|
27
148
|
|
28
149
|
@triton.jit
|
29
150
|
def alloc_extend_kernel(
|
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
|
|
154
275
|
tl.store(out_indices + pid, page * page_size)
|
155
276
|
|
156
277
|
|
157
|
-
class PagedTokenToKVPoolAllocator:
|
278
|
+
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
158
279
|
"""
|
159
280
|
An allocator managing the indices to kv cache data.
|
160
281
|
|
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
|
|
172
293
|
device: str,
|
173
294
|
kvcache: KVCache,
|
174
295
|
):
|
175
|
-
|
176
|
-
self.dtype = dtype
|
177
|
-
self.device = device
|
178
|
-
self.page_size = page_size
|
296
|
+
super().__init__(size, page_size, dtype, device, kvcache)
|
179
297
|
self.num_pages = size // page_size
|
180
|
-
|
181
|
-
self.free_pages = None
|
182
|
-
self.is_not_in_free_group = True
|
183
|
-
self.free_group = []
|
184
|
-
self.clear()
|
185
298
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
186
|
-
|
187
|
-
self._kvcache = kvcache
|
188
299
|
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
189
|
-
|
190
|
-
def available_size(self):
|
191
|
-
return len(self.free_pages) * self.page_size
|
192
|
-
|
193
|
-
def get_kvcache(self):
|
194
|
-
return self._kvcache
|
300
|
+
self.clear()
|
195
301
|
|
196
302
|
def alloc(self, need_size: int):
|
197
303
|
# page-aligned allocation, returning contiguous indices of pages
|
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
|
|
298
404
|
if self.debug_mode:
|
299
405
|
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
300
406
|
|
301
|
-
def free_group_begin(self):
|
302
|
-
self.is_not_in_free_group = False
|
303
|
-
self.free_group = []
|
304
|
-
|
305
|
-
def free_group_end(self):
|
306
|
-
self.is_not_in_free_group = True
|
307
|
-
if self.free_group:
|
308
|
-
self.free(torch.cat(self.free_group))
|
309
|
-
|
310
|
-
def backup_state(self):
|
311
|
-
return self.free_pages
|
312
|
-
|
313
|
-
def restore_state(self, free_pages):
|
314
|
-
self.free_pages = free_pages
|
315
|
-
|
316
407
|
def clear(self):
|
317
408
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
318
409
|
self.free_pages = torch.arange(
|
@@ -1,5 +1,31 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Any, List, Tuple
|
2
|
+
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from sglang.srt.managers.schedule_batch import Req
|
8
|
+
else:
|
9
|
+
Req = Any # Placeholder for Req type when not type checking
|
10
|
+
|
11
|
+
|
12
|
+
class MatchResult(NamedTuple):
|
13
|
+
"""Result of a prefix match operation.
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
device_indices : Indices of the KV cache on the device matched by common prefix.
|
17
|
+
last_device_node: The last TreeNode on the device that was matched.
|
18
|
+
last_host_node : The last TreeNode on the host that was matched.
|
19
|
+
Note that if HiCache is not enabled,
|
20
|
+
this **must** be the same as `last_device_node`.
|
21
|
+
host_hit_length : Length of the KV cache hit on the host, if applicable.
|
22
|
+
0 if HiCache is not enabled.
|
23
|
+
"""
|
24
|
+
|
25
|
+
device_indices: torch.Tensor
|
26
|
+
last_device_node: Any
|
27
|
+
last_host_node: Any
|
28
|
+
host_hit_length: int = 0
|
3
29
|
|
4
30
|
|
5
31
|
class BasePrefixCache(ABC):
|
@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
|
|
10
36
|
pass
|
11
37
|
|
12
38
|
@abstractmethod
|
13
|
-
def match_prefix(self,
|
39
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
14
40
|
pass
|
15
41
|
|
16
42
|
@abstractmethod
|
17
|
-
def
|
43
|
+
def cache_finished_req(self, req: Req, **kwargs):
|
18
44
|
pass
|
19
45
|
|
20
46
|
@abstractmethod
|
21
|
-
def
|
22
|
-
pass
|
23
|
-
|
24
|
-
@abstractmethod
|
25
|
-
def cache_unfinished_req(self, **kwargs):
|
47
|
+
def cache_unfinished_req(self, req: Req, **kwargs):
|
26
48
|
pass
|
27
49
|
|
28
50
|
@abstractmethod
|
@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
|
|
49
71
|
def pretty_print(self):
|
50
72
|
raise NotImplementedError()
|
51
73
|
|
74
|
+
def init_load_back(
|
75
|
+
self,
|
76
|
+
last_host_node: Any,
|
77
|
+
host_hit_length: int,
|
78
|
+
) -> Tuple[torch.Tensor, Any]:
|
79
|
+
"""
|
80
|
+
Preparing KV cache loading from host to device.
|
81
|
+
"""
|
82
|
+
raise NotImplementedError()
|
83
|
+
|
84
|
+
def ready_to_load_host_cache(self) -> Any:
|
85
|
+
"""
|
86
|
+
Notify the cache controller to start the KV cache loading
|
87
|
+
"""
|
88
|
+
raise NotImplementedError()
|
89
|
+
|
90
|
+
def check_hicache_events(self) -> Any:
|
91
|
+
"""
|
92
|
+
Check HiCache related activities to update radix tree and synchronize across TP workers if needed
|
93
|
+
"""
|
94
|
+
raise NotImplementedError()
|
95
|
+
|
52
96
|
def take_events(self):
|
53
97
|
return []
|
@@ -2,40 +2,38 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
4
4
|
|
5
|
-
from typing import TYPE_CHECKING, Any
|
5
|
+
from typing import TYPE_CHECKING, Any
|
6
6
|
|
7
7
|
import torch
|
8
8
|
|
9
|
-
from sglang.srt.mem_cache.
|
10
|
-
from sglang.srt.mem_cache.
|
9
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
10
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
11
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
11
12
|
|
12
13
|
if TYPE_CHECKING:
|
13
14
|
from sglang.srt.managers.schedule_batch import Req
|
14
15
|
|
15
16
|
|
16
|
-
class ChunkCacheEntry:
|
17
|
-
def __init__(self, rid: str, value: torch.Tensor):
|
18
|
-
self.rid = rid
|
19
|
-
self.value = value
|
20
|
-
|
21
|
-
|
22
17
|
class ChunkCache(BasePrefixCache):
|
23
18
|
def __init__(
|
24
19
|
self,
|
25
20
|
req_to_token_pool: ReqToTokenPool,
|
26
|
-
token_to_kv_pool_allocator:
|
21
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
27
22
|
page_size: int,
|
28
23
|
):
|
29
24
|
self.req_to_token_pool = req_to_token_pool
|
30
25
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
31
26
|
self.page_size = page_size
|
32
|
-
self.disable = True
|
33
27
|
|
34
28
|
def reset(self):
|
35
29
|
pass
|
36
30
|
|
37
|
-
def match_prefix(self, **unused_kwargs) ->
|
38
|
-
return
|
31
|
+
def match_prefix(self, **unused_kwargs) -> MatchResult:
|
32
|
+
return MatchResult(
|
33
|
+
device_indices=torch.empty((0,), dtype=torch.int64),
|
34
|
+
last_device_node=None,
|
35
|
+
last_host_node=None,
|
36
|
+
)
|
39
37
|
|
40
38
|
def cache_finished_req(self, req: Req):
|
41
39
|
kv_indices = self.req_to_token_pool.req_to_token[
|
@@ -54,9 +52,6 @@ class ChunkCache(BasePrefixCache):
|
|
54
52
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
55
53
|
req.prefix_indices = kv_indices
|
56
54
|
|
57
|
-
def insert(self):
|
58
|
-
raise NotImplementedError()
|
59
|
-
|
60
55
|
def evict(self, num_tokens: int):
|
61
56
|
pass
|
62
57
|
|
@@ -7,11 +7,12 @@ from typing import List, Optional
|
|
7
7
|
import torch
|
8
8
|
|
9
9
|
from sglang.srt.managers.cache_controller import HiCacheController
|
10
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
11
|
+
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
10
12
|
from sglang.srt.mem_cache.memory_pool import (
|
11
13
|
MHATokenToKVPool,
|
12
14
|
MLATokenToKVPool,
|
13
15
|
ReqToTokenPool,
|
14
|
-
TokenToKVPoolAllocator,
|
15
16
|
)
|
16
17
|
from sglang.srt.mem_cache.memory_pool_host import (
|
17
18
|
MHATokenToKVPoolHost,
|
@@ -27,7 +28,7 @@ class HiRadixCache(RadixCache):
|
|
27
28
|
def __init__(
|
28
29
|
self,
|
29
30
|
req_to_token_pool: ReqToTokenPool,
|
30
|
-
token_to_kv_pool_allocator:
|
31
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
31
32
|
tp_cache_group: torch.distributed.ProcessGroup,
|
32
33
|
page_size: int,
|
33
34
|
hicache_ratio: float,
|
@@ -283,39 +284,44 @@ class HiRadixCache(RadixCache):
|
|
283
284
|
def init_load_back(
|
284
285
|
self,
|
285
286
|
last_node: TreeNode,
|
286
|
-
|
287
|
+
host_hit_length: int,
|
287
288
|
mem_quota: Optional[int] = None,
|
288
289
|
):
|
289
|
-
|
290
|
-
len(prefix_indices) == 0 or prefix_indices.is_cuda
|
291
|
-
), "indices of device kV caches should be on GPU"
|
290
|
+
_ = host_hit_length # unused, but kept for compatibility
|
292
291
|
if last_node.evicted:
|
293
292
|
loading_values = self.load_back(last_node, mem_quota)
|
294
293
|
if loading_values is not None:
|
295
|
-
prefix_indices = (
|
296
|
-
loading_values
|
297
|
-
if len(prefix_indices) == 0
|
298
|
-
else torch.cat([prefix_indices, loading_values])
|
299
|
-
)
|
300
294
|
logger.debug(
|
301
295
|
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
302
296
|
)
|
297
|
+
return loading_values, last_node
|
303
298
|
|
304
299
|
while last_node.evicted:
|
305
300
|
last_node = last_node.parent
|
306
301
|
|
307
|
-
return
|
302
|
+
return (
|
303
|
+
torch.empty((0,), dtype=torch.int64, device=self.device),
|
304
|
+
last_node,
|
305
|
+
)
|
308
306
|
|
309
|
-
def
|
307
|
+
def ready_to_load_host_cache(self):
|
308
|
+
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
310
309
|
self.load_cache_event.set()
|
310
|
+
return producer_index
|
311
311
|
|
312
|
-
def
|
312
|
+
def check_hicache_events(self):
|
313
|
+
self.writing_check()
|
314
|
+
self.loading_check()
|
315
|
+
|
316
|
+
def match_prefix(self, key: List[int], **kwargs):
|
313
317
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
314
318
|
if self.disable or len(key) == 0:
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
+
return MatchResult(
|
320
|
+
device_indices=empty_value,
|
321
|
+
last_device_node=self.root_node,
|
322
|
+
last_host_node=self.root_node,
|
323
|
+
host_hit_length=0,
|
324
|
+
)
|
319
325
|
|
320
326
|
if self.page_size != 1:
|
321
327
|
page_aligned_len = len(key) // self.page_size * self.page_size
|
@@ -327,14 +333,18 @@ class HiRadixCache(RadixCache):
|
|
327
333
|
else:
|
328
334
|
value = empty_value
|
329
335
|
|
330
|
-
|
336
|
+
host_hit_length = 0
|
337
|
+
last_host_node = last_node
|
331
338
|
while last_node.evicted:
|
339
|
+
host_hit_length += len(last_node.host_value)
|
332
340
|
last_node = last_node.parent
|
333
341
|
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
342
|
+
return MatchResult(
|
343
|
+
device_indices=value,
|
344
|
+
last_device_node=last_node,
|
345
|
+
last_host_node=last_host_node,
|
346
|
+
host_hit_length=host_hit_length,
|
347
|
+
)
|
338
348
|
|
339
349
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
340
350
|
node.last_access_time = time.monotonic()
|
@@ -372,6 +382,7 @@ class HiRadixCache(RadixCache):
|
|
372
382
|
new_node.lock_ref = child.lock_ref
|
373
383
|
new_node.key = child.key[:split_len]
|
374
384
|
new_node.loading = child.loading
|
385
|
+
new_node.hit_count = child.hit_count
|
375
386
|
|
376
387
|
# split value and host value if exists
|
377
388
|
if child.evicted:
|