sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +6 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +23 -3
  11. sglang/srt/entrypoints/openai/protocol.py +3 -1
  12. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  13. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  14. sglang/srt/eplb/expert_distribution.py +5 -0
  15. sglang/srt/eplb/expert_location.py +17 -6
  16. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  17. sglang/srt/eplb/expert_location_updater.py +2 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/step3_detector.py +436 -0
  20. sglang/srt/hf_transformers_utils.py +2 -0
  21. sglang/srt/jinja_template_utils.py +4 -1
  22. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  23. sglang/srt/layers/moe/ep_moe/layer.py +98 -603
  24. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  29. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  30. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  31. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  32. sglang/srt/layers/moe/topk.py +6 -2
  33. sglang/srt/layers/quantization/fp8.py +0 -18
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  35. sglang/srt/layers/quantization/unquant.py +0 -8
  36. sglang/srt/layers/quantization/w4afp8.py +1 -0
  37. sglang/srt/managers/cache_controller.py +143 -45
  38. sglang/srt/managers/data_parallel_controller.py +6 -0
  39. sglang/srt/managers/io_struct.py +12 -2
  40. sglang/srt/managers/scheduler.py +116 -669
  41. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  42. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  43. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  44. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  45. sglang/srt/managers/template_manager.py +62 -19
  46. sglang/srt/managers/tokenizer_manager.py +166 -83
  47. sglang/srt/managers/tp_worker.py +9 -0
  48. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  49. sglang/srt/mem_cache/hicache_storage.py +45 -11
  50. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  51. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  52. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  53. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  54. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  55. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  56. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  57. sglang/srt/model_executor/model_runner.py +20 -13
  58. sglang/srt/models/arcee.py +532 -0
  59. sglang/srt/models/deepseek_v2.py +15 -56
  60. sglang/srt/models/glm4_moe.py +3 -1
  61. sglang/srt/models/granitemoe.py +3 -0
  62. sglang/srt/models/grok.py +3 -0
  63. sglang/srt/models/hunyuan.py +1 -0
  64. sglang/srt/models/llama4.py +3 -0
  65. sglang/srt/models/mixtral.py +3 -0
  66. sglang/srt/models/olmoe.py +3 -0
  67. sglang/srt/models/phimoe.py +1 -0
  68. sglang/srt/models/qwen3_moe.py +12 -69
  69. sglang/srt/models/step3_vl.py +994 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/poll_based_barrier.py +31 -0
  73. sglang/srt/reasoning_parser.py +2 -1
  74. sglang/srt/server_args.py +18 -13
  75. sglang/srt/speculative/eagle_worker.py +2 -0
  76. sglang/srt/two_batch_overlap.py +8 -3
  77. sglang/test/test_utils.py +53 -0
  78. sglang/utils.py +0 -11
  79. sglang/version.py +1 -1
  80. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
  81. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
  82. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -79,7 +79,9 @@ class HiRadixCache(RadixCache):
79
79
  self.write_through_threshold = (
80
80
  1 if hicache_write_policy == "write_through" else 3
81
81
  )
82
- self.write_through_threshold_storage = 3
82
+ self.write_through_threshold_storage = (
83
+ 1 if hicache_write_policy == "write_through" else 3
84
+ )
83
85
  self.load_back_threshold = 10
84
86
  super().__init__(
85
87
  req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -111,6 +113,7 @@ class HiRadixCache(RadixCache):
111
113
  )
112
114
  if host_indices is not None:
113
115
  node.host_value = host_indices
116
+ assert len(node.host_value) > 0
114
117
  self.ongoing_write_through[node.id] = node
115
118
  if not write_back:
116
119
  # no need to lock nodes if write back
@@ -388,10 +391,14 @@ class HiRadixCache(RadixCache):
388
391
  self.cache_controller.ack_backup_queue.get()
389
392
  )
390
393
  host_node = self.ongoing_backup[ack_id]
391
- if completed_tokens < len(host_node.key):
394
+ if completed_tokens == 0:
395
+ host_node.hash_value = None
396
+ elif completed_tokens < len(host_node.key):
392
397
  # backup is only partially successful, split the node
393
398
  new_node = self._split_node(host_node.key, host_node, completed_tokens)
394
399
  new_node.hash_value = hash_value
400
+ else:
401
+ host_node.hash_value = hash_value
395
402
  host_node.release_host()
