sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/decode.py +4 -0
- sglang/srt/disaggregation/prefill.py +4 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/openai/protocol.py +27 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/entrypoints/tool.py +7 -7
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +16 -7
- sglang/srt/layers/attention/ascend_backend.py +218 -111
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
- sglang/srt/layers/attention/utils.py +15 -94
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/moe/cutlass_moe.py +0 -15
- sglang/srt/layers/moe/ep_moe/layer.py +1 -7
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/mxfp4.py +16 -23
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/lora_manager.py +29 -12
- sglang/srt/managers/cache_controller.py +223 -156
- sglang/srt/managers/detokenizer_manager.py +5 -0
- sglang/srt/managers/io_struct.py +30 -0
- sglang/srt/managers/scheduler.py +58 -7
- sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
- sglang/srt/managers/tokenizer_manager.py +36 -3
- sglang/srt/mem_cache/hicache_storage.py +31 -20
- sglang/srt/mem_cache/hiradix_cache.py +12 -3
- sglang/srt/mem_cache/memory_pool.py +73 -14
- sglang/srt/mem_cache/memory_pool_host.py +3 -2
- sglang/srt/mem_cache/radix_cache.py +1 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
- sglang/srt/metrics/collector.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +12 -3
- sglang/srt/models/gpt_oss.py +2 -1
- sglang/srt/models/qwen2_5_vl.py +1 -0
- sglang/srt/offloader.py +115 -0
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/server_args.py +10 -5
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +59 -12
- sglang/test/test_cutlass_moe.py +33 -28
- sglang/version.py +1 -1
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -67,6 +67,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
67
67
|
from sglang.srt.layers.moe import initialize_moe_config
|
68
68
|
from sglang.srt.managers.io_struct import (
|
69
69
|
AbortReq,
|
70
|
+
BatchTokenizedEmbeddingReqInput,
|
71
|
+
BatchTokenizedGenerateReqInput,
|
70
72
|
CloseSessionReqInput,
|
71
73
|
ExpertDistributionReq,
|
72
74
|
ExpertDistributionReqOutput,
|
@@ -510,6 +512,8 @@ class Scheduler(
|
|
510
512
|
[
|
511
513
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
512
514
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
515
|
+
(BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
|
516
|
+
(BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
|
513
517
|
(FlushCacheReqInput, self.flush_cache_wrapped),
|
514
518
|
(AbortReq, self.abort_request),
|
515
519
|
(OpenSessionReqInput, self.open_session),
|
@@ -623,6 +627,8 @@ class Scheduler(
|
|
623
627
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
624
628
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
625
629
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
630
|
+
model_name=server_args.served_model_name,
|
631
|
+
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
626
632
|
)
|
627
633
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
628
634
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -1018,14 +1024,26 @@ class Scheduler(
|
|
1018
1024
|
req
|
1019
1025
|
for req in recv_reqs
|
1020
1026
|
if isinstance(
|
1021
|
-
req,
|
1027
|
+
req,
|
1028
|
+
(
|
1029
|
+
TokenizedGenerateReqInput,
|
1030
|
+
TokenizedEmbeddingReqInput,
|
1031
|
+
BatchTokenizedGenerateReqInput,
|
1032
|
+
BatchTokenizedEmbeddingReqInput,
|
1033
|
+
),
|
1022
1034
|
)
|
1023
1035
|
]
|
1024
1036
|
control_reqs = [
|
1025
1037
|
req
|
1026
1038
|
for req in recv_reqs
|
1027
1039
|
if not isinstance(
|
1028
|
-
req,
|
1040
|
+
req,
|
1041
|
+
(
|
1042
|
+
TokenizedGenerateReqInput,
|
1043
|
+
TokenizedEmbeddingReqInput,
|
1044
|
+
BatchTokenizedGenerateReqInput,
|
1045
|
+
BatchTokenizedEmbeddingReqInput,
|
1046
|
+
),
|
1029
1047
|
)
|
1030
1048
|
]
|
1031
1049
|
else:
|
@@ -1253,6 +1271,17 @@ class Scheduler(
|
|
1253
1271
|
else:
|
1254
1272
|
self._add_request_to_queue(req)
|
1255
1273
|
|
1274
|
+
def handle_batch_generate_request(
|
1275
|
+
self,
|
1276
|
+
recv_req: BatchTokenizedGenerateReqInput,
|
1277
|
+
):
|
1278
|
+
"""Handle optimized batch generate request."""
|
1279
|
+
logger.debug(f"Processing batch generate request with {len(recv_req)} requests")
|
1280
|
+
|
1281
|
+
# Process each request in the batch
|
1282
|
+
for tokenized_req in recv_req:
|
1283
|
+
self.handle_generate_request(tokenized_req)
|
1284
|
+
|
1256
1285
|
def _add_request_to_queue(self, req: Req):
|
1257
1286
|
req.queue_time_start = time.perf_counter()
|
1258
1287
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
@@ -1269,10 +1298,11 @@ class Scheduler(
|
|
1269
1298
|
def _prefetch_kvcache(self, req: Req):
|
1270
1299
|
if self.enable_hicache_storage:
|
1271
1300
|
req.init_next_round_input(self.tree_cache)
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1301
|
+
if req.last_node.backuped:
|
1302
|
+
# only to initiate the prefetch if the last node is backuped
|
1303
|
+
# otherwise, the allocated GPU memory must be locked for integrity
|
1304
|
+
last_hash = req.last_host_node.get_last_hash_value()
|
1305
|
+
matched_len = len(req.prefix_indices) + req.host_hit_length
|
1276
1306
|
new_input_tokens = req.fill_ids[matched_len:]
|
1277
1307
|
self.tree_cache.prefetch_from_storage(
|
1278
1308
|
req.rid, req.last_host_node, new_input_tokens, last_hash
|
@@ -1335,6 +1365,19 @@ class Scheduler(
|
|
1335
1365
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1336
1366
|
self._add_request_to_queue(req)
|
1337
1367
|
|
1368
|
+
def handle_batch_embedding_request(
|
1369
|
+
self,
|
1370
|
+
recv_req: BatchTokenizedEmbeddingReqInput,
|
1371
|
+
):
|
1372
|
+
"""Handle optimized batch embedding request."""
|
1373
|
+
logger.debug(
|
1374
|
+
f"Processing batch embedding request with {len(recv_req)} requests"
|
1375
|
+
)
|
1376
|
+
|
1377
|
+
# Process each request in the batch
|
1378
|
+
for tokenized_req in recv_req:
|
1379
|
+
self.handle_embedding_request(tokenized_req)
|
1380
|
+
|
1338
1381
|
def self_check_during_idle(self):
|
1339
1382
|
self.check_memory()
|
1340
1383
|
self.check_tree_cache()
|
@@ -2513,7 +2556,15 @@ def is_health_check_generate_req(recv_req):
|
|
2513
2556
|
|
2514
2557
|
|
2515
2558
|
def is_work_request(recv_req):
|
2516
|
-
return isinstance(
|
2559
|
+
return isinstance(
|
2560
|
+
recv_req,
|
2561
|
+
(
|
2562
|
+
TokenizedGenerateReqInput,
|
2563
|
+
TokenizedEmbeddingReqInput,
|
2564
|
+
BatchTokenizedGenerateReqInput,
|
2565
|
+
BatchTokenizedEmbeddingReqInput,
|
2566
|
+
),
|
2567
|
+
)
|
2517
2568
|
|
2518
2569
|
|
2519
2570
|
def run_scheduler_process(
|
@@ -125,6 +125,14 @@ class SchedulerMetricsMixin:
|
|
125
125
|
total_queue_latency += req.queue_time_end - req.queue_time_start
|
126
126
|
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
127
127
|
|
128
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
129
|
+
self.stats.num_prefill_prealloc_queue_reqs = len(
|
130
|
+
self.disagg_prefill_bootstrap_queue.queue
|
131
|
+
)
|
132
|
+
self.stats.num_prefill_inflight_queue_reqs = len(
|
133
|
+
self.disagg_prefill_inflight_queue
|
134
|
+
)
|
135
|
+
|
128
136
|
self.metrics_collector.log_stats(self.stats)
|
129
137
|
self._emit_kv_metrics()
|
130
138
|
self._publish_kv_events()
|
@@ -202,6 +210,13 @@ class SchedulerMetricsMixin:
|
|
202
210
|
self.stats.spec_accept_length = spec_accept_length
|
203
211
|
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
204
212
|
self.metrics_collector.log_stats(self.stats)
|
213
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
214
|
+
self.stats.num_decode_prealloc_queue_reqs = len(
|
215
|
+
self.disagg_decode_prealloc_queue.queue
|
216
|
+
)
|
217
|
+
self.stats.num_decode_transfer_queue_reqs = len(
|
218
|
+
self.disagg_decode_transfer_queue.queue
|
219
|
+
)
|
205
220
|
self._emit_kv_metrics()
|
206
221
|
self._publish_kv_events()
|
207
222
|
|
@@ -71,6 +71,8 @@ from sglang.srt.managers.io_struct import (
|
|
71
71
|
BatchMultimodalOut,
|
72
72
|
BatchStrOut,
|
73
73
|
BatchTokenIDOut,
|
74
|
+
BatchTokenizedEmbeddingReqInput,
|
75
|
+
BatchTokenizedGenerateReqInput,
|
74
76
|
CloseSessionReqInput,
|
75
77
|
ConfigureLoggingReq,
|
76
78
|
EmbeddingReqInput,
|
@@ -768,6 +770,30 @@ class TokenizerManager:
|
|
768
770
|
self.rid_to_state[obj.rid] = state
|
769
771
|
return state
|
770
772
|
|
773
|
+
def _send_batch_request(
|
774
|
+
self,
|
775
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
776
|
+
tokenized_objs: List[
|
777
|
+
Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
778
|
+
],
|
779
|
+
created_time: Optional[float] = None,
|
780
|
+
):
|
781
|
+
"""Send a batch of tokenized requests as a single batched request to the scheduler."""
|
782
|
+
if isinstance(tokenized_objs[0], TokenizedGenerateReqInput):
|
783
|
+
batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs)
|
784
|
+
else:
|
785
|
+
batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs)
|
786
|
+
|
787
|
+
self.send_to_scheduler.send_pyobj(batch_req)
|
788
|
+
|
789
|
+
# Create states for each individual request in the batch
|
790
|
+
for i, tokenized_obj in enumerate(tokenized_objs):
|
791
|
+
tmp_obj = obj[i]
|
792
|
+
state = ReqState(
|
793
|
+
[], False, asyncio.Event(), tmp_obj, created_time=created_time
|
794
|
+
)
|
795
|
+
self.rid_to_state[tmp_obj.rid] = state
|
796
|
+
|
771
797
|
async def _wait_one_response(
|
772
798
|
self,
|
773
799
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -870,10 +896,17 @@ class TokenizerManager:
|
|
870
896
|
|
871
897
|
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
872
898
|
|
873
|
-
|
899
|
+
# Send as a single batched request
|
900
|
+
self._send_batch_request(obj, tokenized_objs, created_time)
|
901
|
+
|
902
|
+
# Set up generators for each request in the batch
|
903
|
+
for i in range(batch_size):
|
874
904
|
tmp_obj = obj[i]
|
875
|
-
|
876
|
-
|
905
|
+
generators.append(
|
906
|
+
self._wait_one_response(
|
907
|
+
tmp_obj, self.rid_to_state[tmp_obj.rid], request
|
908
|
+
)
|
909
|
+
)
|
877
910
|
rids.append(tmp_obj.rid)
|
878
911
|
else:
|
879
912
|
# Sequential tokenization and processing
|
@@ -2,6 +2,7 @@ import hashlib
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
+
from dataclasses import dataclass
|
5
6
|
from typing import Any, List, Optional
|
6
7
|
|
7
8
|
import torch
|
@@ -9,17 +10,6 @@ import torch
|
|
9
10
|
logger = logging.getLogger(__name__)
|
10
11
|
|
11
12
|
|
12
|
-
from sglang.srt.distributed import (
|
13
|
-
get_tensor_model_parallel_rank,
|
14
|
-
get_tensor_model_parallel_world_size,
|
15
|
-
)
|
16
|
-
from sglang.srt.layers.dp_attention import (
|
17
|
-
get_attention_tp_rank,
|
18
|
-
get_attention_tp_size,
|
19
|
-
is_dp_attention_enabled,
|
20
|
-
)
|
21
|
-
|
22
|
-
|
23
13
|
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
24
14
|
hasher = hashlib.sha256()
|
25
15
|
|
@@ -32,6 +22,15 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
|
32
22
|
return hasher.hexdigest()
|
33
23
|
|
34
24
|
|
25
|
+
@dataclass
|
26
|
+
class HiCacheStorageConfig:
|
27
|
+
tp_rank: int
|
28
|
+
tp_size: int
|
29
|
+
is_mla_model: bool
|
30
|
+
model_name: Optional[str]
|
31
|
+
extra_config: Optional[dict] = None
|
32
|
+
|
33
|
+
|
35
34
|
class HiCacheStorage(ABC):
|
36
35
|
"""
|
37
36
|
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
@@ -60,7 +59,7 @@ class HiCacheStorage(ABC):
|
|
60
59
|
keys: List[str],
|
61
60
|
target_locations: Optional[Any] = None,
|
62
61
|
target_sizes: Optional[Any] = None,
|
63
|
-
) -> List[torch.Tensor | None]:
|
62
|
+
) -> List[torch.Tensor | None] | int:
|
64
63
|
"""
|
65
64
|
Retrieve values for multiple keys.
|
66
65
|
Returns a list of tensors or None for each key.
|
@@ -96,25 +95,37 @@ class HiCacheStorage(ABC):
|
|
96
95
|
pass
|
97
96
|
|
98
97
|
@abstractmethod
|
99
|
-
def exists(self, key: str) -> bool
|
98
|
+
def exists(self, key: str) -> bool:
|
100
99
|
"""
|
101
100
|
Check if the key exists in the storage.
|
102
101
|
Returns True if the key exists, False otherwise.
|
103
102
|
"""
|
104
103
|
pass
|
105
104
|
|
105
|
+
def batch_exists(self, keys: List[str]) -> int:
|
106
|
+
"""
|
107
|
+
Check if the keys exist in the storage.
|
108
|
+
return the number of consecutive existing keys from the start.
|
109
|
+
Can be overridden by subclasses for more efficient implementation.
|
110
|
+
"""
|
111
|
+
for i in range(len(keys)):
|
112
|
+
if not self.exists(keys[i]):
|
113
|
+
return i
|
114
|
+
return len(keys)
|
115
|
+
|
106
116
|
|
107
117
|
class HiCacheFile(HiCacheStorage):
|
108
118
|
|
109
|
-
def __init__(
|
119
|
+
def __init__(
|
120
|
+
self, storage_config: HiCacheStorageConfig, file_path: str = "/tmp/hicache"
|
121
|
+
):
|
110
122
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
111
|
-
if is_dp_attention_enabled():
|
112
|
-
tp_rank = get_attention_tp_rank()
|
113
|
-
tp_size = get_attention_tp_size()
|
114
|
-
else:
|
115
|
-
tp_rank = get_tensor_model_parallel_rank()
|
116
|
-
tp_size = get_tensor_model_parallel_world_size()
|
117
123
|
|
124
|
+
tp_rank, tp_size, is_mla = (
|
125
|
+
storage_config.tp_rank,
|
126
|
+
storage_config.tp_size,
|
127
|
+
storage_config.is_mla_model,
|
128
|
+
)
|
118
129
|
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
119
130
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
120
131
|
os.makedirs(self.file_path)
|
@@ -39,6 +39,8 @@ class HiRadixCache(RadixCache):
|
|
39
39
|
hicache_mem_layout: str,
|
40
40
|
hicache_storage_backend: Optional[str] = None,
|
41
41
|
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
42
|
+
model_name: Optional[str] = None,
|
43
|
+
storage_backend_extra_config: Optional[str] = None,
|
42
44
|
):
|
43
45
|
|
44
46
|
if hicache_io_backend == "direct":
|
@@ -87,6 +89,8 @@ class HiRadixCache(RadixCache):
|
|
87
89
|
io_backend=hicache_io_backend,
|
88
90
|
storage_backend=hicache_storage_backend,
|
89
91
|
prefetch_threshold=self.prefetch_threshold,
|
92
|
+
model_name=model_name,
|
93
|
+
storage_backend_extra_config=storage_backend_extra_config,
|
90
94
|
)
|
91
95
|
|
92
96
|
# record the nodes with ongoing write through
|
@@ -430,9 +434,12 @@ class HiRadixCache(RadixCache):
|
|
430
434
|
if self.prefetch_stop_policy == "best_effort":
|
431
435
|
return can_terminate
|
432
436
|
|
433
|
-
|
434
|
-
|
435
|
-
|
437
|
+
if len(operation.hash_value) == 0:
|
438
|
+
completed = False
|
439
|
+
else:
|
440
|
+
completed = (
|
441
|
+
operation.completed_tokens == len(operation.hash_value) * self.page_size
|
442
|
+
)
|
436
443
|
|
437
444
|
if self.prefetch_stop_policy == "wait_complete":
|
438
445
|
can_terminate = completed
|
@@ -536,6 +543,8 @@ class HiRadixCache(RadixCache):
|
|
536
543
|
while last_node.evicted:
|
537
544
|
host_hit_length += len(last_node.host_value)
|
538
545
|
last_node = last_node.parent
|
546
|
+
while not last_host_node.backuped:
|
547
|
+
last_host_node = last_host_node.parent
|
539
548
|
|
540
549
|
return MatchResult(
|
541
550
|
device_indices=value,
|
@@ -36,12 +36,15 @@ import triton.language as tl
|
|
36
36
|
|
37
37
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
38
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
-
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
39
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
40
40
|
|
41
41
|
logger = logging.getLogger(__name__)
|
42
42
|
|
43
43
|
GB = 1024 * 1024 * 1024
|
44
44
|
_is_cuda = is_cuda()
|
45
|
+
_is_npu = is_npu()
|
46
|
+
if _is_npu:
|
47
|
+
import torch_npu
|
45
48
|
|
46
49
|
|
47
50
|
class ReqToTokenPool:
|
@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
|
624
627
|
cache_k = cache_k.view(self.store_dtype)
|
625
628
|
cache_v = cache_v.view(self.store_dtype)
|
626
629
|
|
627
|
-
import torch_npu
|
628
|
-
|
629
630
|
torch_npu._npu_reshape_and_cache(
|
630
631
|
key=cache_k,
|
631
632
|
value=cache_v,
|
@@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
912
913
|
|
913
914
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
914
915
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
915
|
-
self.
|
916
|
+
self.k_buffer = torch.zeros(
|
916
917
|
(
|
917
918
|
layer_num,
|
918
919
|
self.size // self.page_size + 1,
|
919
920
|
self.page_size,
|
920
|
-
self.kv_lora_rank
|
921
|
+
self.kv_lora_rank,
|
922
|
+
),
|
923
|
+
dtype=self.store_dtype,
|
924
|
+
device=self.device,
|
925
|
+
)
|
926
|
+
self.v_buffer = torch.zeros(
|
927
|
+
(
|
928
|
+
layer_num,
|
929
|
+
self.size // self.page_size + 1,
|
930
|
+
self.page_size,
|
931
|
+
self.qk_rope_head_dim,
|
921
932
|
),
|
922
933
|
dtype=self.store_dtype,
|
923
934
|
device=self.device,
|
@@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
931
942
|
)
|
932
943
|
self.mem_usage = kv_size / GB
|
933
944
|
|
945
|
+
def get_kv_size_bytes(self):
|
946
|
+
assert hasattr(self, "k_buffer")
|
947
|
+
assert hasattr(self, "v_buffer")
|
948
|
+
kv_size_bytes = 0
|
949
|
+
for k_cache in self.k_buffer:
|
950
|
+
kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
951
|
+
for v_cache in self.v_buffer:
|
952
|
+
kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
953
|
+
return kv_size_bytes
|
954
|
+
|
955
|
+
def get_kv_buffer(self, layer_id: int):
|
956
|
+
if self.layer_transfer_counter is not None:
|
957
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
958
|
+
return (
|
959
|
+
self.k_buffer[layer_id - self.start_layer],
|
960
|
+
self.v_buffer[layer_id - self.start_layer],
|
961
|
+
)
|
962
|
+
|
963
|
+
def get_key_buffer(self, layer_id: int):
|
964
|
+
if self.layer_transfer_counter is not None:
|
965
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
966
|
+
|
967
|
+
if self.store_dtype != self.dtype:
|
968
|
+
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
969
|
+
return self.k_buffer[layer_id - self.start_layer]
|
970
|
+
|
971
|
+
def get_value_buffer(self, layer_id: int):
|
972
|
+
if self.layer_transfer_counter is not None:
|
973
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
974
|
+
|
975
|
+
if self.store_dtype != self.dtype:
|
976
|
+
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
977
|
+
return self.v_buffer[layer_id - self.start_layer]
|
978
|
+
|
934
979
|
# for disagg
|
935
980
|
def get_contiguous_buf_infos(self):
|
936
981
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
937
|
-
kv_data_ptrs = [self.
|
938
|
-
|
939
|
-
|
982
|
+
kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
|
983
|
+
self.v_buffer[i].data_ptr() for i in range(self.layer_num)
|
984
|
+
]
|
985
|
+
kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
|
986
|
+
self.v_buffer[i].nbytes for i in range(self.layer_num)
|
987
|
+
]
|
988
|
+
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
|
989
|
+
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
|
990
|
+
]
|
940
991
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
941
992
|
|
942
993
|
def set_kv_buffer(
|
@@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
953
1004
|
if self.store_dtype != self.dtype:
|
954
1005
|
cache_k = cache_k.view(self.store_dtype)
|
955
1006
|
|
956
|
-
|
1007
|
+
if cache_v is None:
|
1008
|
+
cache_k, cache_v = cache_k.split(
|
1009
|
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
1010
|
+
)
|
957
1011
|
|
958
|
-
torch_npu.
|
959
|
-
|
960
|
-
|
961
|
-
|
1012
|
+
torch_npu.npu_scatter_nd_update_(
|
1013
|
+
self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
|
1014
|
+
loc.view(-1, 1),
|
1015
|
+
cache_k.view(-1, 1, self.kv_lora_rank),
|
1016
|
+
)
|
1017
|
+
torch_npu.npu_scatter_nd_update_(
|
1018
|
+
self.v_buffer[layer_id - self.start_layer].view(
|
1019
|
+
-1, 1, self.qk_rope_head_dim
|
962
1020
|
),
|
963
|
-
|
1021
|
+
loc.view(-1, 1),
|
1022
|
+
cache_v.view(-1, 1, self.qk_rope_head_dim),
|
964
1023
|
)
|
965
1024
|
|
966
1025
|
|
@@ -465,6 +465,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
465
465
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
466
466
|
|
467
467
|
def get_buffer_meta(self, keys, indices):
|
468
|
+
local_rank = get_tensor_model_parallel_rank()
|
468
469
|
ptr_list = []
|
469
470
|
key_list = []
|
470
471
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
@@ -488,8 +489,8 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
488
489
|
ptr_list.append(k_ptr)
|
489
490
|
ptr_list.append(v_ptr)
|
490
491
|
key_ = keys[index // self.page_size]
|
491
|
-
key_list.append(f"{key_}_{
|
492
|
-
key_list.append(f"{key_}_{
|
492
|
+
key_list.append(f"{key_}_{local_rank}_k")
|
493
|
+
key_list.append(f"{key_}_{local_rank}_v")
|
493
494
|
element_size = (
|
494
495
|
self.layer_num
|
495
496
|
* self.dtype.itemsize
|
@@ -11,12 +11,7 @@ from typing import Any, List, Optional, Tuple
|
|
11
11
|
|
12
12
|
import torch
|
13
13
|
|
14
|
-
from sglang.srt.
|
15
|
-
from sglang.srt.layers.dp_attention import (
|
16
|
-
get_attention_tp_rank,
|
17
|
-
is_dp_attention_enabled,
|
18
|
-
)
|
19
|
-
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
14
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
20
15
|
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
21
16
|
|
22
17
|
logger = logging.getLogger(__name__)
|
@@ -172,19 +167,16 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
172
167
|
|
173
168
|
@staticmethod
|
174
169
|
def from_env_config(
|
175
|
-
bytes_per_page: int,
|
170
|
+
bytes_per_page: int,
|
171
|
+
dtype: torch.dtype,
|
172
|
+
storage_config: HiCacheStorageConfig = None,
|
176
173
|
) -> "HiCacheHF3FS":
|
177
174
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
178
175
|
Hf3fsGlobalMetadataClient,
|
179
176
|
Hf3fsLocalMetadataClient,
|
180
177
|
)
|
181
178
|
|
182
|
-
if
|
183
|
-
rank = (
|
184
|
-
get_attention_tp_rank()
|
185
|
-
if is_dp_attention_enabled()
|
186
|
-
else get_tensor_model_parallel_rank()
|
187
|
-
)
|
179
|
+
rank = storage_config.tp_rank if storage_config is not None else 0
|
188
180
|
|
189
181
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
190
182
|
if not config_path:
|