sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/model_config.py +2 -1
- sglang/srt/disaggregation/mini_lb.py +2 -2
- sglang/srt/distributed/parallel_state.py +46 -41
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +5 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +3 -3
- sglang/srt/entrypoints/openai/serving_completions.py +3 -1
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
- sglang/srt/entrypoints/openai/serving_responses.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/layer.py +2 -7
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/utils.py +0 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
- sglang/srt/layers/quantization/modelopt_quant.py +35 -2
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/managers/cache_controller.py +42 -39
- sglang/srt/managers/detokenizer_manager.py +0 -34
- sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
- sglang/srt/managers/schedule_policy.py +3 -2
- sglang/srt/managers/scheduler.py +7 -100
- sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +15 -10
- sglang/srt/mem_cache/hiradix_cache.py +16 -0
- sglang/srt/mem_cache/memory_pool_host.py +18 -11
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/metrics/collector.py +12 -4
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/forward_batch_info.py +16 -17
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +245 -36
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/gpt_oss.py +5 -4
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/longcat_flash.py +26 -15
- sglang/srt/models/longcat_flash_nextn.py +23 -15
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/qwen2_moe.py +4 -1
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/server_args.py +79 -2
- sglang/srt/speculative/eagle_worker.py +158 -112
- sglang/srt/utils.py +12 -10
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -283,7 +283,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
283
283
|
self.swa_attn_allocator.clear()
|
284
284
|
self.full_attn_allocator.clear()
|
285
285
|
self.full_to_swa_index_mapping.fill_(0)
|
286
|
-
self.
|
286
|
+
self.is_not_in_free_group = True
|
287
287
|
self.free_group = []
|
288
288
|
|
289
289
|
|
@@ -27,6 +27,7 @@ class HiCacheStorageConfig:
|
|
27
27
|
tp_rank: int
|
28
28
|
tp_size: int
|
29
29
|
is_mla_model: bool
|
30
|
+
is_page_first_layout: bool
|
30
31
|
model_name: Optional[str]
|
31
32
|
extra_config: Optional[dict] = None
|
32
33
|
|
@@ -135,18 +136,24 @@ class HiCacheFile(HiCacheStorage):
|
|
135
136
|
):
|
136
137
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
137
138
|
|
138
|
-
tp_rank, tp_size,
|
139
|
+
tp_rank, tp_size, model_name, is_mla_model = (
|
139
140
|
storage_config.tp_rank,
|
140
141
|
storage_config.tp_size,
|
142
|
+
storage_config.model_name,
|
141
143
|
storage_config.is_mla_model,
|
142
144
|
)
|
143
|
-
|
145
|
+
model_name = "-".join(model_name.split("/")) if model_name else ""
|
146
|
+
if is_mla_model:
|
147
|
+
self.config_suffix = f"_{model_name}"
|
148
|
+
else:
|
149
|
+
self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
|
150
|
+
|
144
151
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
145
152
|
os.makedirs(self.file_path)
|
146
153
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
147
154
|
|
148
155
|
def _get_suffixed_key(self, key: str) -> str:
|
149
|
-
return key + self.
|
156
|
+
return key + self.config_suffix
|
150
157
|
|
151
158
|
def get(
|
152
159
|
self,
|
@@ -157,13 +164,11 @@ class HiCacheFile(HiCacheStorage):
|
|
157
164
|
key = self._get_suffixed_key(key)
|
158
165
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
159
166
|
try:
|
160
|
-
|
161
|
-
with open(tensor_path, "rb") as f:
|
162
|
-
target_location.
|
163
|
-
|
164
|
-
|
165
|
-
.untyped_storage()
|
166
|
-
)
|
167
|
+
expected = target_location.numel() * target_location.element_size()
|
168
|
+
with open(tensor_path, "rb", buffering=0) as f:
|
169
|
+
buf = memoryview(target_location.view(torch.uint8).contiguous().numpy())
|
170
|
+
if f.readinto(buf) != expected:
|
171
|
+
raise IOError(f"Short read for {key}")
|
167
172
|
return target_location
|
168
173
|
except FileNotFoundError:
|
169
174
|
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
@@ -771,3 +771,19 @@ class HiRadixCache(RadixCache):
|
|
771
771
|
if not cur_child.evicted:
|
772
772
|
stack.append(cur_child)
|
773
773
|
return ret_list
|
774
|
+
|
775
|
+
def release_aborted_request(self, rid: str):
|
776
|
+
if rid not in self.ongoing_prefetch:
|
777
|
+
return
|
778
|
+
|
779
|
+
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
|
780
|
+
if operation.host_indices is None:
|
781
|
+
return
|
782
|
+
|
783
|
+
completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
|
784
|
+
if self.tp_world_size > 1:
|
785
|
+
torch.distributed.barrier(group=self.tp_group)
|
786
|
+
last_host_node.release_host()
|
787
|
+
del self.ongoing_prefetch[rid]
|
788
|
+
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
|
789
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
@@ -467,6 +467,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
467
467
|
ptr_list = []
|
468
468
|
key_list = []
|
469
469
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
470
|
+
indices = indices.tolist()
|
470
471
|
v_offset = (
|
471
472
|
self.layer_num
|
472
473
|
* self.size
|
@@ -499,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
499
500
|
element_size_list = [element_size] * len(key_list)
|
500
501
|
return key_list, ptr_list, element_size_list
|
501
502
|
|
502
|
-
def get_buffer_with_hash(self, keys, indices):
|
503
|
+
def get_buffer_with_hash(self, keys, indices=None):
|
503
504
|
assert self.layout == "page_first"
|
504
|
-
assert len(keys) == (len(indices) // self.page_size)
|
505
|
+
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
505
506
|
|
506
507
|
key_list = []
|
507
508
|
buf_list = []
|
508
509
|
|
509
|
-
for
|
510
|
+
for i in range(len(keys)):
|
511
|
+
key = keys[i]
|
510
512
|
key_list.append(f"{key}-k")
|
511
|
-
buf_list.append(self.k_buffer[i : i + self.page_size])
|
512
513
|
key_list.append(f"{key}-v")
|
513
|
-
|
514
|
+
if indices is not None:
|
515
|
+
index = indices[i * self.page_size]
|
516
|
+
buf_list.append(self.k_buffer[index : index + self.page_size])
|
517
|
+
buf_list.append(self.v_buffer[index : index + self.page_size])
|
514
518
|
|
515
|
-
return key_list, buf_list
|
519
|
+
return key_list, buf_list, 2
|
516
520
|
|
517
521
|
|
518
522
|
class MLATokenToKVPoolHost(HostKVCache):
|
@@ -706,6 +710,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
706
710
|
ptr_list = []
|
707
711
|
key_list = []
|
708
712
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
713
|
+
indices = indices.tolist()
|
709
714
|
for index in range(0, len(indices), self.page_size):
|
710
715
|
k_ptr = (
|
711
716
|
kv_buffer_data_ptr
|
@@ -726,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
726
731
|
element_size_list = [element_size] * len(key_list)
|
727
732
|
return key_list, ptr_list, element_size_list
|
728
733
|
|
729
|
-
def get_buffer_with_hash(self, keys, indices):
|
734
|
+
def get_buffer_with_hash(self, keys, indices=None):
|
730
735
|
assert self.layout == "page_first"
|
731
|
-
assert len(keys) == (len(indices) // self.page_size)
|
736
|
+
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
732
737
|
|
733
738
|
buf_list = []
|
734
739
|
|
735
|
-
|
736
|
-
|
740
|
+
if indices is not None:
|
741
|
+
for i in range(len(keys)):
|
742
|
+
index = indices[i * self.page_size]
|
743
|
+
buf_list.append(self.kv_buffer[index : index + self.page_size])
|
737
744
|
|
738
|
-
return keys, buf_list
|
745
|
+
return keys, buf_list, 1
|
@@ -4,10 +4,12 @@ import json
|
|
4
4
|
import logging
|
5
5
|
import threading
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Dict, List, Optional, Tuple
|
7
|
+
from typing import Dict, List, Optional, OrderedDict, Tuple
|
8
8
|
|
9
|
+
import orjson
|
9
10
|
import requests
|
10
|
-
from fastapi import FastAPI, HTTPException, Request,
|
11
|
+
from fastapi import FastAPI, HTTPException, Request, Response
|
12
|
+
from fastapi.responses import ORJSONResponse
|
11
13
|
from requests.adapters import HTTPAdapter
|
12
14
|
from urllib3.util.retry import Retry
|
13
15
|
|
@@ -24,10 +26,10 @@ class RankMetadata:
|
|
24
26
|
"""Holds all metadata for a single rank."""
|
25
27
|
|
26
28
|
def __init__(self, num_pages: int):
|
27
|
-
self.lock = threading.
|
29
|
+
self.lock = threading.Lock()
|
28
30
|
self.num_pages = num_pages
|
29
31
|
self.free_pages: List[int] = list(range(num_pages))
|
30
|
-
self.key_to_index:
|
32
|
+
self.key_to_index: OrderedDict[str, int] = OrderedDict()
|
31
33
|
# Todo: Support multi files for HF3FS
|
32
34
|
|
33
35
|
def exists_keys(self, keys: List[str]) -> List[bool]:
|
@@ -46,16 +48,18 @@ class RankMetadata:
|
|
46
48
|
for i, (key, prefix_key) in enumerate(keys):
|
47
49
|
if key in self.key_to_index:
|
48
50
|
results[i] = (True, self.key_to_index[key])
|
51
|
+
self.key_to_index.move_to_end(key)
|
49
52
|
else:
|
50
53
|
new_keys_to_process.append((i, key, prefix_key))
|
51
54
|
|
52
55
|
# Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
|
53
56
|
for i, key, prefix_key in new_keys_to_process:
|
54
57
|
if len(self.free_pages) > 0:
|
55
|
-
|
56
|
-
results[i] = (False, page_idx)
|
58
|
+
page_index = self.free_pages.pop()
|
57
59
|
else:
|
58
|
-
|
60
|
+
page_index = self.key_to_index.popitem(last=False)[1]
|
61
|
+
|
62
|
+
results[i] = (False, page_index)
|
59
63
|
|
60
64
|
return results
|
61
65
|
|
@@ -68,6 +72,7 @@ class RankMetadata:
|
|
68
72
|
with self.lock:
|
69
73
|
for key, page_index in written_keys_to_confirm:
|
70
74
|
self.key_to_index[key] = page_index
|
75
|
+
self.key_to_index.move_to_end(key)
|
71
76
|
|
72
77
|
for page_index in pages_to_release:
|
73
78
|
if page_index not in self.free_pages:
|
@@ -94,7 +99,14 @@ class RankMetadata:
|
|
94
99
|
def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
|
95
100
|
"""Get page indices for keys."""
|
96
101
|
with self.lock:
|
97
|
-
|
102
|
+
results = []
|
103
|
+
for key in keys:
|
104
|
+
if key in self.key_to_index:
|
105
|
+
results.append(self.key_to_index[key])
|
106
|
+
self.key_to_index.move_to_end(key)
|
107
|
+
else:
|
108
|
+
results.append(None)
|
109
|
+
return results
|
98
110
|
|
99
111
|
|
100
112
|
class GlobalMetadataState:
|
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
|
|
182
194
|
|
183
195
|
def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
|
184
196
|
self.state = GlobalMetadataState(persistence_path, save_interval)
|
185
|
-
self.app = FastAPI()
|
197
|
+
self.app = FastAPI(default_response_class=ORJSONResponse)
|
198
|
+
|
186
199
|
self._setup_routes()
|
187
200
|
|
188
201
|
def _setup_routes(self):
|
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
|
|
199
212
|
|
200
213
|
def get_rank_metadata(self, rank: int) -> RankMetadata:
|
201
214
|
"""Get rank metadata with proper error handling."""
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
215
|
+
if rank not in self.state.ranks:
|
216
|
+
raise HTTPException(
|
217
|
+
status_code=404,
|
218
|
+
detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
|
219
|
+
)
|
220
|
+
return self.state.ranks[rank]
|
221
|
+
|
222
|
+
async def _read_json(self, request: Request) -> dict:
|
223
|
+
"""Parse request JSON using orjson if available."""
|
224
|
+
body = await request.body()
|
225
|
+
return orjson.loads(body)
|
226
|
+
|
227
|
+
def _json_response(self, content: dict):
|
228
|
+
"""Return ORJSONResponse when available to bypass jsonable_encoder."""
|
229
|
+
return ORJSONResponse(content)
|
209
230
|
|
210
231
|
async def initialize(self, rank: int, request: Request):
|
211
232
|
"""Initialize a rank with specified number of pages."""
|
212
|
-
data = await
|
233
|
+
data = await self._read_json(request)
|
213
234
|
num_pages = data["num_pages"]
|
214
235
|
with self.state.global_lock:
|
215
236
|
if rank in self.state.ranks:
|
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
|
|
223
244
|
else:
|
224
245
|
logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
|
225
246
|
self.state.ranks[rank] = RankMetadata(num_pages)
|
226
|
-
return
|
247
|
+
return Response(status_code=204)
|
227
248
|
|
228
249
|
async def exists(self, rank: int, request: Request):
|
229
250
|
"""Check if keys exist in metadata."""
|
230
|
-
data = await
|
251
|
+
data = await self._read_json(request)
|
231
252
|
keys = data["keys"]
|
232
253
|
metadata = self.get_rank_metadata(rank)
|
233
254
|
results = metadata.exists_keys(keys)
|
234
|
-
return {"exists": results}
|
255
|
+
return self._json_response({"exists": results})
|
235
256
|
|
236
257
|
async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
|
237
258
|
"""Reserve and allocate page indices for keys."""
|
238
|
-
data = await
|
259
|
+
data = await self._read_json(request)
|
239
260
|
metadata = self.get_rank_metadata(rank)
|
240
261
|
keys = data["keys"]
|
241
262
|
results = metadata.reserve_and_allocate_page_indices(keys)
|
242
|
-
return {"indices": results}
|
263
|
+
return self._json_response({"indices": results})
|
243
264
|
|
244
265
|
async def confirm_write(self, rank: int, request: Request):
|
245
266
|
"""Confirm write operations and release pages."""
|
246
|
-
data = await
|
267
|
+
data = await self._read_json(request)
|
247
268
|
metadata = self.get_rank_metadata(rank)
|
248
269
|
success_written_keys = data.get("written_keys_to_confirm", [])
|
249
270
|
released_pages = data.get("pages_to_release", [])
|
250
271
|
|
251
272
|
metadata.confirm_write(success_written_keys, released_pages)
|
252
273
|
|
253
|
-
return
|
254
|
-
"message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
|
255
|
-
}
|
274
|
+
return Response(status_code=204)
|
256
275
|
|
257
276
|
async def delete_keys(self, rank: int, request: Request):
|
258
277
|
"""Delete keys from metadata."""
|
259
|
-
data = await
|
278
|
+
data = await self._read_json(request)
|
260
279
|
metadata = self.get_rank_metadata(rank)
|
261
280
|
count = metadata.delete_keys(data["keys"])
|
262
|
-
return
|
281
|
+
return Response(status_code=204)
|
263
282
|
|
264
283
|
async def clear(self, rank: int):
|
265
284
|
"""Clear all metadata for a rank."""
|
266
285
|
metadata = self.get_rank_metadata(rank)
|
267
286
|
metadata.clear_all()
|
268
|
-
return
|
287
|
+
return Response(status_code=204)
|
269
288
|
|
270
289
|
async def get_page_indices(self, rank: int, request: Request):
|
271
290
|
"""Get page indices for keys."""
|
272
|
-
data = await
|
291
|
+
data = await self._read_json(request)
|
273
292
|
metadata = self.get_rank_metadata(rank)
|
274
293
|
keys = data["keys"]
|
275
294
|
results = metadata.get_page_indices(keys)
|
276
|
-
return {"indices": results}
|
295
|
+
return self._json_response({"indices": results})
|
277
296
|
|
278
297
|
def run(self, host: str = "0.0.0.0", port: int = 18000):
|
279
298
|
"""Run the metadata server."""
|
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
|
|
309
328
|
status_forcelist=[500, 502, 503, 504],
|
310
329
|
allowed_methods=["GET", "POST"],
|
311
330
|
)
|
312
|
-
adapter = HTTPAdapter(
|
331
|
+
adapter = HTTPAdapter(
|
332
|
+
max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
|
333
|
+
)
|
313
334
|
self._session.mount("http://", adapter)
|
314
335
|
|
315
336
|
def _post(self, endpoint: str, json_data: dict) -> dict:
|
316
337
|
try:
|
317
|
-
|
338
|
+
url = f"{self.base_url}/{endpoint}"
|
339
|
+
headers = {"Content-Type": "application/json"}
|
340
|
+
payload = orjson.dumps(json_data) # type: ignore[union-attr]
|
341
|
+
response = self._session.post(url, data=payload, headers=headers)
|
318
342
|
response.raise_for_status()
|
319
|
-
|
343
|
+
|
344
|
+
if response.status_code == 204 or not response.content:
|
345
|
+
return {}
|
346
|
+
return orjson.loads(response.content) # type: ignore[union-attr]
|
320
347
|
except requests.exceptions.RequestException as e:
|
321
348
|
logging.error(f"Failed to POST to {endpoint} after retries: {e}")
|
322
349
|
raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
|
@@ -113,6 +113,8 @@ def synchronized():
|
|
113
113
|
|
114
114
|
|
115
115
|
class HiCacheHF3FS(HiCacheStorage):
|
116
|
+
"""HiCache backend that stores KV cache pages in HF3FS files."""
|
117
|
+
|
116
118
|
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
117
119
|
|
118
120
|
def __init__(
|
@@ -126,6 +128,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
126
128
|
dtype: torch.dtype,
|
127
129
|
metadata_client: Hf3fsMetadataInterface,
|
128
130
|
is_mla_model: bool = False,
|
131
|
+
is_page_first_layout: bool = False,
|
129
132
|
):
|
130
133
|
self.rank = rank
|
131
134
|
self.file_path = file_path
|
@@ -136,6 +139,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
136
139
|
self.dtype = dtype
|
137
140
|
self.metadata_client = metadata_client
|
138
141
|
self.is_mla_model = is_mla_model
|
142
|
+
self.is_page_first_layout = is_page_first_layout
|
139
143
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
140
144
|
self.num_pages = self.file_size // self.bytes_per_page
|
141
145
|
self.skip_backup = False
|
@@ -176,15 +180,36 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
176
180
|
dtype: torch.dtype,
|
177
181
|
storage_config: HiCacheStorageConfig = None,
|
178
182
|
) -> "HiCacheHF3FS":
|
183
|
+
"""Create a HiCacheHF3FS instance from environment configuration.
|
184
|
+
|
185
|
+
Environment:
|
186
|
+
- Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
|
187
|
+
- Falls back to a local single-machine config when the env var is not set.
|
188
|
+
|
189
|
+
Raises:
|
190
|
+
ValueError: If MLA Model is requested without global metadata server or required keys are missing.
|
191
|
+
"""
|
179
192
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
180
193
|
Hf3fsGlobalMetadataClient,
|
181
194
|
Hf3fsLocalMetadataClient,
|
182
195
|
)
|
183
196
|
|
184
|
-
|
197
|
+
if storage_config is not None:
|
198
|
+
rank, is_mla_model, is_page_first_layout = (
|
199
|
+
storage_config.tp_rank,
|
200
|
+
storage_config.is_mla_model,
|
201
|
+
storage_config.is_page_first_layout,
|
202
|
+
)
|
203
|
+
else:
|
204
|
+
rank, is_mla_model, is_page_first_layout = 0, False, False
|
205
|
+
|
206
|
+
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
|
185
207
|
|
186
208
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
187
209
|
if not config_path:
|
210
|
+
if is_mla_model:
|
211
|
+
raise ValueError(mla_unsupported_msg)
|
212
|
+
|
188
213
|
return HiCacheHF3FS(
|
189
214
|
rank=rank,
|
190
215
|
file_path=f"/data/hicache.{rank}.bin",
|
@@ -194,6 +219,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
194
219
|
entries=8,
|
195
220
|
dtype=dtype,
|
196
221
|
metadata_client=Hf3fsLocalMetadataClient(),
|
222
|
+
is_page_first_layout=is_page_first_layout,
|
197
223
|
)
|
198
224
|
|
199
225
|
try:
|
@@ -214,25 +240,27 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
214
240
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
215
241
|
|
216
242
|
# Choose metadata client based on configuration
|
217
|
-
|
218
|
-
if "metadata_server_url" in config and config["metadata_server_url"]:
|
243
|
+
if config.get("metadata_server_url"):
|
219
244
|
# Use global metadata client to connect to metadata server
|
220
245
|
metadata_server_url = config["metadata_server_url"]
|
221
246
|
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
222
247
|
|
223
|
-
# Enable MLA optimization only when using the global metadata client
|
224
|
-
is_mla_model = storage_config.is_mla_model if storage_config else False
|
225
248
|
logger.info(
|
226
249
|
f"Using global metadata client with server url: {metadata_server_url}"
|
227
250
|
)
|
228
251
|
else:
|
252
|
+
# Enable MLA optimization only when using the global metadata client
|
253
|
+
if is_mla_model:
|
254
|
+
raise ValueError(mla_unsupported_msg)
|
255
|
+
|
229
256
|
# Use local metadata client for single-machine deployment
|
230
257
|
metadata_client = Hf3fsLocalMetadataClient()
|
231
258
|
|
259
|
+
rank_for_path = 0 if is_mla_model else rank
|
232
260
|
return HiCacheHF3FS(
|
233
261
|
rank=rank,
|
234
262
|
# Let all ranks use the same file path for MLA model
|
235
|
-
file_path=f"{config['file_path_prefix']}.{
|
263
|
+
file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
|
236
264
|
file_size=int(config["file_size"]),
|
237
265
|
numjobs=int(config["numjobs"]),
|
238
266
|
bytes_per_page=bytes_per_page,
|
@@ -240,6 +268,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
240
268
|
dtype=dtype,
|
241
269
|
metadata_client=metadata_client,
|
242
270
|
is_mla_model=is_mla_model,
|
271
|
+
is_page_first_layout=is_page_first_layout,
|
243
272
|
)
|
244
273
|
|
245
274
|
def get(
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import hashlib
|
2
1
|
import json
|
3
2
|
import logging
|
4
3
|
import os
|
@@ -6,10 +5,8 @@ import uuid
|
|
6
5
|
from dataclasses import dataclass
|
7
6
|
from typing import Any, List, Optional
|
8
7
|
|
9
|
-
import numpy as np
|
10
8
|
import torch
|
11
9
|
|
12
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
13
10
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
14
11
|
|
15
12
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
@@ -154,21 +151,36 @@ class MooncakeStore(HiCacheStorage):
|
|
154
151
|
target_location: Optional[List[int]] = None,
|
155
152
|
target_sizes: Optional[List[int]] = None,
|
156
153
|
) -> bool:
|
157
|
-
|
154
|
+
# Only support zero copy set for now
|
155
|
+
assert target_location is not None and target_sizes is not None
|
156
|
+
exist_result = self._batch_exist([key])
|
157
|
+
if exist_result[0] == 1:
|
158
|
+
return True
|
159
|
+
put_result = self._put_batch_zero_copy_impl(
|
160
|
+
[key], [target_location], [target_sizes]
|
161
|
+
)
|
162
|
+
return put_result[0] == 0
|
158
163
|
|
159
164
|
def batch_set(
|
160
165
|
self,
|
161
166
|
keys: List[str],
|
162
167
|
values: Optional[List[torch.Tensor]] = None,
|
163
|
-
|
168
|
+
target_locations: Optional[List[int]] = None,
|
164
169
|
target_sizes: Optional[List[int]] = None,
|
165
170
|
) -> bool:
|
166
|
-
|
171
|
+
# Only support zero copy set for now
|
172
|
+
assert target_locations is not None and target_sizes is not None
|
173
|
+
assert len(keys) == len(target_locations) == len(target_sizes)
|
174
|
+
|
167
175
|
if len(keys) == 0:
|
168
176
|
return False
|
169
177
|
|
170
178
|
for i in range(len(keys)):
|
171
|
-
if
|
179
|
+
if (
|
180
|
+
keys[i] is None
|
181
|
+
or target_locations[i] is None
|
182
|
+
or target_sizes[i] is None
|
183
|
+
):
|
172
184
|
return False
|
173
185
|
|
174
186
|
exist_result = self._batch_exist(keys)
|
@@ -179,7 +191,7 @@ class MooncakeStore(HiCacheStorage):
|
|
179
191
|
for i in range(len(keys)):
|
180
192
|
if exist_result[i] != 1:
|
181
193
|
set_keys.append(keys[i])
|
182
|
-
set_target_locations.append(
|
194
|
+
set_target_locations.append(target_locations[i])
|
183
195
|
set_target_sizes.append(target_sizes[i])
|
184
196
|
set_indices.append(i)
|
185
197
|
# Only set non-existing keys to storage
|
@@ -204,18 +216,24 @@ class MooncakeStore(HiCacheStorage):
|
|
204
216
|
target_location: Optional[Any] = None,
|
205
217
|
target_sizes: Optional[Any] = None,
|
206
218
|
) -> bool:
|
207
|
-
|
219
|
+
assert target_location is not None and target_sizes is not None
|
220
|
+
get_result = self._get_batch_zero_copy_impl(
|
221
|
+
[key], [target_location], [target_sizes]
|
222
|
+
)
|
223
|
+
return get_result[0] >= 0
|
208
224
|
|
209
225
|
def batch_get(
|
210
226
|
self,
|
211
227
|
keys: List[str],
|
212
|
-
|
228
|
+
target_locations: Optional[Any] = None,
|
213
229
|
target_sizes: Optional[Any] = None,
|
214
230
|
) -> int:
|
215
|
-
assert len(keys) == len(
|
231
|
+
assert len(keys) == len(target_locations) == len(target_sizes)
|
216
232
|
if len(keys) == 0:
|
217
233
|
return 0
|
218
|
-
get_result = self._get_batch_zero_copy_impl(
|
234
|
+
get_result = self._get_batch_zero_copy_impl(
|
235
|
+
keys, target_locations, target_sizes
|
236
|
+
)
|
219
237
|
if self.is_mla_backend:
|
220
238
|
key_multiplier = 1
|
221
239
|
else:
|
@@ -226,7 +244,8 @@ class MooncakeStore(HiCacheStorage):
|
|
226
244
|
return len(keys) // key_multiplier
|
227
245
|
|
228
246
|
def exists(self, key) -> bool:
|
229
|
-
|
247
|
+
exist_result = self._batch_exist([key])
|
248
|
+
return exist_result[0] == 1
|
230
249
|
|
231
250
|
def batch_exists(self, keys) -> int:
|
232
251
|
if self.is_mla_backend:
|