396
403
  del self.ongoing_backup[ack_id]
397
404
 
@@ -431,6 +438,8 @@ class HiRadixCache(RadixCache):
431
438
  written_indices,
432
439
  hash_value[:min_completed_tokens],
433
440
  )
441
+ if len(written_indices):
442
+ self.cache_controller.mem_pool_host.update_prefetch(written_indices)
434
443
 
435
444
  self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
436
445
  self.cache_controller.mem_pool_host.free(
@@ -551,13 +560,11 @@ class HiRadixCache(RadixCache):
551
560
  prefix_len = self.key_match_fn(child.key, key)
552
561
  if prefix_len < len(child.key):
553
562
  new_node = self._split_node(child.key, child, prefix_len)
554
- self.inc_hit_count(new_node)
555
563
  if not new_node.evicted:
556
564
  value.append(new_node.value)
557
565
  node = new_node
558
566
  break
559
567
  else:
560
- self.inc_hit_count(child)
561
568
  if not child.evicted:
562
569
  value.append(child.value)
563
570
  node = child
@@ -587,6 +594,10 @@ class HiRadixCache(RadixCache):
587
594
  if child.backuped:
588
595
  new_node.host_value = child.host_value[:split_len]
589
596
  child.host_value = child.host_value[split_len:]
597
+
598
+ if child.hash_value:
599
+ new_node.hash_value = child.hash_value[: split_len // self.page_size]
600
+ child.hash_value = child.hash_value[split_len // self.page_size :]
590
601
  child.parent = new_node
591
602
  child.key = child.key[split_len:]
592
603
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
@@ -25,7 +25,6 @@ def synchronized(debug_only=False):
25
25
  @wraps(func)
26
26
  def wrapper(self, *args, **kwargs):
27
27
  if (not debug_only) or self.debug:
28
- return func(self, *args, **kwargs)
29
28
  with self.lock:
30
29
  return func(self, *args, **kwargs)
31
30
  else:
@@ -181,6 +180,15 @@ class HostKVCache(abc.ABC):
181
180
  )
182
181
  self.mem_state[indices] = MemoryStateInt.BACKUP
183
182
 
183
+ @synchronized(debug_only=True)
184
+ def update_prefetch(self, indices: torch.Tensor):
185
+ if not self.is_reserved(indices):
186
+ raise ValueError(
187
+ f"The host memory slots should be in RESERVED state before turning into BACKUP. "
188
+ f"Current state: {self.get_state(indices)}"
189
+ )
190
+ self.mem_state[indices] = MemoryStateInt.BACKUP
191
+
184
192
  @synchronized(debug_only=True)
185
193
  def update_synced(self, indices: torch.Tensor):
186
194
  self.mem_state[indices] = MemoryStateInt.SYNCED
@@ -257,6 +265,43 @@ class MHATokenToKVPoolHost(HostKVCache):
257
265
  self.head_dim,
258
266
  )
259
267
 
268
+ def get_buffer_meta(self, keys, indices):
269
+ ptr_list = []
270
+ key_list = []
271
+ kv_buffer_data_ptr = self.kv_buffer.data_ptr()
272
+ v_offset = (
273
+ self.layer_num
274
+ * self.size
275
+ * self.head_num
276
+ * self.head_dim
277
+ * self.dtype.itemsize
278
+ )
279
+ for index in range(0, len(indices), self.page_size):
280
+ for layer_id in range(self.layer_num):
281
+ k_ptr = (
282
+ kv_buffer_data_ptr
283
+ + indices[index]
284
+ * self.head_num
285
+ * self.head_dim
286
+ * self.dtype.itemsize
287
+ + layer_id
288
+ * self.size
289
+ * self.head_num
290
+ * self.head_dim
291
+ * self.dtype.itemsize
292
+ )
293
+ v_ptr = k_ptr + v_offset
294
+ ptr_list.append(k_ptr)
295
+ ptr_list.append(v_ptr)
296
+ key_ = keys[index // self.page_size]
297
+ key_list.append(f"{key_}_{layer_id}_k")
298
+ key_list.append(f"{key_}_{layer_id}_v")
299
+ element_size = (
300
+ self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
301
+ )
302
+ element_size_list = [element_size] * len(key_list)
303
+ return key_list, ptr_list, element_size_list
304
+
260
305
  @property
261
306
  def k_buffer(self):
262
307
  return self.kv_buffer[0]
@@ -317,3 +362,30 @@ class MLATokenToKVPoolHost(HostKVCache):
317
362
  1,
318
363
  self.kv_lora_rank + self.qk_rope_head_dim,
319
364
  )
365
+
366
+ def get_buffer_meta(self, keys, indices):
367
+ ptr_list = []
368
+ key_list = []
369
+ kv_buffer_data_ptr = self.kv_buffer.data_ptr()
370
+ for index in range(0, len(indices), self.page_size):
371
+ for layer_id in range(self.layer_num):
372
+ k_ptr = (
373
+ kv_buffer_data_ptr
374
+ + indices[index]
375
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
376
+ * self.dtype.itemsize
377
+ + layer_id
378
+ * self.size
379
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
380
+ * self.dtype.itemsize
381
+ )
382
+ ptr_list.append(k_ptr)
383
+ key_ = keys[index // self.page_size]
384
+ key_list.append(f"{key_}_{layer_id}_k")
385
+ element_size = (
386
+ self.dtype.itemsize
387
+ * self.page_size
388
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
389
+ )
390
+ element_size_list = [element_size] * len(key_list)
391
+ return key_list, ptr_list, element_size_list
@@ -0,0 +1,264 @@
1
+ import hashlib
2
+ import json
3
+ import logging
4
+ import os
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ from typing import Any, List, Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
13
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
14
+
15
+ DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
16
+ DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
22
+ local_rank = get_tensor_model_parallel_rank()
23
+ prefix_str = ""
24
+ if prefix_block_key:
25
+ if len(prefix_block_key):
26
+ prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
27
+ current_token_ids_bytes = np.array(current_page_ids).tobytes()
28
+ current_hash_object = hashlib.sha256(current_token_ids_bytes)
29
+ current_hash_hex = current_hash_object.hexdigest()
30
+ return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
31
+
32
+
33
+ @dataclass
34
+ class MooncakeStoreConfig:
35
+ local_hostname: str
36
+ metadata_server: str
37
+ global_segment_size: int
38
+ local_buffer_size: int
39
+ protocol: str
40
+ device_name: str
41
+ master_server_address: str
42
+
43
+ @staticmethod
44
+ def from_file() -> "MooncakeStoreConfig":
45
+ """Load the config from a JSON file."""
46
+ file_path = os.getenv("MOONCAKE_CONFIG_PATH")
47
+ if file_path is None:
48
+ raise ValueError(
49
+ "The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
50
+ )
51
+ with open(file_path) as fin:
52
+ config = json.load(fin)
53
+ return MooncakeStoreConfig(
54
+ local_hostname=config.get("local_hostname"),
55
+ metadata_server=config.get("metadata_server"),
56
+ global_segment_size=config.get(
57
+ "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
58
+ ),
59
+ local_buffer_size=config.get(
60
+ "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
61
+ ),
62
+ protocol=config.get("protocol", "tcp"),
63
+ device_name=config.get("device_name", "auto"),
64
+ master_server_address=config.get("master_server_address"),
65
+ )
66
+
67
+ @staticmethod
68
+ def load_from_env() -> "MooncakeStoreConfig":
69
+ """Load config from a file specified in the environment variable.
70
+ export MOONCAKE_MASTER=10.13.3.232:50051
71
+ export MOONCAKE_PROTOCOL="rdma"
72
+ export MOONCAKE_DEVICE="auto"
73
+ export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE"
74
+ """
75
+ # other required environment variables...
76
+ if not os.getenv("MOONCAKE_MASTER"):
77
+ raise ValueError("The environment variable 'MOONCAKE_MASTER' is not set.")
78
+ return MooncakeStoreConfig(
79
+ local_hostname=os.getenv("LOCAL_HOSTNAME", "localhost"),
80
+ metadata_server=os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"),
81
+ global_segment_size=int(
82
+ os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
83
+ ),
84
+ local_buffer_size=int(
85
+ os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
86
+ ),
87
+ protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
88
+ device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
89
+ master_server_address=os.getenv("MOONCAKE_MASTER"),
90
+ )
91
+
92
+ def __post_init__(self):
93
+ if self.device_name == "auto":
94
+ os.environ["MC_MS_AUTO_DISC"] = "1"
95
+ os.environ["MC_MS_FILTERS"] = (
96
+ "mlx5_bond_0, mlx5_bond_1, mlx5_bond_2, mlx5_bond_3"
97
+ )
98
+
99
+
100
+ class MooncakeStore(HiCacheStorage):
101
+ def __init__(self):
102
+ try:
103
+ from mooncake.store import MooncakeDistributedStore
104
+ except ImportError as e:
105
+ raise ImportError(
106
+ "Please install mooncake by following the instructions at "
107
+ "https://kvcache-ai.github.io/Mooncake/getting_started/build.html"
108
+ "to run SGLang with MooncakeConnector."
109
+ ) from e
110
+
111
+ try:
112
+ self.store = MooncakeDistributedStore()
113
+ self.config = MooncakeStoreConfig.load_from_env()
114
+ logger.info("Mooncake Configuration loaded from env successfully.")
115
+
116
+ ret_code = self.store.setup(
117
+ self.config.local_hostname,
118
+ self.config.metadata_server,
119
+ self.config.global_segment_size,
120
+ self.config.local_buffer_size,
121
+ self.config.protocol,
122
+ self.config.device_name,
123
+ self.config.master_server_address,
124
+ )
125
+ if ret_code:
126
+ logger.error(f"failed to setup mooncake store, error code: {ret_code}")
127
+
128
+ logger.info("Connect to Mooncake store successfully.")
129
+ self.warmup()
130
+ logger.info("Mooncake store warmup successfully.")
131
+
132
+ except ValueError as e:
133
+ logger.error("Configuration loading failed: %s", e)
134
+ raise
135
+ except Exception as exc:
136
+ logger.error("An error occurred while loading the configuration: %s", exc)
137
+ raise
138
+
139
+ def warmup(self):
140
+ warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
141
+ # 10 MB
142
+ warmup_value = bytes(10 * 1024 * 1024)
143
+ self.store.put(warmup_key, warmup_value)
144
+ assert self.store.is_exist(warmup_key) == 1
145
+ self.store.get(warmup_key)
146
+ self.store.remove(warmup_key)
147
+
148
+ def register_buffer(self, buffer: torch.Tensor) -> None:
149
+ try:
150
+ buffer_ptr = buffer.data_ptr()
151
+ buffer_size = buffer.numel() * buffer.element_size()
152
+ ret_code = self.store.register_buffer(buffer_ptr, buffer_size)
153
+ if ret_code:
154
+ logger.error(f"failed to register buffer, error code: {ret_code}")
155
+ except TypeError as err:
156
+ logger.error("Failed to register buffer to Mooncake Store: %s", err)
157
+ raise TypeError("Mooncake Store Register Buffer Error.") from err
158
+
159
+ def set(
160
+ self,
161
+ key,
162
+ value: Optional[Any] = None,
163
+ target_location: Optional[List[int]] = None,
164
+ target_sizes: Optional[List[int]] = None,
165
+ ) -> bool:
166
+ assert len(key) == len(target_location) == len(target_sizes)
167
+ if len(key) == 0:
168
+ return
169
+
170
+ for i in range(len(key)):
171
+ if key[i] is None or target_location[i] is None or target_sizes[i] is None:
172
+ return
173
+
174
+ self._put_batch_zero_copy_impl(key, target_location, target_sizes)
175
+
176
+ def batch_set(
177
+ self,
178
+ keys: List[str],
179
+ value: Optional[Any] = None,
180
+ target_location: Optional[List[int]] = None,
181
+ target_sizes: Optional[List[int]] = None,
182
+ ) -> bool:
183
+ assert len(keys) == len(target_location) == len(target_sizes)
184
+ if len(keys) == 0:
185
+ return
186
+
187
+ for i in range(len(keys)):
188
+ if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
189
+ return
190
+
191
+ self._put_batch_zero_copy_impl(keys, target_location, target_sizes)
192
+
193
+ def get(
194
+ self,
195
+ key,
196
+ target_location: Optional[Any] = None,
197
+ target_sizes: Optional[Any] = None,
198
+ ) -> torch.Tensor | None:
199
+ assert len(key) == len(target_location) == len(target_sizes)
200
+ if len(key) == 0:
201
+ return
202
+
203
+ for i in range(len(key)):
204
+ if key[i] is None or target_location[i] is None or target_sizes[i] is None:
205
+ return
206
+
207
+ return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
208
+
209
+ def batch_get(
210
+ self,
211
+ keys: List[str],
212
+ target_location: Optional[Any] = None,
213
+ target_sizes: Optional[Any] = None,
214
+ ) -> torch.Tensor | None:
215
+ assert len(keys) == len(target_location) == len(target_sizes)
216
+ if len(keys) == 0:
217
+ return
218
+
219
+ for i in range(len(keys)):
220
+ if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
221
+ return
222
+
223
+ return self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
224
+
225
+ def exists(self, keys) -> bool | dict:
226
+ _keys = []
227
+ local_rank = torch.cuda.current_device()
228
+ for key in keys:
229
+ if key is None:
230
+ return None
231
+ # Since mooncake store is stored in layer by layer,
232
+ # only the first layer is checked here.
233
+ _keys.append(f"{key}_{local_rank}_k")
234
+ result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
235
+ return result
236
+
237
+ def delete(self, key) -> None:
238
+ raise (NotImplementedError)
239
+
240
+ def close(self):
241
+ # MooncakeDistributedStore will automatically call the destructor, so
242
+ # it is unnecessary to close it manually.
243
+ pass
244
+
245
+ def clear(self) -> None:
246
+ raise (NotImplementedError)
247
+
248
+ def _put_batch_zero_copy_impl(
249
+ self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
250
+ ) -> None:
251
+ try:
252
+ self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
253
+ except TypeError as err:
254
+ logger.error("Failed to put value to Mooncake Store: %s", err)
255
+ raise TypeError("Mooncake Store Put Type Error.") from err
256
+
257
+ def _get_batch_zero_copy_impl(
258
+ self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
259
+ ) -> None:
260
+ try:
261
+ self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
262
+ except TypeError as err:
263
+ logger.error("Failed to get value from Mooncake Store: %s", err)
264
+ raise TypeError("Mooncake Store Get Type Error.") from err
@@ -0,0 +1,40 @@
1
+ import torch
2
+ from mooncake_store import MooncakeStore
3
+
4
+
5
+ def test_init_and_warmup():
6
+ store = MooncakeStore()
7
+ assert store.store is not None
8
+
9
+
10
+ def test_register_buffer():
11
+ store = MooncakeStore()
12
+ tensor = torch.zeros(1024, dtype=torch.float32)
13
+ store.register_buffer(tensor)
14
+
15
+
16
+ def test_set_and_get():
17
+ store = MooncakeStore()
18
+
19
+ key = ["test_key_" + str(i) for i in range(2)]
20
+ tensor = torch.arange(256, dtype=torch.float32).cuda()
21
+ ptrs = [tensor.data_ptr(), tensor.data_ptr()]
22
+ sizes = [tensor.numel() * tensor.element_size()] * 2
23
+
24
+ store.set(key, target_location=ptrs, target_sizes=sizes)
25
+ store.get(key, target_location=ptrs, target_sizes=sizes)
26
+
27
+
28
+ def test_exists():
29
+ store = MooncakeStore()
30
+ keys = ["test_key_0", "non_existent_key"]
31
+ result = store.exists(keys)
32
+ assert isinstance(result, dict)
33
+ assert "test_key_0" in result
34
+
35
+
36
+ if __name__ == "__main__":
37
+ test_init_and_warmup()
38
+ test_register_buffer()
39
+ test_set_and_get()
40
+ test_exists()
@@ -0,0 +1,177 @@
1
+ import logging
2
+ import multiprocessing
3
+ import os
4
+ import threading
5
+ from functools import wraps
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import torch
10
+ from torch.utils.cpp_extension import load
11
+
12
+ root = Path(__file__).parent.resolve()
13
+ hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ try:
18
+ from hf3fs_fuse.io import (
19
+ deregister_fd,
20
+ extract_mount_point,
21
+ make_ioring,
22
+ make_iovec,
23
+ register_fd,
24
+ )
25
+ except ImportError as e:
26
+ logger.warning(f"hf3fs_fuse.io is not available: {e}")
27
+
28
+
29
+ def rsynchronized():
30
+ def _decorator(func):
31
+ @wraps(func)
32
+ def wrapper(self, *args, **kwargs):
33
+ with self.rlock:
34
+ return func(self, *args, **kwargs)
35
+
36
+ return wrapper
37
+
38
+ return _decorator
39
+
40
+
41
+ def wsynchronized():
42
+ def _decorator(func):
43
+ @wraps(func)
44
+ def wrapper(self, *args, **kwargs):
45
+ with self.wlock:
46
+ return func(self, *args, **kwargs)
47
+
48
+ return wrapper
49
+
50
+ return _decorator
51
+
52
+
53
+ class Hf3fsClient:
54
+ def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
55
+ self.path = path
56
+ self.size = size
57
+ self.bytes_per_page = bytes_per_page
58
+ self.entries = entries
59
+
60
+ self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
61
+ os.ftruncate(self.file, size)
62
+ register_fd(self.file)
63
+
64
+ self.hf3fs_mount_point = extract_mount_point(path)
65
+ self.bs = self.bytes_per_page
66
+ self.shm_r = multiprocessing.shared_memory.SharedMemory(
67
+ size=self.bs * self.entries, create=True
68
+ )
69
+ self.shm_w = multiprocessing.shared_memory.SharedMemory(
70
+ size=self.bs * self.entries, create=True
71
+ )
72
+
73
+ self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8)
74
+ self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8)
75
+
76
+ self.numa = -1
77
+ self.ior_r = make_ioring(
78
+ self.hf3fs_mount_point,
79
+ self.entries,
80
+ for_read=True,
81
+ timeout=1,
82
+ numa=self.numa,
83
+ )
84
+ self.ior_w = make_ioring(
85
+ self.hf3fs_mount_point,
86
+ self.entries,
87
+ for_read=False,
88
+ timeout=1,
89
+ numa=self.numa,
90
+ )
91
+ self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
92
+ self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
93
+
94
+ self.rlock = threading.RLock()
95
+ self.wlock = threading.RLock()
96
+
97
+ @rsynchronized()
98
+ def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
99
+ self.check(offsets, tensors)
100
+
101
+ # prepare
102
+ current = 0
103
+ for offset, tensor in zip(offsets, tensors):
104
+ size = tensor.numel() * tensor.itemsize
105
+ self.ior_r.prepare(
106
+ self.iov_r[current : current + size], True, self.file, offset
107
+ )
108
+ current += size
109
+
110
+ # submit
111
+ ionum = len(offsets)
112
+ resv = self.ior_r.submit().wait(min_results=ionum)
113
+
114
+ # results
115
+ hf3fs_utils.read_shm(self.shm_r_tensor, tensors)
116
+ results = [res.result for res in resv]
117
+
118
+ return results
119
+
120
+ @wsynchronized()
121
+ def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
122
+ self.check(offsets, tensors)
123
+
124
+ # prepare
125
+ hf3fs_utils.write_shm(tensors, self.shm_w_tensor)
126
+ current = 0
127
+ for offset, tensor in zip(offsets, tensors):
128
+ size = tensor.numel() * tensor.itemsize
129
+ self.ior_w.prepare(
130
+ self.iov_w[current : current + size], False, self.file, offset
131
+ )
132
+ current += size
133
+
134
+ # submit
135
+ ionum = len(offsets)
136
+ resv = self.ior_w.submit().wait(min_results=ionum)
137
+
138
+ # results
139
+ results = [res.result for res in resv]
140
+
141
+ return results
142
+
143
+ def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
144
+ sizes = [t.numel() * t.itemsize for t in tensors]
145
+ if any(
146
+ [
147
+ len(offsets) > self.entries,
148
+ len(offsets) != len(sizes),
149
+ all(
150
+ [
151
+ offset < 0 or offset + size > self.size
152
+ for offset, size in zip(offsets, sizes)
153
+ ]
154
+ ),
155
+ all([size > self.bytes_per_page for size in sizes]),
156
+ ]
157
+ ):
158
+ self.close()
159
+ raise ValueError(f"Hf3fsClient.check: {offsets=}, {sizes=}")
160
+
161
+ def get_size(self) -> int:
162
+ return self.size
163
+
164
+ def close(self) -> None:
165
+ deregister_fd(self.file)
166
+ os.close(self.file)
167
+ del self.ior_r
168
+ del self.ior_w
169
+ del self.iov_r
170
+ del self.iov_w
171
+ self.shm_r.close()
172
+ self.shm_w.close()
173
+ self.shm_r.unlink()
174
+ self.shm_w.unlink()
175
+
176
+ def flush(self) -> None:
177
+ os.fsync(self.file)