sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/interpreter.py +1 -1
  2. sglang/srt/configs/internvl.py +6 -0
  3. sglang/srt/configs/model_config.py +2 -1
  4. sglang/srt/disaggregation/mini_lb.py +2 -2
  5. sglang/srt/distributed/parallel_state.py +46 -41
  6. sglang/srt/entrypoints/engine.py +1 -1
  7. sglang/srt/entrypoints/http_server.py +5 -1
  8. sglang/srt/entrypoints/openai/protocol.py +3 -3
  9. sglang/srt/entrypoints/openai/serving_chat.py +3 -3
  10. sglang/srt/entrypoints/openai/serving_completions.py +3 -1
  11. sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
  12. sglang/srt/entrypoints/openai/serving_responses.py +1 -1
  13. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  14. sglang/srt/layers/attention/aiter_backend.py +93 -68
  15. sglang/srt/layers/communicator.py +45 -7
  16. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  17. sglang/srt/layers/moe/ep_moe/layer.py +2 -7
  18. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  21. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  24. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  25. sglang/srt/layers/moe/utils.py +0 -1
  26. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
  27. sglang/srt/layers/quantization/modelopt_quant.py +35 -2
  28. sglang/srt/layers/quantization/mxfp4.py +4 -1
  29. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  30. sglang/srt/layers/quantization/quark/utils.py +97 -0
  31. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  32. sglang/srt/layers/quantization/w4afp8.py +30 -25
  33. sglang/srt/layers/rocm_linear_utils.py +44 -0
  34. sglang/srt/layers/rotary_embedding.py +0 -18
  35. sglang/srt/managers/cache_controller.py +42 -39
  36. sglang/srt/managers/detokenizer_manager.py +0 -34
  37. sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
  38. sglang/srt/managers/schedule_policy.py +3 -2
  39. sglang/srt/managers/scheduler.py +7 -100
  40. sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
  41. sglang/srt/managers/template_manager.py +3 -3
  42. sglang/srt/managers/tokenizer_manager.py +1 -0
  43. sglang/srt/mem_cache/allocator.py +1 -1
  44. sglang/srt/mem_cache/hicache_storage.py +15 -10
  45. sglang/srt/mem_cache/hiradix_cache.py +16 -0
  46. sglang/srt/mem_cache/memory_pool_host.py +18 -11
  47. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  48. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
  49. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
  50. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  51. sglang/srt/metrics/collector.py +12 -4
  52. sglang/srt/metrics/utils.py +48 -0
  53. sglang/srt/model_executor/forward_batch_info.py +16 -17
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +245 -36
  56. sglang/srt/models/glm4_moe.py +10 -1
  57. sglang/srt/models/gpt_oss.py +5 -4
  58. sglang/srt/models/internvl.py +28 -0
  59. sglang/srt/models/longcat_flash.py +26 -15
  60. sglang/srt/models/longcat_flash_nextn.py +23 -15
  61. sglang/srt/models/minicpmv.py +165 -3
  62. sglang/srt/models/qwen2_moe.py +4 -1
  63. sglang/srt/models/qwen3.py +8 -2
  64. sglang/srt/models/qwen3_moe.py +39 -8
  65. sglang/srt/models/torch_native_llama.py +1 -1
  66. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  67. sglang/srt/server_args.py +79 -2
  68. sglang/srt/speculative/eagle_worker.py +158 -112
  69. sglang/srt/utils.py +12 -10
  70. sglang/test/few_shot_gsm8k.py +1 -0
  71. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  72. sglang/utils.py +1 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
  75. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
  76. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  77. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  78. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  79. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  80. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  81. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  82. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
  83. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -283,7 +283,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
283
283
  self.swa_attn_allocator.clear()
284
284
  self.full_attn_allocator.clear()
285
285
  self.full_to_swa_index_mapping.fill_(0)
286
- self.is_in_free_group = False
286
+ self.is_not_in_free_group = True
287
287
  self.free_group = []
288
288
 
