sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -7,11 +7,12 @@ from typing import List, Optional
|
|
7
7
|
import torch
|
8
8
|
|
9
9
|
from sglang.srt.managers.cache_controller import HiCacheController
|
10
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
11
|
+
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
10
12
|
from sglang.srt.mem_cache.memory_pool import (
|
11
13
|
MHATokenToKVPool,
|
12
14
|
MLATokenToKVPool,
|
13
15
|
ReqToTokenPool,
|
14
|
-
TokenToKVPoolAllocator,
|
15
16
|
)
|
16
17
|
from sglang.srt.mem_cache.memory_pool_host import (
|
17
18
|
MHATokenToKVPoolHost,
|
@@ -27,7 +28,7 @@ class HiRadixCache(RadixCache):
|
|
27
28
|
def __init__(
|
28
29
|
self,
|
29
30
|
req_to_token_pool: ReqToTokenPool,
|
30
|
-
token_to_kv_pool_allocator:
|
31
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
31
32
|
tp_cache_group: torch.distributed.ProcessGroup,
|
32
33
|
page_size: int,
|
33
34
|
hicache_ratio: float,
|
@@ -283,39 +284,44 @@ class HiRadixCache(RadixCache):
|
|
283
284
|
def init_load_back(
|
284
285
|
self,
|
285
286
|
last_node: TreeNode,
|
286
|
-
|
287
|
+
host_hit_length: int,
|
287
288
|
mem_quota: Optional[int] = None,
|
288
289
|
):
|
289
|
-
|
290
|
-
len(prefix_indices) == 0 or prefix_indices.is_cuda
|
291
|
-
), "indices of device kV caches should be on GPU"
|
290
|
+
_ = host_hit_length # unused, but kept for compatibility
|
292
291
|
if last_node.evicted:
|
293
292
|
loading_values = self.load_back(last_node, mem_quota)
|
294
293
|
if loading_values is not None:
|
295
|
-
prefix_indices = (
|
296
|
-
loading_values
|
297
|
-
if len(prefix_indices) == 0
|
298
|
-
else torch.cat([prefix_indices, loading_values])
|
299
|
-
)
|
300
294
|
logger.debug(
|
301
295
|
f"loading back {len(loading_values)} tokens for node {last_node.id}"
|
302
296
|
)
|
297
|
+
return loading_values, last_node
|
303
298
|
|
304
299
|
while last_node.evicted:
|
305
300
|
last_node = last_node.parent
|
306
301
|
|
307
|
-
return
|
302
|
+
return (
|
303
|
+
torch.empty((0,), dtype=torch.int64, device=self.device),
|
304
|
+
last_node,
|
305
|
+
)
|
308
306
|
|
309
|
-
def
|
307
|
+
def ready_to_load_host_cache(self):
|
308
|
+
producer_index = self.cache_controller.layer_done_counter.next_producer()
|
310
309
|
self.load_cache_event.set()
|
310
|
+
return producer_index
|
311
311
|
|
312
|
-
def
|
312
|
+
def check_hicache_events(self):
|
313
|
+
self.writing_check()
|
314
|
+
self.loading_check()
|
315
|
+
|
316
|
+
def match_prefix(self, key: List[int], **kwargs):
|
313
317
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
314
318
|
if self.disable or len(key) == 0:
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
+
return MatchResult(
|
320
|
+
device_indices=empty_value,
|
321
|
+
last_device_node=self.root_node,
|
322
|
+
last_host_node=self.root_node,
|
323
|
+
host_hit_length=0,
|
324
|
+
)
|
319
325
|
|
320
326
|
if self.page_size != 1:
|
321
327
|
page_aligned_len = len(key) // self.page_size * self.page_size
|
@@ -327,14 +333,18 @@ class HiRadixCache(RadixCache):
|
|
327
333
|
else:
|
328
334
|
value = empty_value
|
329
335
|
|
330
|
-
|
336
|
+
host_hit_length = 0
|
337
|
+
last_host_node = last_node
|
331
338
|
while last_node.evicted:
|
339
|
+
host_hit_length += len(last_node.host_value)
|
332
340
|
last_node = last_node.parent
|
333
341
|
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
342
|
+
return MatchResult(
|
343
|
+
device_indices=value,
|
344
|
+
last_device_node=last_node,
|
345
|
+
last_host_node=last_host_node,
|
346
|
+
host_hit_length=host_hit_length,
|
347
|
+
)
|
338
348
|
|
339
349
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
340
350
|
node.last_access_time = time.monotonic()
|
@@ -372,6 +382,7 @@ class HiRadixCache(RadixCache):
|
|
372
382
|
new_node.lock_ref = child.lock_ref
|
373
383
|
new_node.key = child.key[:split_len]
|
374
384
|
new_node.loading = child.loading
|
385
|
+
new_node.hit_count = child.hit_count
|
375
386
|
|
376
387
|
# split value and host value if exists
|
377
388
|
if child.evicted:
|
@@ -26,6 +26,7 @@ KVCache actually holds the physical kv cache.
|
|
26
26
|
|
27
27
|
import abc
|
28
28
|
import logging
|
29
|
+
from contextlib import nullcontext
|
29
30
|
from typing import List, Optional, Tuple, Union
|
30
31
|
|
31
32
|
import numpy as np
|
@@ -33,8 +34,9 @@ import torch
|
|
33
34
|
import triton
|
34
35
|
import triton.language as tl
|
35
36
|
|
37
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
36
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
|
-
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
|
39
|
+
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
38
40
|
|
39
41
|
logger = logging.getLogger(__name__)
|
40
42
|
|
@@ -52,6 +54,7 @@ class ReqToTokenPool:
|
|
52
54
|
device: str,
|
53
55
|
enable_memory_saver: bool,
|
54
56
|
):
|
57
|
+
|
55
58
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
56
59
|
enable=enable_memory_saver
|
57
60
|
)
|
@@ -59,7 +62,7 @@ class ReqToTokenPool:
|
|
59
62
|
self.size = size
|
60
63
|
self.max_context_len = max_context_len
|
61
64
|
self.device = device
|
62
|
-
with memory_saver_adapter.region():
|
65
|
+
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
63
66
|
self.req_to_token = torch.zeros(
|
64
67
|
(size, max_context_len), dtype=torch.int32, device=device
|
65
68
|
)
|
@@ -119,6 +122,9 @@ class KVCache(abc.ABC):
|
|
119
122
|
enable=enable_memory_saver
|
120
123
|
)
|
121
124
|
|
125
|
+
# used for chunked cpu-offloading
|
126
|
+
self.cpu_offloading_chunk_size = 8192
|
127
|
+
|
122
128
|
@abc.abstractmethod
|
123
129
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
124
130
|
raise NotImplementedError()
|
@@ -153,83 +159,11 @@ class KVCache(abc.ABC):
|
|
153
159
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
154
160
|
self.layer_transfer_counter = layer_transfer_counter
|
155
161
|
|
156
|
-
|
157
|
-
class TokenToKVPoolAllocator:
|
158
|
-
"""An allocator managing the indices to kv cache data."""
|
159
|
-
|
160
|
-
def __init__(
|
161
|
-
self,
|
162
|
-
size: int,
|
163
|
-
dtype: torch.dtype,
|
164
|
-
device: str,
|
165
|
-
kvcache: KVCache,
|
166
|
-
):
|
167
|
-
self.size = size
|
168
|
-
self.dtype = dtype
|
169
|
-
self.device = device
|
170
|
-
self.page_size = 1
|
171
|
-
|
172
|
-
self.free_slots = None
|
173
|
-
self.is_not_in_free_group = True
|
174
|
-
self.free_group = []
|
175
|
-
self.clear()
|
176
|
-
|
177
|
-
self._kvcache = kvcache
|
178
|
-
|
179
|
-
def available_size(self):
|
180
|
-
return len(self.free_slots)
|
181
|
-
|
182
|
-
def debug_print(self) -> str:
|
183
|
-
return ""
|
184
|
-
|
185
|
-
def get_kvcache(self):
|
186
|
-
return self._kvcache
|
187
|
-
|
188
|
-
def alloc(self, need_size: int):
|
189
|
-
if need_size > len(self.free_slots):
|
190
|
-
return None
|
191
|
-
|
192
|
-
select_index = self.free_slots[:need_size]
|
193
|
-
self.free_slots = self.free_slots[need_size:]
|
194
|
-
return select_index
|
195
|
-
|
196
|
-
def free(self, free_index: torch.Tensor):
|
197
|
-
if free_index.numel() == 0:
|
198
|
-
return
|
199
|
-
|
200
|
-
if self.is_not_in_free_group:
|
201
|
-
self.free_slots = torch.cat((self.free_slots, free_index))
|
202
|
-
else:
|
203
|
-
self.free_group.append(free_index)
|
204
|
-
|
205
|
-
def free_group_begin(self):
|
206
|
-
self.is_not_in_free_group = False
|
207
|
-
self.free_group = []
|
208
|
-
|
209
|
-
def free_group_end(self):
|
210
|
-
self.is_not_in_free_group = True
|
211
|
-
if self.free_group:
|
212
|
-
self.free(torch.cat(self.free_group))
|
213
|
-
|
214
|
-
def backup_state(self):
|
215
|
-
return self.free_slots
|
216
|
-
|
217
|
-
def restore_state(self, free_slots):
|
218
|
-
self.free_slots = free_slots
|
219
|
-
|
220
|
-
def clear(self):
|
221
|
-
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
222
|
-
self.free_slots = torch.arange(
|
223
|
-
1, self.size + 1, dtype=torch.int64, device=self.device
|
224
|
-
)
|
225
|
-
self.is_not_in_free_group = True
|
226
|
-
self.free_group = []
|
227
|
-
|
228
162
|
def get_cpu_copy(self, indices):
|
229
|
-
|
163
|
+
raise NotImplementedError()
|
230
164
|
|
231
165
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
232
|
-
|
166
|
+
raise NotImplementedError()
|
233
167
|
|
234
168
|
|
235
169
|
class MHATokenToKVPool(KVCache):
|
@@ -260,10 +194,22 @@ class MHATokenToKVPool(KVCache):
|
|
260
194
|
|
261
195
|
self.head_num = head_num
|
262
196
|
self.head_dim = head_dim
|
197
|
+
|
198
|
+
# for disagg with nvlink
|
199
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
200
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
201
|
+
)
|
202
|
+
if self.enable_custom_mem_pool:
|
203
|
+
# TODO(shangming): abstract custom allocator class for more backends
|
204
|
+
from mooncake.allocator import NVLinkAllocator
|
205
|
+
|
206
|
+
allocator = NVLinkAllocator.get_allocator(self.device)
|
207
|
+
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
208
|
+
else:
|
209
|
+
self.custom_mem_pool = None
|
210
|
+
|
263
211
|
self._create_buffers()
|
264
212
|
|
265
|
-
# used for chunked cpu-offloading
|
266
|
-
self.chunk_size = 8192
|
267
213
|
self.layer_transfer_counter = None
|
268
214
|
self.device_module = torch.get_device_module(self.device)
|
269
215
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
@@ -274,25 +220,30 @@ class MHATokenToKVPool(KVCache):
|
|
274
220
|
)
|
275
221
|
|
276
222
|
def _create_buffers(self):
|
277
|
-
with self.memory_saver_adapter.region():
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
223
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
224
|
+
with (
|
225
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
226
|
+
if self.enable_custom_mem_pool
|
227
|
+
else nullcontext()
|
228
|
+
):
|
229
|
+
# [size, head_num, head_dim] for each layer
|
230
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
231
|
+
self.k_buffer = [
|
232
|
+
torch.zeros(
|
233
|
+
(self.size + self.page_size, self.head_num, self.head_dim),
|
234
|
+
dtype=self.store_dtype,
|
235
|
+
device=self.device,
|
236
|
+
)
|
237
|
+
for _ in range(self.layer_num)
|
238
|
+
]
|
239
|
+
self.v_buffer = [
|
240
|
+
torch.zeros(
|
241
|
+
(self.size + self.page_size, self.head_num, self.head_dim),
|
242
|
+
dtype=self.store_dtype,
|
243
|
+
device=self.device,
|
244
|
+
)
|
245
|
+
for _ in range(self.layer_num)
|
246
|
+
]
|
296
247
|
|
297
248
|
self.data_ptrs = torch.tensor(
|
298
249
|
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
@@ -349,13 +300,17 @@ class MHATokenToKVPool(KVCache):
|
|
349
300
|
]
|
350
301
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
351
302
|
|
303
|
+
def maybe_get_custom_mem_pool(self):
|
304
|
+
return self.custom_mem_pool
|
305
|
+
|
352
306
|
def get_cpu_copy(self, indices):
|
353
307
|
torch.cuda.synchronize()
|
354
308
|
kv_cache_cpu = []
|
309
|
+
chunk_size = self.cpu_offloading_chunk_size
|
355
310
|
for layer_id in range(self.layer_num):
|
356
311
|
kv_cache_cpu.append([])
|
357
|
-
for i in range(0, len(indices),
|
358
|
-
chunk_indices = indices[i : i +
|
312
|
+
for i in range(0, len(indices), chunk_size):
|
313
|
+
chunk_indices = indices[i : i + chunk_size]
|
359
314
|
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
360
315
|
"cpu", non_blocking=True
|
361
316
|
)
|
@@ -368,12 +323,13 @@ class MHATokenToKVPool(KVCache):
|
|
368
323
|
|
369
324
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
370
325
|
torch.cuda.synchronize()
|
326
|
+
chunk_size = self.cpu_offloading_chunk_size
|
371
327
|
for layer_id in range(self.layer_num):
|
372
|
-
for i in range(0, len(indices),
|
373
|
-
chunk_indices = indices[i : i +
|
328
|
+
for i in range(0, len(indices), chunk_size):
|
329
|
+
chunk_indices = indices[i : i + chunk_size]
|
374
330
|
k_cpu, v_cpu = (
|
375
|
-
kv_cache_cpu[layer_id][i //
|
376
|
-
kv_cache_cpu[layer_id][i //
|
331
|
+
kv_cache_cpu[layer_id][i // chunk_size][0],
|
332
|
+
kv_cache_cpu[layer_id][i // chunk_size][1],
|
377
333
|
)
|
378
334
|
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
379
335
|
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
@@ -569,16 +525,34 @@ class MLATokenToKVPool(KVCache):
|
|
569
525
|
self.kv_lora_rank = kv_lora_rank
|
570
526
|
self.qk_rope_head_dim = qk_rope_head_dim
|
571
527
|
|
572
|
-
with
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
528
|
+
# for disagg with nvlink
|
529
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
530
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
531
|
+
)
|
532
|
+
if self.enable_custom_mem_pool:
|
533
|
+
# TODO(shangming): abstract custom allocator class for more backends
|
534
|
+
from mooncake.allocator import NVLinkAllocator
|
535
|
+
|
536
|
+
allocator = NVLinkAllocator.get_allocator(self.device)
|
537
|
+
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
538
|
+
else:
|
539
|
+
self.custom_mem_pool = None
|
540
|
+
|
541
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
542
|
+
with (
|
543
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
544
|
+
if self.custom_mem_pool
|
545
|
+
else nullcontext()
|
546
|
+
):
|
547
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
548
|
+
self.kv_buffer = [
|
549
|
+
torch.zeros(
|
550
|
+
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
551
|
+
dtype=self.store_dtype,
|
552
|
+
device=device,
|
553
|
+
)
|
554
|
+
for _ in range(layer_num)
|
555
|
+
]
|
582
556
|
|
583
557
|
self.layer_transfer_counter = None
|
584
558
|
|
@@ -604,6 +578,9 @@ class MLATokenToKVPool(KVCache):
|
|
604
578
|
]
|
605
579
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
606
580
|
|
581
|
+
def maybe_get_custom_mem_pool(self):
|
582
|
+
return self.custom_mem_pool
|
583
|
+
|
607
584
|
def get_key_buffer(self, layer_id: int):
|
608
585
|
if self.layer_transfer_counter is not None:
|
609
586
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
@@ -677,6 +654,33 @@ class MLATokenToKVPool(KVCache):
|
|
677
654
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
678
655
|
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
679
656
|
|
657
|
+
def get_cpu_copy(self, indices):
|
658
|
+
torch.cuda.synchronize()
|
659
|
+
kv_cache_cpu = []
|
660
|
+
chunk_size = self.cpu_offloading_chunk_size
|
661
|
+
for layer_id in range(self.layer_num):
|
662
|
+
kv_cache_cpu.append([])
|
663
|
+
for i in range(0, len(indices), chunk_size):
|
664
|
+
chunk_indices = indices[i : i + chunk_size]
|
665
|
+
kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
|
666
|
+
"cpu", non_blocking=True
|
667
|
+
)
|
668
|
+
kv_cache_cpu[-1].append(kv_cpu)
|
669
|
+
torch.cuda.synchronize()
|
670
|
+
return kv_cache_cpu
|
671
|
+
|
672
|
+
def load_cpu_copy(self, kv_cache_cpu, indices):
|
673
|
+
torch.cuda.synchronize()
|
674
|
+
chunk_size = self.cpu_offloading_chunk_size
|
675
|
+
for layer_id in range(self.layer_num):
|
676
|
+
for i in range(0, len(indices), chunk_size):
|
677
|
+
chunk_indices = indices[i : i + chunk_size]
|
678
|
+
kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
|
679
|
+
assert kv_cpu.shape[0] == len(chunk_indices)
|
680
|
+
kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
|
681
|
+
self.kv_buffer[layer_id][chunk_indices] = kv_chunk
|
682
|
+
torch.cuda.synchronize()
|
683
|
+
|
680
684
|
|
681
685
|
class DoubleSparseTokenToKVPool(KVCache):
|
682
686
|
def __init__(
|
@@ -704,7 +708,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
704
708
|
end_layer,
|
705
709
|
)
|
706
710
|
|
707
|
-
with self.memory_saver_adapter.region():
|
711
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
708
712
|
# [size, head_num, head_dim] for each layer
|
709
713
|
self.k_buffer = [
|
710
714
|
torch.zeros(
|
@@ -23,7 +23,7 @@ import heapq
|
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
25
|
from functools import partial
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
26
|
+
from typing import TYPE_CHECKING, List, Optional
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
@@ -31,11 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
|
|
31
31
|
AllBlocksCleared,
|
32
32
|
BlockRemoved,
|
33
33
|
BlockStored,
|
34
|
-
KVCacheEvent,
|
35
34
|
)
|
36
|
-
from sglang.srt.
|
37
|
-
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
38
|
-
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
35
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
36
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
37
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
39
38
|
|
40
39
|
if TYPE_CHECKING:
|
41
40
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -47,9 +46,9 @@ class TreeNode:
|
|
47
46
|
|
48
47
|
def __init__(self, id: Optional[int] = None):
|
49
48
|
self.children = defaultdict(TreeNode)
|
50
|
-
self.parent = None
|
51
|
-
self.key = None
|
52
|
-
self.value = None
|
49
|
+
self.parent: TreeNode = None
|
50
|
+
self.key: List[int] = None
|
51
|
+
self.value: Optional[torch.Tensor] = None
|
53
52
|
self.lock_ref = 0
|
54
53
|
self.last_access_time = time.monotonic()
|
55
54
|
|
@@ -57,7 +56,7 @@ class TreeNode:
|
|
57
56
|
# indicating the node is loading KV cache from host
|
58
57
|
self.loading = False
|
59
58
|
# store the host indices of KV cache
|
60
|
-
self.host_value = None
|
59
|
+
self.host_value: Optional[torch.Tensor] = None
|
61
60
|
|
62
61
|
self.id = TreeNode.counter if id is None else id
|
63
62
|
TreeNode.counter += 1
|
@@ -99,7 +98,7 @@ class RadixCache(BasePrefixCache):
|
|
99
98
|
def __init__(
|
100
99
|
self,
|
101
100
|
req_to_token_pool: ReqToTokenPool,
|
102
|
-
token_to_kv_pool_allocator:
|
101
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
103
102
|
page_size: int,
|
104
103
|
disable: bool = False,
|
105
104
|
enable_kv_cache_events: bool = False,
|
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
|
|
135
134
|
self.protected_size_ = 0
|
136
135
|
self._record_all_cleared_event()
|
137
136
|
|
138
|
-
def match_prefix(self, key: List[int], **kwargs) ->
|
137
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
139
138
|
"""Find the matching prefix from the radix tree.
|
140
139
|
Args:
|
141
140
|
key: A list of token IDs to find a matching prefix.
|
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
|
|
147
146
|
than the last node's value.
|
148
147
|
"""
|
149
148
|
if self.disable or len(key) == 0:
|
150
|
-
return (
|
151
|
-
torch.empty(
|
149
|
+
return MatchResult(
|
150
|
+
device_indices=torch.empty(
|
152
151
|
(0,),
|
153
152
|
dtype=torch.int64,
|
154
153
|
device=self.device,
|
155
154
|
),
|
156
|
-
self.root_node,
|
155
|
+
last_device_node=self.root_node,
|
156
|
+
last_host_node=self.root_node,
|
157
157
|
)
|
158
158
|
|
159
159
|
if self.page_size != 1:
|
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
|
|
165
165
|
value = torch.cat(value)
|
166
166
|
else:
|
167
167
|
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
168
|
-
return
|
168
|
+
return MatchResult(
|
169
|
+
device_indices=value,
|
170
|
+
last_device_node=last_node,
|
171
|
+
last_host_node=last_node,
|
172
|
+
)
|
169
173
|
|
170
174
|
def insert(self, key: List, value=None):
|
171
175
|
if self.disable:
|
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
|
|
235
239
|
)
|
236
240
|
|
237
241
|
# The prefix indices could be updated, reuse it
|
238
|
-
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
|
242
|
+
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
|
239
243
|
self.req_to_token_pool.write(
|
240
244
|
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
241
245
|
new_indices[len(req.prefix_indices) :],
|