sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +87 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +26 -7
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +374 -136
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +13 -13
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +25 -27
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +237 -204
- sglang/srt/managers/detokenizer_manager.py +48 -2
- sglang/srt/managers/io_struct.py +57 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +94 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +122 -42
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +51 -23
- sglang/srt/mem_cache/hiradix_cache.py +87 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +77 -14
- sglang/srt/mem_cache/memory_pool_host.py +4 -5
- sglang/srt/mem_cache/radix_cache.py +6 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -5
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +38 -13
- sglang/srt/models/gpt_oss.py +2 -15
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +66 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +122 -56
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +73 -5
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import hashlib
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
+
from dataclasses import dataclass
|
5
6
|
from typing import Any, List, Optional
|
6
7
|
|
7
8
|
import torch
|
@@ -9,17 +10,6 @@ import torch
|
|
9
10
|
logger = logging.getLogger(__name__)
|
10
11
|
|
11
12
|
|
12
|
-
from sglang.srt.distributed import (
|
13
|
-
get_tensor_model_parallel_rank,
|
14
|
-
get_tensor_model_parallel_world_size,
|
15
|
-
)
|
16
|
-
from sglang.srt.layers.dp_attention import (
|
17
|
-
get_attention_tp_rank,
|
18
|
-
get_attention_tp_size,
|
19
|
-
is_dp_attention_enabled,
|
20
|
-
)
|
21
|
-
|
22
|
-
|
23
13
|
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
24
14
|
hasher = hashlib.sha256()
|
25
15
|
|
@@ -32,6 +22,15 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
|
32
22
|
return hasher.hexdigest()
|
33
23
|
|
34
24
|
|
25
|
+
@dataclass
|
26
|
+
class HiCacheStorageConfig:
|
27
|
+
tp_rank: int
|
28
|
+
tp_size: int
|
29
|
+
is_mla_model: bool
|
30
|
+
model_name: Optional[str]
|
31
|
+
extra_config: Optional[dict] = None
|
32
|
+
|
33
|
+
|
35
34
|
class HiCacheStorage(ABC):
|
36
35
|
"""
|
37
36
|
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
@@ -60,7 +59,7 @@ class HiCacheStorage(ABC):
|
|
60
59
|
keys: List[str],
|
61
60
|
target_locations: Optional[Any] = None,
|
62
61
|
target_sizes: Optional[Any] = None,
|
63
|
-
) -> List[torch.Tensor | None]:
|
62
|
+
) -> List[torch.Tensor | None] | int:
|
64
63
|
"""
|
65
64
|
Retrieve values for multiple keys.
|
66
65
|
Returns a list of tensors or None for each key.
|
@@ -96,25 +95,51 @@ class HiCacheStorage(ABC):
|
|
96
95
|
pass
|
97
96
|
|
98
97
|
@abstractmethod
|
99
|
-
def exists(self, key: str) -> bool
|
98
|
+
def exists(self, key: str) -> bool:
|
100
99
|
"""
|
101
100
|
Check if the key exists in the storage.
|
102
101
|
Returns True if the key exists, False otherwise.
|
103
102
|
"""
|
104
103
|
pass
|
105
104
|
|
105
|
+
@abstractmethod
|
106
|
+
def delete(self, key: str) -> bool:
|
107
|
+
"""
|
108
|
+
Delete the entry associated with the given key.
|
109
|
+
"""
|
110
|
+
pass
|
111
|
+
|
112
|
+
@abstractmethod
|
113
|
+
def clear(self) -> bool:
|
114
|
+
"""
|
115
|
+
Clear all entries in the storage.
|
116
|
+
"""
|
117
|
+
pass
|
118
|
+
|
119
|
+
def batch_exists(self, keys: List[str]) -> int:
|
120
|
+
"""
|
121
|
+
Check if the keys exist in the storage.
|
122
|
+
return the number of consecutive existing keys from the start.
|
123
|
+
Can be overridden by subclasses for more efficient implementation.
|
124
|
+
"""
|
125
|
+
for i in range(len(keys)):
|
126
|
+
if not self.exists(keys[i]):
|
127
|
+
return i
|
128
|
+
return len(keys)
|
129
|
+
|
106
130
|
|
107
131
|
class HiCacheFile(HiCacheStorage):
|
108
132
|
|
109
|
-
def __init__(
|
133
|
+
def __init__(
|
134
|
+
self, storage_config: HiCacheStorageConfig, file_path: str = "/tmp/hicache"
|
135
|
+
):
|
110
136
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
111
|
-
if is_dp_attention_enabled():
|
112
|
-
tp_rank = get_attention_tp_rank()
|
113
|
-
tp_size = get_attention_tp_size()
|
114
|
-
else:
|
115
|
-
tp_rank = get_tensor_model_parallel_rank()
|
116
|
-
tp_size = get_tensor_model_parallel_world_size()
|
117
137
|
|
138
|
+
tp_rank, tp_size, is_mla = (
|
139
|
+
storage_config.tp_rank,
|
140
|
+
storage_config.tp_size,
|
141
|
+
storage_config.is_mla_model,
|
142
|
+
)
|
118
143
|
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
119
144
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
120
145
|
os.makedirs(self.file_path)
|
@@ -164,11 +189,12 @@ class HiCacheFile(HiCacheStorage):
|
|
164
189
|
target_location: Optional[Any] = None,
|
165
190
|
target_sizes: Optional[Any] = None,
|
166
191
|
) -> bool:
|
167
|
-
key = self._get_suffixed_key(key)
|
168
|
-
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
169
192
|
if self.exists(key):
|
170
193
|
logger.debug(f"Key {key} already exists. Skipped.")
|
171
194
|
return True
|
195
|
+
|
196
|
+
key = self._get_suffixed_key(key)
|
197
|
+
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
172
198
|
try:
|
173
199
|
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
|
174
200
|
return True
|
@@ -202,12 +228,14 @@ class HiCacheFile(HiCacheStorage):
|
|
202
228
|
logger.warning(f"Key {key} does not exist. Cannot delete.")
|
203
229
|
return
|
204
230
|
|
205
|
-
def clear(self) ->
|
231
|
+
def clear(self) -> bool:
|
206
232
|
try:
|
207
233
|
for filename in os.listdir(self.file_path):
|
208
234
|
file_path = os.path.join(self.file_path, filename)
|
209
235
|
if os.path.isfile(file_path):
|
210
236
|
os.remove(file_path)
|
211
237
|
logger.info("Cleared all entries in HiCacheFile storage.")
|
238
|
+
return True
|
212
239
|
except Exception as e:
|
213
240
|
logger.error(f"Failed to clear HiCacheFile storage: {e}")
|
241
|
+
return False
|
@@ -39,6 +39,8 @@ class HiRadixCache(RadixCache):
|
|
39
39
|
hicache_mem_layout: str,
|
40
40
|
hicache_storage_backend: Optional[str] = None,
|
41
41
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
42
|
+
model_name: Optional[str] = None,
|
43
|
+
storage_backend_extra_config: Optional[str] = None,
|
42
44
|
):
|
43
45
|
|
44
46
|
if hicache_io_backend == "direct":
|
@@ -87,6 +89,8 @@ class HiRadixCache(RadixCache):
|
|
87
89
|
io_backend=hicache_io_backend,
|
88
90
|
storage_backend=hicache_storage_backend,
|
89
91
|
prefetch_threshold=self.prefetch_threshold,
|
92
|
+
model_name=model_name,
|
93
|
+
storage_backend_extra_config=storage_backend_extra_config,
|
90
94
|
)
|
91
95
|
|
92
96
|
# record the nodes with ongoing write through
|
@@ -98,10 +102,7 @@ class HiRadixCache(RadixCache):
|
|
98
102
|
self.ongoing_backup = {}
|
99
103
|
# todo: dynamically adjust the threshold
|
100
104
|
self.write_through_threshold = (
|
101
|
-
1 if hicache_write_policy == "write_through" else
|
102
|
-
)
|
103
|
-
self.write_through_threshold_storage = (
|
104
|
-
1 if hicache_write_policy == "write_through" else 3
|
105
|
+
1 if hicache_write_policy == "write_through" else 2
|
105
106
|
)
|
106
107
|
self.load_back_threshold = 10
|
107
108
|
super().__init__(
|
@@ -121,6 +122,15 @@ class HiRadixCache(RadixCache):
|
|
121
122
|
height += 1
|
122
123
|
return height
|
123
124
|
|
125
|
+
def clear_storage_backend(self):
|
126
|
+
if self.enable_storage:
|
127
|
+
self.cache_controller.storage_backend.clear()
|
128
|
+
logger.info("Hierarchical cache storage backend cleared successfully!")
|
129
|
+
return True
|
130
|
+
else:
|
131
|
+
logger.warning("Hierarchical cache storage backend is not enabled.")
|
132
|
+
return False
|
133
|
+
|
124
134
|
def write_backup(self, node: TreeNode, write_back=False):
|
125
135
|
host_indices = self.cache_controller.write(
|
126
136
|
device_indices=node.value,
|
@@ -151,8 +161,9 @@ class HiRadixCache(RadixCache):
|
|
151
161
|
self.ongoing_backup[operation_id] = node
|
152
162
|
node.protect_host()
|
153
163
|
|
154
|
-
def
|
155
|
-
|
164
|
+
def _inc_hit_count(self, node: TreeNode, chunked=False):
|
165
|
+
# skip the hit count update for chunked requests
|
166
|
+
if self.cache_controller.write_policy == "write_back" or chunked:
|
156
167
|
return
|
157
168
|
node.hit_count += 1
|
158
169
|
|
@@ -160,14 +171,6 @@ class HiRadixCache(RadixCache):
|
|
160
171
|
if node.hit_count >= self.write_through_threshold:
|
161
172
|
# write to host if the node is not backuped
|
162
173
|
self.write_backup(node)
|
163
|
-
else:
|
164
|
-
if (
|
165
|
-
self.enable_storage
|
166
|
-
and (not node.backuped_storage)
|
167
|
-
and node.hit_count >= self.write_through_threshold_storage
|
168
|
-
):
|
169
|
-
# if the node is backuped on host memory but not on storage
|
170
|
-
self.write_backup_storage(node)
|
171
174
|
|
172
175
|
def writing_check(self, write_back=False):
|
173
176
|
if write_back:
|
@@ -188,8 +191,11 @@ class HiRadixCache(RadixCache):
|
|
188
191
|
)
|
189
192
|
for _ in range(queue_size.item()):
|
190
193
|
ack_id = self.cache_controller.ack_write_queue.get()
|
191
|
-
self.
|
194
|
+
backuped_node = self.ongoing_write_through[ack_id]
|
195
|
+
self.dec_lock_ref(backuped_node)
|
192
196
|
del self.ongoing_write_through[ack_id]
|
197
|
+
if self.enable_storage:
|
198
|
+
self.write_backup_storage(backuped_node)
|
193
199
|
|
194
200
|
def loading_check(self):
|
195
201
|
while not self.cache_controller.ack_load_queue.empty():
|
@@ -372,57 +378,54 @@ class HiRadixCache(RadixCache):
|
|
372
378
|
self.writing_check()
|
373
379
|
self.loading_check()
|
374
380
|
if self.enable_storage:
|
375
|
-
self.
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
+
self.drain_storage_control_queues()
|
382
|
+
|
383
|
+
def drain_storage_control_queues(self):
|
384
|
+
"""
|
385
|
+
Combine prefetch revoke, backup ack, and host mem release checks
|
386
|
+
to minimize TP synchronization and Python overhead.
|
387
|
+
"""
|
388
|
+
cc = self.cache_controller
|
389
|
+
|
390
|
+
qsizes = torch.tensor(
|
391
|
+
[
|
392
|
+
cc.prefetch_revoke_queue.qsize(),
|
393
|
+
cc.ack_backup_queue.qsize(),
|
394
|
+
cc.host_mem_release_queue.qsize(),
|
395
|
+
],
|
396
|
+
dtype=torch.int,
|
381
397
|
)
|
382
398
|
if self.tp_world_size > 1:
|
383
|
-
# synchrnoize TP workers to make the same update to hiradix cache
|
384
399
|
torch.distributed.all_reduce(
|
385
|
-
|
386
|
-
op=torch.distributed.ReduceOp.MIN,
|
387
|
-
group=self.tp_group,
|
400
|
+
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
388
401
|
)
|
389
|
-
for _ in range(queue_size.item()):
|
390
|
-
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
391
|
-
if req_id in self.ongoing_prefetch:
|
392
|
-
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
393
|
-
last_host_node.release_host()
|
394
|
-
del self.ongoing_prefetch[req_id]
|
395
|
-
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
396
|
-
else:
|
397
|
-
# the revoked operation already got terminated
|
398
|
-
pass
|
399
402
|
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
)
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
403
|
+
n_revoke, n_backup, n_release = map(int, qsizes.tolist())
|
404
|
+
|
405
|
+
# process prefetch revokes
|
406
|
+
for _ in range(n_revoke):
|
407
|
+
req_id = cc.prefetch_revoke_queue.get()
|
408
|
+
info = self.ongoing_prefetch.pop(req_id, None)
|
409
|
+
if info is not None:
|
410
|
+
last_host_node, token_ids, _, _ = info
|
411
|
+
last_host_node.release_host()
|
412
|
+
cc.prefetch_tokens_occupied -= len(token_ids)
|
413
|
+
# else: the revoked operation already got terminated, nothing to do
|
414
|
+
|
415
|
+
# process backup acks
|
416
|
+
for _ in range(n_backup):
|
417
|
+
ack_id = cc.ack_backup_queue.get()
|
418
|
+
entry = self.ongoing_backup.pop(ack_id, None)
|
419
|
+
if entry is not None:
|
420
|
+
entry.release_host()
|
421
|
+
|
422
|
+
# release host memory
|
423
|
+
host_indices_list = []
|
424
|
+
for _ in range(n_release):
|
425
|
+
host_indices_list.append(cc.host_mem_release_queue.get())
|
426
|
+
if host_indices_list:
|
427
|
+
host_indices = torch.cat(host_indices_list, dim=0)
|
428
|
+
cc.mem_pool_host.free(host_indices)
|
426
429
|
|
427
430
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
428
431
|
can_terminate = True
|
@@ -430,9 +433,12 @@ class HiRadixCache(RadixCache):
|
|
430
433
|
if self.prefetch_stop_policy == "best_effort":
|
431
434
|
return can_terminate
|
432
435
|
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
+
if len(operation.hash_value) == 0:
|
437
|
+
completed = False
|
438
|
+
else:
|
439
|
+
completed = (
|
440
|
+
operation.completed_tokens == len(operation.hash_value) * self.page_size
|
441
|
+
)
|
436
442
|
|
437
443
|
if self.prefetch_stop_policy == "wait_complete":
|
438
444
|
can_terminate = completed
|
@@ -502,7 +508,7 @@ class HiRadixCache(RadixCache):
|
|
502
508
|
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
503
509
|
|
504
510
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
505
|
-
self.cache_controller.
|
511
|
+
self.cache_controller.append_host_mem_release(
|
506
512
|
host_indices[min_completed_tokens:completed_tokens]
|
507
513
|
)
|
508
514
|
last_host_node.release_host()
|
@@ -536,6 +542,8 @@ class HiRadixCache(RadixCache):
|
|
536
542
|
while last_node.evicted:
|
537
543
|
host_hit_length += len(last_node.host_value)
|
538
544
|
last_node = last_node.parent
|
545
|
+
while not last_host_node.backuped:
|
546
|
+
last_host_node = last_host_node.parent
|
539
547
|
|
540
548
|
return MatchResult(
|
541
549
|
device_indices=value,
|
@@ -556,7 +564,11 @@ class HiRadixCache(RadixCache):
|
|
556
564
|
len(new_input_tokens) % self.page_size
|
557
565
|
)
|
558
566
|
new_input_tokens = new_input_tokens[:prefetch_length]
|
559
|
-
if
|
567
|
+
if (
|
568
|
+
not self.enable_storage
|
569
|
+
or prefetch_length < self.prefetch_threshold
|
570
|
+
or self.cache_controller.prefetch_rate_limited()
|
571
|
+
):
|
560
572
|
return
|
561
573
|
|
562
574
|
last_host_node.protect_host()
|
@@ -564,6 +576,10 @@ class HiRadixCache(RadixCache):
|
|
564
576
|
if host_indices is None:
|
565
577
|
self.evict_host(prefetch_length)
|
566
578
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
579
|
+
if host_indices is None:
|
580
|
+
last_host_node.release_host()
|
581
|
+
# no sufficient host memory for prefetch
|
582
|
+
return
|
567
583
|
operation = self.cache_controller.prefetch(
|
568
584
|
req_id, host_indices, new_input_tokens, last_hash
|
569
585
|
)
|
@@ -663,11 +679,11 @@ class HiRadixCache(RadixCache):
|
|
663
679
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
664
680
|
return new_node
|
665
681
|
|
666
|
-
def
|
667
|
-
node.last_access_time = time.monotonic()
|
682
|
+
def insert(self, key: List, value, chunked=False):
|
668
683
|
if len(key) == 0:
|
669
684
|
return 0
|
670
685
|
|
686
|
+
node = self.root_node
|
671
687
|
child_key = self.get_child_key_fn(key)
|
672
688
|
total_prefix_length = 0
|
673
689
|
|
@@ -684,7 +700,7 @@ class HiRadixCache(RadixCache):
|
|
684
700
|
self.token_to_kv_pool_host.update_synced(node.host_value)
|
685
701
|
self.evictable_size_ += len(node.value)
|
686
702
|
else:
|
687
|
-
self.
|
703
|
+
self._inc_hit_count(node, chunked)
|
688
704
|
total_prefix_length += prefix_len
|
689
705
|
else:
|
690
706
|
# partial match, split the node
|
@@ -694,7 +710,7 @@ class HiRadixCache(RadixCache):
|
|
694
710
|
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
695
711
|
self.evictable_size_ += len(new_node.value)
|
696
712
|
else:
|
697
|
-
self.
|
713
|
+
self._inc_hit_count(new_node, chunked)
|
698
714
|
total_prefix_length += prefix_len
|
699
715
|
node = new_node
|
700
716
|
|
@@ -728,7 +744,7 @@ class HiRadixCache(RadixCache):
|
|
728
744
|
last_hash = new_node.hash_value[-1]
|
729
745
|
|
730
746
|
if self.cache_controller.write_policy != "write_back":
|
731
|
-
self.
|
747
|
+
self._inc_hit_count(new_node, chunked)
|
732
748
|
return total_prefix_length
|
733
749
|
|
734
750
|
def _collect_leaves_device(self):
|
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
|
|
183
183
|
self.req_to_token_pool.free(req.req_pool_idx)
|
184
184
|
self.dec_lock_ref(req.last_node)
|
185
185
|
|
186
|
-
def cache_unfinished_req(self, req: Req):
|
186
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
187
187
|
"""Cache request when it is unfinished."""
|
188
188
|
if self.disable:
|
189
189
|
return
|
@@ -36,12 +36,15 @@ import triton.language as tl
|
|
36
36
|
|
37
37
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
38
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
-
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
39
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
40
40
|
|
41
41
|
logger = logging.getLogger(__name__)
|
42
42
|
|
43
43
|
GB = 1024 * 1024 * 1024
|
44
44
|
_is_cuda = is_cuda()
|
45
|
+
_is_npu = is_npu()
|
46
|
+
if _is_npu:
|
47
|
+
import torch_npu
|
45
48
|
|
46
49
|
|
47
50
|
class ReqToTokenPool:
|
@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
|
624
627
|
cache_k = cache_k.view(self.store_dtype)
|
625
628
|
cache_v = cache_v.view(self.store_dtype)
|
626
629
|
|
627
|
-
import torch_npu
|
628
|
-
|
629
630
|
torch_npu._npu_reshape_and_cache(
|
630
631
|
key=cache_k,
|
631
632
|
value=cache_v,
|
@@ -912,12 +913,24 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
912
913
|
|
913
914
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
914
915
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
915
|
-
self.
|
916
|
+
self.k_buffer = torch.zeros(
|
916
917
|
(
|
917
918
|
layer_num,
|
918
919
|
self.size // self.page_size + 1,
|
919
920
|
self.page_size,
|
920
|
-
|
921
|
+
1,
|
922
|
+
self.kv_lora_rank,
|
923
|
+
),
|
924
|
+
dtype=self.store_dtype,
|
925
|
+
device=self.device,
|
926
|
+
)
|
927
|
+
self.v_buffer = torch.zeros(
|
928
|
+
(
|
929
|
+
layer_num,
|
930
|
+
self.size // self.page_size + 1,
|
931
|
+
self.page_size,
|
932
|
+
1,
|
933
|
+
self.qk_rope_head_dim,
|
921
934
|
),
|
922
935
|
dtype=self.store_dtype,
|
923
936
|
device=self.device,
|
@@ -931,12 +944,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
931
944
|
)
|
932
945
|
self.mem_usage = kv_size / GB
|
933
946
|
|
947
|
+
def get_kv_size_bytes(self):
|
948
|
+
assert hasattr(self, "k_buffer")
|
949
|
+
assert hasattr(self, "v_buffer")
|
950
|
+
kv_size_bytes = 0
|
951
|
+
for k_cache in self.k_buffer:
|
952
|
+
kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
953
|
+
for v_cache in self.v_buffer:
|
954
|
+
kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
955
|
+
return kv_size_bytes
|
956
|
+
|
957
|
+
def get_kv_buffer(self, layer_id: int):
|
958
|
+
if self.layer_transfer_counter is not None:
|
959
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
960
|
+
return (
|
961
|
+
self.k_buffer[layer_id - self.start_layer],
|
962
|
+
self.v_buffer[layer_id - self.start_layer],
|
963
|
+
)
|
964
|
+
|
965
|
+
def get_key_buffer(self, layer_id: int):
|
966
|
+
if self.layer_transfer_counter is not None:
|
967
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
968
|
+
|
969
|
+
if self.store_dtype != self.dtype:
|
970
|
+
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
971
|
+
return self.k_buffer[layer_id - self.start_layer]
|
972
|
+
|
973
|
+
def get_value_buffer(self, layer_id: int):
|
974
|
+
if self.layer_transfer_counter is not None:
|
975
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
976
|
+
|
977
|
+
if self.store_dtype != self.dtype:
|
978
|
+
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
979
|
+
return self.v_buffer[layer_id - self.start_layer]
|
980
|
+
|
934
981
|
# for disagg
|
935
982
|
def get_contiguous_buf_infos(self):
|
936
983
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
937
|
-
kv_data_ptrs = [self.
|
938
|
-
|
939
|
-
|
984
|
+
kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
|
985
|
+
self.v_buffer[i].data_ptr() for i in range(self.layer_num)
|
986
|
+
]
|
987
|
+
kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
|
988
|
+
self.v_buffer[i].nbytes for i in range(self.layer_num)
|
989
|
+
]
|
990
|
+
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
|
991
|
+
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
|
992
|
+
]
|
940
993
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
941
994
|
|
942
995
|
def set_kv_buffer(
|
@@ -949,18 +1002,28 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
949
1002
|
layer_id = layer.layer_id
|
950
1003
|
if cache_k.dtype != self.dtype:
|
951
1004
|
cache_k = cache_k.to(self.dtype)
|
1005
|
+
cache_v = cache_v.to(self.dtype)
|
952
1006
|
|
953
1007
|
if self.store_dtype != self.dtype:
|
954
1008
|
cache_k = cache_k.view(self.store_dtype)
|
1009
|
+
cache_v = cache_v.view(self.store_dtype)
|
955
1010
|
|
956
|
-
|
1011
|
+
if cache_v is None:
|
1012
|
+
cache_k, cache_v = cache_k.split(
|
1013
|
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
1014
|
+
)
|
957
1015
|
|
958
|
-
torch_npu.
|
959
|
-
|
960
|
-
|
961
|
-
|
1016
|
+
torch_npu.npu_scatter_nd_update_(
|
1017
|
+
self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
|
1018
|
+
loc.view(-1, 1),
|
1019
|
+
cache_k.view(-1, 1, self.kv_lora_rank),
|
1020
|
+
)
|
1021
|
+
torch_npu.npu_scatter_nd_update_(
|
1022
|
+
self.v_buffer[layer_id - self.start_layer].view(
|
1023
|
+
-1, 1, self.qk_rope_head_dim
|
962
1024
|
),
|
963
|
-
|
1025
|
+
loc.view(-1, 1),
|
1026
|
+
cache_v.view(-1, 1, self.qk_rope_head_dim),
|
964
1027
|
)
|
965
1028
|
|
966
1029
|
|
@@ -7,7 +7,6 @@ from functools import wraps
|
|
7
7
|
import psutil
|
8
8
|
import torch
|
9
9
|
|
10
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
11
10
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
12
11
|
from sglang.srt.utils import is_npu
|
13
12
|
|
@@ -464,7 +463,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
464
463
|
else:
|
465
464
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
466
465
|
|
467
|
-
def get_buffer_meta(self, keys, indices):
|
466
|
+
def get_buffer_meta(self, keys, indices, local_rank):
|
468
467
|
ptr_list = []
|
469
468
|
key_list = []
|
470
469
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
@@ -488,8 +487,8 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
488
487
|
ptr_list.append(k_ptr)
|
489
488
|
ptr_list.append(v_ptr)
|
490
489
|
key_ = keys[index // self.page_size]
|
491
|
-
key_list.append(f"{key_}_{
|
492
|
-
key_list.append(f"{key_}_{
|
490
|
+
key_list.append(f"{key_}_{local_rank}_k")
|
491
|
+
key_list.append(f"{key_}_{local_rank}_v")
|
493
492
|
element_size = (
|
494
493
|
self.layer_num
|
495
494
|
* self.dtype.itemsize
|
@@ -703,7 +702,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
703
702
|
else:
|
704
703
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
705
704
|
|
706
|
-
def get_buffer_meta(self, keys, indices):
|
705
|
+
def get_buffer_meta(self, keys, indices, local_rank):
|
707
706
|
ptr_list = []
|
708
707
|
key_list = []
|
709
708
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
@@ -62,7 +62,6 @@ class TreeNode:
|
|
62
62
|
self.host_value: Optional[torch.Tensor] = None
|
63
63
|
# store hash values of each pages
|
64
64
|
self.hash_value: Optional[List[str]] = None
|
65
|
-
self.backuped_storage = False
|
66
65
|
|
67
66
|
self.id = TreeNode.counter if id is None else id
|
68
67
|
TreeNode.counter += 1
|
@@ -152,6 +151,7 @@ class RadixCache(BasePrefixCache):
|
|
152
151
|
self.root_node = TreeNode()
|
153
152
|
self.root_node.key = []
|
154
153
|
self.root_node.value = []
|
154
|
+
self.root_node.host_value = []
|
155
155
|
self.root_node.lock_ref = 1
|
156
156
|
self.evictable_size_ = 0
|
157
157
|
self.protected_size_ = 0
|
@@ -194,7 +194,7 @@ class RadixCache(BasePrefixCache):
|
|
194
194
|
last_host_node=last_node,
|
195
195
|
)
|
196
196
|
|
197
|
-
def insert(self, key: List, value=None):
|
197
|
+
def insert(self, key: List, value=None, chunked=False):
|
198
198
|
if self.disable:
|
199
199
|
return 0
|
200
200
|
|
@@ -239,7 +239,7 @@ class RadixCache(BasePrefixCache):
|
|
239
239
|
self.req_to_token_pool.free(req.req_pool_idx)
|
240
240
|
self.dec_lock_ref(req.last_node)
|
241
241
|
|
242
|
-
def cache_unfinished_req(self, req: Req):
|
242
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
243
243
|
"""Cache request when it is unfinished."""
|
244
244
|
if self.disable:
|
245
245
|
return
|
@@ -260,7 +260,9 @@ class RadixCache(BasePrefixCache):
|
|
260
260
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
261
261
|
|
262
262
|
# Radix Cache takes one ref in memory pool
|
263
|
-
new_prefix_len = self.insert(
|
263
|
+
new_prefix_len = self.insert(
|
264
|
+
page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
|
265
|
+
)
|
264
266
|
self.token_to_kv_pool_allocator.free(
|
265
267
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
266
268
|
)
|
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
|
|
181
181
|
self.dec_lock_ref(req.last_node)
|
182
182
|
self.req_to_token_pool.free(req.req_pool_idx)
|
183
183
|
|
184
|
-
def cache_unfinished_req(self, req: Req):
|
184
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
185
185
|
"""Cache request when it is unfinished."""
|
186
186
|
assert req.req_pool_idx is not None
|
187
187
|
token_ids = req.fill_ids
|