289
289
 
@@ -27,6 +27,7 @@ class HiCacheStorageConfig:
27
27
  tp_rank: int
28
28
  tp_size: int
29
29
  is_mla_model: bool
30
+ is_page_first_layout: bool
30
31
  model_name: Optional[str]
31
32
  extra_config: Optional[dict] = None
32
33
 
@@ -135,18 +136,24 @@ class HiCacheFile(HiCacheStorage):
135
136
  ):
136
137
  self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
137
138
 
138
- tp_rank, tp_size, is_mla = (
139
+ tp_rank, tp_size, model_name, is_mla_model = (
139
140
  storage_config.tp_rank,
140
141
  storage_config.tp_size,
142
+ storage_config.model_name,
141
143
  storage_config.is_mla_model,
142
144
  )
143
- self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
145
+ model_name = "-".join(model_name.split("/")) if model_name else ""
146
+ if is_mla_model:
147
+ self.config_suffix = f"_{model_name}"
148
+ else:
149
+ self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
150
+
144
151
  if not os.path.exists(self.file_path) and tp_rank == 0:
145
152
  os.makedirs(self.file_path)
146
153
  logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
147
154
 
148
155
  def _get_suffixed_key(self, key: str) -> str:
149
- return key + self.tp_suffix
156
+ return key + self.config_suffix
150
157
 
151
158
  def get(
152
159
  self,
@@ -157,13 +164,11 @@ class HiCacheFile(HiCacheStorage):
157
164
  key = self._get_suffixed_key(key)
158
165
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
159
166
  try:
160
- # Load directly into target_location's memory buffer
161
- with open(tensor_path, "rb") as f:
162
- target_location.set_(
163
- torch.frombuffer(f.read(), dtype=target_location.dtype)
164
- .reshape(target_location.shape)
165
- .untyped_storage()
166
- )
167
+ expected = target_location.numel() * target_location.element_size()
168
+ with open(tensor_path, "rb", buffering=0) as f:
169
+ buf = memoryview(target_location.view(torch.uint8).contiguous().numpy())
170
+ if f.readinto(buf) != expected:
171
+ raise IOError(f"Short read for {key}")
167
172
  return target_location
168
173
  except FileNotFoundError:
169
174
  logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
@@ -771,3 +771,19 @@ class HiRadixCache(RadixCache):
771
771
  if not cur_child.evicted:
772
772
  stack.append(cur_child)
773
773
  return ret_list
774
+
775
+ def release_aborted_request(self, rid: str):
776
+ if rid not in self.ongoing_prefetch:
777
+ return
778
+
779
+ last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
780
+ if operation.host_indices is None:
781
+ return
782
+
783
+ completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
784
+ if self.tp_world_size > 1:
785
+ torch.distributed.barrier(group=self.tp_group)
786
+ last_host_node.release_host()
787
+ del self.ongoing_prefetch[rid]
788
+ self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
789
+ self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
@@ -467,6 +467,7 @@ class MHATokenToKVPoolHost(HostKVCache):
467
467
  ptr_list = []
468
468
  key_list = []
469
469
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
470
+ indices = indices.tolist()
470
471
  v_offset = (
471
472
  self.layer_num
472
473
  * self.size
@@ -499,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache):
499
500
  element_size_list = [element_size] * len(key_list)
500
501
  return key_list, ptr_list, element_size_list
501
502
 
502
- def get_buffer_with_hash(self, keys, indices):
503
+ def get_buffer_with_hash(self, keys, indices=None):
503
504
  assert self.layout == "page_first"
504
- assert len(keys) == (len(indices) // self.page_size)
505
+ assert indices is None or (len(keys) == (len(indices) // self.page_size))
505
506
 
506
507
  key_list = []
507
508
  buf_list = []
508
509
 
509
- for key, i in zip(keys, range(0, len(indices), self.page_size)):
510
+ for i in range(len(keys)):
511
+ key = keys[i]
510
512
  key_list.append(f"{key}-k")
511
- buf_list.append(self.k_buffer[i : i + self.page_size])
512
513
  key_list.append(f"{key}-v")
513
- buf_list.append(self.v_buffer[i : i + self.page_size])
514
+ if indices is not None:
515
+ index = indices[i * self.page_size]
516
+ buf_list.append(self.k_buffer[index : index + self.page_size])
517
+ buf_list.append(self.v_buffer[index : index + self.page_size])
514
518
 
515
- return key_list, buf_list
519
+ return key_list, buf_list, 2
516
520
 
517
521
 
518
522
  class MLATokenToKVPoolHost(HostKVCache):
@@ -706,6 +710,7 @@ class MLATokenToKVPoolHost(HostKVCache):
706
710
  ptr_list = []
707
711
  key_list = []
708
712
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
713
+ indices = indices.tolist()
709
714
  for index in range(0, len(indices), self.page_size):
710
715
  k_ptr = (
711
716
  kv_buffer_data_ptr
@@ -726,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache):
726
731
  element_size_list = [element_size] * len(key_list)
727
732
  return key_list, ptr_list, element_size_list
728
733
 
729
- def get_buffer_with_hash(self, keys, indices):
734
+ def get_buffer_with_hash(self, keys, indices=None):
730
735
  assert self.layout == "page_first"
731
- assert len(keys) == (len(indices) // self.page_size)
736
+ assert indices is None or (len(keys) == (len(indices) // self.page_size))
732
737
 
733
738
  buf_list = []
734
739
 
735
- for i in range(0, len(indices), self.page_size):
736
- buf_list.append(self.kv_buffer[i : i + self.page_size])
740
+ if indices is not None:
741
+ for i in range(len(keys)):
742
+ index = indices[i * self.page_size]
743
+ buf_list.append(self.kv_buffer[index : index + self.page_size])
737
744
 
738
- return keys, buf_list
745
+ return keys, buf_list, 1
@@ -4,10 +4,12 @@ import json
4
4
  import logging
5
5
  import threading
6
6
  from pathlib import Path
7
- from typing import Dict, List, Optional, Tuple
7
+ from typing import Dict, List, Optional, OrderedDict, Tuple
8
8
 
9
+ import orjson
9
10
  import requests
10
- from fastapi import FastAPI, HTTPException, Request, status
11
+ from fastapi import FastAPI, HTTPException, Request, Response
12
+ from fastapi.responses import ORJSONResponse
11
13
  from requests.adapters import HTTPAdapter
12
14
  from urllib3.util.retry import Retry
13
15
 
@@ -24,10 +26,10 @@ class RankMetadata:
24
26
  """Holds all metadata for a single rank."""
25
27
 
26
28
  def __init__(self, num_pages: int):
27
- self.lock = threading.RLock()
29
+ self.lock = threading.Lock()
28
30
  self.num_pages = num_pages
29
31
  self.free_pages: List[int] = list(range(num_pages))
30
- self.key_to_index: Dict[str, int] = {}
32
+ self.key_to_index: OrderedDict[str, int] = OrderedDict()
31
33
  # Todo: Support multi files for HF3FS
32
34
 
33
35
  def exists_keys(self, keys: List[str]) -> List[bool]:
@@ -46,16 +48,18 @@ class RankMetadata:
46
48
  for i, (key, prefix_key) in enumerate(keys):
47
49
  if key in self.key_to_index:
48
50
  results[i] = (True, self.key_to_index[key])
51
+ self.key_to_index.move_to_end(key)
49
52
  else:
50
53
  new_keys_to_process.append((i, key, prefix_key))
51
54
 
52
55
  # Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
53
56
  for i, key, prefix_key in new_keys_to_process:
54
57
  if len(self.free_pages) > 0:
55
- page_idx = self.free_pages.pop()
56
- results[i] = (False, page_idx)
58
+ page_index = self.free_pages.pop()
57
59
  else:
58
- results[i] = (False, -1)
60
+ page_index = self.key_to_index.popitem(last=False)[1]
61
+
62
+ results[i] = (False, page_index)
59
63
 
60
64
  return results
61
65
 
@@ -68,6 +72,7 @@ class RankMetadata:
68
72
  with self.lock:
69
73
  for key, page_index in written_keys_to_confirm:
70
74
  self.key_to_index[key] = page_index
75
+ self.key_to_index.move_to_end(key)
71
76
 
72
77
  for page_index in pages_to_release:
73
78
  if page_index not in self.free_pages:
@@ -94,7 +99,14 @@ class RankMetadata:
94
99
  def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
95
100
  """Get page indices for keys."""
96
101
  with self.lock:
97
- return [self.key_to_index.get(key) for key in keys]
102
+ results = []
103
+ for key in keys:
104
+ if key in self.key_to_index:
105
+ results.append(self.key_to_index[key])
106
+ self.key_to_index.move_to_end(key)
107
+ else:
108
+ results.append(None)
109
+ return results
98
110
 
99
111
 
100
112
  class GlobalMetadataState:
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
182
194
 
183
195
  def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
184
196
  self.state = GlobalMetadataState(persistence_path, save_interval)
185
- self.app = FastAPI()
197
+ self.app = FastAPI(default_response_class=ORJSONResponse)
198
+
186
199
  self._setup_routes()
187
200
 
188
201
  def _setup_routes(self):
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
199
212
 
200
213
  def get_rank_metadata(self, rank: int) -> RankMetadata:
201
214
  """Get rank metadata with proper error handling."""
202
- with self.state.global_lock:
203
- if rank not in self.state.ranks:
204
- raise HTTPException(
205
- status_code=404,
206
- detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
207
- )
208
- return self.state.ranks[rank]
215
+ if rank not in self.state.ranks:
216
+ raise HTTPException(
217
+ status_code=404,
218
+ detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
219
+ )
220
+ return self.state.ranks[rank]
221
+
222
+ async def _read_json(self, request: Request) -> dict:
223
+ """Parse request JSON using orjson if available."""
224
+ body = await request.body()
225
+ return orjson.loads(body)
226
+
227
+ def _json_response(self, content: dict):
228
+ """Return ORJSONResponse when available to bypass jsonable_encoder."""
229
+ return ORJSONResponse(content)
209
230
 
210
231
  async def initialize(self, rank: int, request: Request):
211
232
  """Initialize a rank with specified number of pages."""
212
- data = await request.json()
233
+ data = await self._read_json(request)
213
234
  num_pages = data["num_pages"]
214
235
  with self.state.global_lock:
215
236
  if rank in self.state.ranks:
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
223
244
  else:
224
245
  logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
225
246
  self.state.ranks[rank] = RankMetadata(num_pages)
226
- return {"message": f"Rank {rank} is ready."}
247
+ return Response(status_code=204)
227
248
 
228
249
  async def exists(self, rank: int, request: Request):
229
250
  """Check if keys exist in metadata."""
230
- data = await request.json()
251
+ data = await self._read_json(request)
231
252
  keys = data["keys"]
232
253
  metadata = self.get_rank_metadata(rank)
233
254
  results = metadata.exists_keys(keys)
234
- return {"exists": results}
255
+ return self._json_response({"exists": results})
235
256
 
236
257
  async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
237
258
  """Reserve and allocate page indices for keys."""
238
- data = await request.json()
259
+ data = await self._read_json(request)
239
260
  metadata = self.get_rank_metadata(rank)
240
261
  keys = data["keys"]
241
262
  results = metadata.reserve_and_allocate_page_indices(keys)
242
- return {"indices": results}
263
+ return self._json_response({"indices": results})
243
264
 
244
265
  async def confirm_write(self, rank: int, request: Request):
245
266
  """Confirm write operations and release pages."""
246
- data = await request.json()
267
+ data = await self._read_json(request)
247
268
  metadata = self.get_rank_metadata(rank)
248
269
  success_written_keys = data.get("written_keys_to_confirm", [])
249
270
  released_pages = data.get("pages_to_release", [])
250
271
 
251
272
  metadata.confirm_write(success_written_keys, released_pages)
252
273
 
253
- return {
254
- "message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
255
- }
274
+ return Response(status_code=204)
256
275
 
257
276
  async def delete_keys(self, rank: int, request: Request):
258
277
  """Delete keys from metadata."""
259
- data = await request.json()
278
+ data = await self._read_json(request)
260
279
  metadata = self.get_rank_metadata(rank)
261
280
  count = metadata.delete_keys(data["keys"])
262
- return {"message": f"Rank {rank}: {count} keys deleted."}
281
+ return Response(status_code=204)
263
282
 
264
283
  async def clear(self, rank: int):
265
284
  """Clear all metadata for a rank."""
266
285
  metadata = self.get_rank_metadata(rank)
267
286
  metadata.clear_all()
268
- return {"message": f"Rank {rank}: Metadata cleared."}
287
+ return Response(status_code=204)
269
288
 
270
289
  async def get_page_indices(self, rank: int, request: Request):
271
290
  """Get page indices for keys."""
272
- data = await request.json()
291
+ data = await self._read_json(request)
273
292
  metadata = self.get_rank_metadata(rank)
274
293
  keys = data["keys"]
275
294
  results = metadata.get_page_indices(keys)
276
- return {"indices": results}
295
+ return self._json_response({"indices": results})
277
296
 
278
297
  def run(self, host: str = "0.0.0.0", port: int = 18000):
279
298
  """Run the metadata server."""
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
309
328
  status_forcelist=[500, 502, 503, 504],
310
329
  allowed_methods=["GET", "POST"],
311
330
  )
312
- adapter = HTTPAdapter(max_retries=retry_strategy)
331
+ adapter = HTTPAdapter(
332
+ max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
333
+ )
313
334
  self._session.mount("http://", adapter)
314
335
 
315
336
  def _post(self, endpoint: str, json_data: dict) -> dict:
316
337
  try:
317
- response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
338
+ url = f"{self.base_url}/{endpoint}"
339
+ headers = {"Content-Type": "application/json"}
340
+ payload = orjson.dumps(json_data) # type: ignore[union-attr]
341
+ response = self._session.post(url, data=payload, headers=headers)
318
342
  response.raise_for_status()
319
- return response.json()
343
+
344
+ if response.status_code == 204 or not response.content:
345
+ return {}
346
+ return orjson.loads(response.content) # type: ignore[union-attr]
320
347
  except requests.exceptions.RequestException as e:
321
348
  logging.error(f"Failed to POST to {endpoint} after retries: {e}")
322
349
  raise RuntimeError(f"Failed to connect to metadata server: {e}") from e
@@ -113,6 +113,8 @@ def synchronized():
113
113
 
114
114
 
115
115
  class HiCacheHF3FS(HiCacheStorage):
116
+ """HiCache backend that stores KV cache pages in HF3FS files."""
117
+
116
118
  default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
117
119
 
118
120
  def __init__(
@@ -126,6 +128,7 @@ class HiCacheHF3FS(HiCacheStorage):
126
128
  dtype: torch.dtype,
127
129
  metadata_client: Hf3fsMetadataInterface,
128
130
  is_mla_model: bool = False,
131
+ is_page_first_layout: bool = False,
129
132
  ):
130
133
  self.rank = rank
131
134
  self.file_path = file_path
@@ -136,6 +139,7 @@ class HiCacheHF3FS(HiCacheStorage):
136
139
  self.dtype = dtype
137
140
  self.metadata_client = metadata_client
138
141
  self.is_mla_model = is_mla_model
142
+ self.is_page_first_layout = is_page_first_layout
139
143
  self.numel = self.bytes_per_page // self.dtype.itemsize
140
144
  self.num_pages = self.file_size // self.bytes_per_page
141
145
  self.skip_backup = False
@@ -176,15 +180,36 @@ class HiCacheHF3FS(HiCacheStorage):
176
180
  dtype: torch.dtype,
177
181
  storage_config: HiCacheStorageConfig = None,
178
182
  ) -> "HiCacheHF3FS":
183
+ """Create a HiCacheHF3FS instance from environment configuration.
184
+
185
+ Environment:
186
+ - Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
187
+ - Falls back to a local single-machine config when the env var is not set.
188
+
189
+ Raises:
190
+ ValueError: If MLA Model is requested without global metadata server or required keys are missing.
191
+ """
179
192
  from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
180
193
  Hf3fsGlobalMetadataClient,
181
194
  Hf3fsLocalMetadataClient,
182
195
  )
183
196
 
184
- rank = storage_config.tp_rank if storage_config is not None else 0
197
+ if storage_config is not None:
198
+ rank, is_mla_model, is_page_first_layout = (
199
+ storage_config.tp_rank,
200
+ storage_config.is_mla_model,
201
+ storage_config.is_page_first_layout,
202
+ )
203
+ else:
204
+ rank, is_mla_model, is_page_first_layout = 0, False, False
205
+
206
+ mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
185
207
 
186
208
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
187
209
  if not config_path:
210
+ if is_mla_model:
211
+ raise ValueError(mla_unsupported_msg)
212
+
188
213
  return HiCacheHF3FS(
189
214
  rank=rank,
190
215
  file_path=f"/data/hicache.{rank}.bin",
@@ -194,6 +219,7 @@ class HiCacheHF3FS(HiCacheStorage):
194
219
  entries=8,
195
220
  dtype=dtype,
196
221
  metadata_client=Hf3fsLocalMetadataClient(),
222
+ is_page_first_layout=is_page_first_layout,
197
223
  )
198
224
 
199
225
  try:
@@ -214,25 +240,27 @@ class HiCacheHF3FS(HiCacheStorage):
214
240
  raise ValueError(f"Missing required keys in config: {missing_keys}")
215
241
 
216
242
  # Choose metadata client based on configuration
217
- is_mla_model = False
218
- if "metadata_server_url" in config and config["metadata_server_url"]:
243
+ if config.get("metadata_server_url"):
219
244
  # Use global metadata client to connect to metadata server
220
245
  metadata_server_url = config["metadata_server_url"]
221
246
  metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
222
247
 
223
- # Enable MLA optimization only when using the global metadata client
224
- is_mla_model = storage_config.is_mla_model if storage_config else False
225
248
  logger.info(
226
249
  f"Using global metadata client with server url: {metadata_server_url}"
227
250
  )
228
251
  else:
252
+ # Enable MLA optimization only when using the global metadata client
253
+ if is_mla_model:
254
+ raise ValueError(mla_unsupported_msg)
255
+
229
256
  # Use local metadata client for single-machine deployment
230
257
  metadata_client = Hf3fsLocalMetadataClient()
231
258
 
259
+ rank_for_path = 0 if is_mla_model else rank
232
260
  return HiCacheHF3FS(
233
261
  rank=rank,
234
262
  # Let all ranks use the same file path for MLA model
235
- file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
263
+ file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
236
264
  file_size=int(config["file_size"]),
237
265
  numjobs=int(config["numjobs"]),
238
266
  bytes_per_page=bytes_per_page,
@@ -240,6 +268,7 @@ class HiCacheHF3FS(HiCacheStorage):
240
268
  dtype=dtype,
241
269
  metadata_client=metadata_client,
242
270
  is_mla_model=is_mla_model,
271
+ is_page_first_layout=is_page_first_layout,
243
272
  )
244
273
 
245
274
  def get(
@@ -1,4 +1,3 @@
1
- import hashlib
2
1
  import json
3
2
  import logging
4
3
  import os
@@ -6,10 +5,8 @@ import uuid
6
5
  from dataclasses import dataclass
7
6
  from typing import Any, List, Optional
8
7
 
9
- import numpy as np
10
8
  import torch
11
9
 
12
- from sglang.srt.distributed import get_tensor_model_parallel_rank
13
10
  from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
14
11
 
15
12
  DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
@@ -154,21 +151,36 @@ class MooncakeStore(HiCacheStorage):
154
151
  target_location: Optional[List[int]] = None,
155
152
  target_sizes: Optional[List[int]] = None,
156
153
  ) -> bool:
157
- return self.batch_set([key], [value], [target_location], [target_sizes])
154
+ # Only support zero copy set for now
155
+ assert target_location is not None and target_sizes is not None
156
+ exist_result = self._batch_exist([key])
157
+ if exist_result[0] == 1:
158
+ return True
159
+ put_result = self._put_batch_zero_copy_impl(
160
+ [key], [target_location], [target_sizes]
161
+ )
162
+ return put_result[0] == 0
158
163
 
159
164
  def batch_set(
160
165
  self,
161
166
  keys: List[str],
162
167
  values: Optional[List[torch.Tensor]] = None,
163
- target_location: Optional[List[int]] = None,
168
+ target_locations: Optional[List[int]] = None,
164
169
  target_sizes: Optional[List[int]] = None,
165
170
  ) -> bool:
166
- assert len(keys) == len(target_location) == len(target_sizes)
171
+ # Only support zero copy set for now
172
+ assert target_locations is not None and target_sizes is not None
173
+ assert len(keys) == len(target_locations) == len(target_sizes)
174
+
167
175
  if len(keys) == 0:
168
176
  return False
169
177
 
170
178
  for i in range(len(keys)):
171
- if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
179
+ if (
180
+ keys[i] is None
181
+ or target_locations[i] is None
182
+ or target_sizes[i] is None
183
+ ):
172
184
  return False
173
185
 
174
186
  exist_result = self._batch_exist(keys)
@@ -179,7 +191,7 @@ class MooncakeStore(HiCacheStorage):
179
191
  for i in range(len(keys)):
180
192
  if exist_result[i] != 1:
181
193
  set_keys.append(keys[i])
182
- set_target_locations.append(target_location[i])
194
+ set_target_locations.append(target_locations[i])
183
195
  set_target_sizes.append(target_sizes[i])
184
196
  set_indices.append(i)
185
197
  # Only set non-existing keys to storage
@@ -204,18 +216,24 @@ class MooncakeStore(HiCacheStorage):
204
216
  target_location: Optional[Any] = None,
205
217
  target_sizes: Optional[Any] = None,
206
218
  ) -> bool:
207
- return self.batch_get([key], [target_location], [target_sizes]) == 1
219
+ assert target_location is not None and target_sizes is not None
220
+ get_result = self._get_batch_zero_copy_impl(
221
+ [key], [target_location], [target_sizes]
222
+ )
223
+ return get_result[0] >= 0
208
224
 
209
225
  def batch_get(
210
226
  self,
211
227
  keys: List[str],
212
- target_location: Optional[Any] = None,
228
+ target_locations: Optional[Any] = None,
213
229
  target_sizes: Optional[Any] = None,
214
230
  ) -> int:
215
- assert len(keys) == len(target_location) == len(target_sizes)
231
+ assert len(keys) == len(target_locations) == len(target_sizes)
216
232
  if len(keys) == 0:
217
233
  return 0
218
- get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
234
+ get_result = self._get_batch_zero_copy_impl(
235
+ keys, target_locations, target_sizes
236
+ )
219
237
  if self.is_mla_backend:
220
238
  key_multiplier = 1
221
239
  else:
@@ -226,7 +244,8 @@ class MooncakeStore(HiCacheStorage):
226
244
  return len(keys) // key_multiplier
227
245
 
228
246
  def exists(self, key) -> bool:
229
- return self.batch_exists([key]) > 0
247
+ exist_result = self._batch_exist([key])
248
+ return exist_result[0] == 1
230
249
 
231
250
  def batch_exists(self, keys) -> int:
232
251
  if self.is_mla_backend: