sglang 0.4.9.post6__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +3 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -640
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +89 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +5 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +2 -0
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +10 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/RECORD +69 -56
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -79,7 +79,9 @@ class HiRadixCache(RadixCache):
|
|
79
79
|
self.write_through_threshold = (
|
80
80
|
1 if hicache_write_policy == "write_through" else 3
|
81
81
|
)
|
82
|
-
self.write_through_threshold_storage =
|
82
|
+
self.write_through_threshold_storage = (
|
83
|
+
1 if hicache_write_policy == "write_through" else 3
|
84
|
+
)
|
83
85
|
self.load_back_threshold = 10
|
84
86
|
super().__init__(
|
85
87
|
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
@@ -111,6 +113,7 @@ class HiRadixCache(RadixCache):
|
|
111
113
|
)
|
112
114
|
if host_indices is not None:
|
113
115
|
node.host_value = host_indices
|
116
|
+
assert len(node.host_value) > 0
|
114
117
|
self.ongoing_write_through[node.id] = node
|
115
118
|
if not write_back:
|
116
119
|
# no need to lock nodes if write back
|
@@ -388,10 +391,14 @@ class HiRadixCache(RadixCache):
|
|
388
391
|
self.cache_controller.ack_backup_queue.get()
|
389
392
|
)
|
390
393
|
host_node = self.ongoing_backup[ack_id]
|
391
|
-
if completed_tokens
|
394
|
+
if completed_tokens == 0:
|
395
|
+
host_node.hash_value = None
|
396
|
+
elif completed_tokens < len(host_node.key):
|
392
397
|
# backup is only partially successful, split the node
|
393
398
|
new_node = self._split_node(host_node.key, host_node, completed_tokens)
|
394
399
|
new_node.hash_value = hash_value
|
400
|
+
else:
|
401
|
+
host_node.hash_value = hash_value
|
395
402
|
host_node.release_host()
|
396
403
|
del self.ongoing_backup[ack_id]
|
397
404
|
|
@@ -431,6 +438,8 @@ class HiRadixCache(RadixCache):
|
|
431
438
|
written_indices,
|
432
439
|
hash_value[:min_completed_tokens],
|
433
440
|
)
|
441
|
+
if len(written_indices):
|
442
|
+
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
434
443
|
|
435
444
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
436
445
|
self.cache_controller.mem_pool_host.free(
|
@@ -551,13 +560,11 @@ class HiRadixCache(RadixCache):
|
|
551
560
|
prefix_len = self.key_match_fn(child.key, key)
|
552
561
|
if prefix_len < len(child.key):
|
553
562
|
new_node = self._split_node(child.key, child, prefix_len)
|
554
|
-
self.inc_hit_count(new_node)
|
555
563
|
if not new_node.evicted:
|
556
564
|
value.append(new_node.value)
|
557
565
|
node = new_node
|
558
566
|
break
|
559
567
|
else:
|
560
|
-
self.inc_hit_count(child)
|
561
568
|
if not child.evicted:
|
562
569
|
value.append(child.value)
|
563
570
|
node = child
|
@@ -587,6 +594,10 @@ class HiRadixCache(RadixCache):
|
|
587
594
|
if child.backuped:
|
588
595
|
new_node.host_value = child.host_value[:split_len]
|
589
596
|
child.host_value = child.host_value[split_len:]
|
597
|
+
|
598
|
+
if child.hash_value:
|
599
|
+
new_node.hash_value = child.hash_value[: split_len // self.page_size]
|
600
|
+
child.hash_value = child.hash_value[split_len // self.page_size :]
|
590
601
|
child.parent = new_node
|
591
602
|
child.key = child.key[split_len:]
|
592
603
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
@@ -25,7 +25,6 @@ def synchronized(debug_only=False):
|
|
25
25
|
@wraps(func)
|
26
26
|
def wrapper(self, *args, **kwargs):
|
27
27
|
if (not debug_only) or self.debug:
|
28
|
-
return func(self, *args, **kwargs)
|
29
28
|
with self.lock:
|
30
29
|
return func(self, *args, **kwargs)
|
31
30
|
else:
|
@@ -181,6 +180,15 @@ class HostKVCache(abc.ABC):
|
|
181
180
|
)
|
182
181
|
self.mem_state[indices] = MemoryStateInt.BACKUP
|
183
182
|
|
183
|
+
@synchronized(debug_only=True)
|
184
|
+
def update_prefetch(self, indices: torch.Tensor):
|
185
|
+
if not self.is_reserved(indices):
|
186
|
+
raise ValueError(
|
187
|
+
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
|
188
|
+
f"Current state: {self.get_state(indices)}"
|
189
|
+
)
|
190
|
+
self.mem_state[indices] = MemoryStateInt.BACKUP
|
191
|
+
|
184
192
|
@synchronized(debug_only=True)
|
185
193
|
def update_synced(self, indices: torch.Tensor):
|
186
194
|
self.mem_state[indices] = MemoryStateInt.SYNCED
|
@@ -257,6 +265,43 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
257
265
|
self.head_dim,
|
258
266
|
)
|
259
267
|
|
268
|
+
def get_buffer_meta(self, keys, indices):
|
269
|
+
ptr_list = []
|
270
|
+
key_list = []
|
271
|
+
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
272
|
+
v_offset = (
|
273
|
+
self.layer_num
|
274
|
+
* self.size
|
275
|
+
* self.head_num
|
276
|
+
* self.head_dim
|
277
|
+
* self.dtype.itemsize
|
278
|
+
)
|
279
|
+
for index in range(0, len(indices), self.page_size):
|
280
|
+
for layer_id in range(self.layer_num):
|
281
|
+
k_ptr = (
|
282
|
+
kv_buffer_data_ptr
|
283
|
+
+ indices[index]
|
284
|
+
* self.head_num
|
285
|
+
* self.head_dim
|
286
|
+
* self.dtype.itemsize
|
287
|
+
+ layer_id
|
288
|
+
* self.size
|
289
|
+
* self.head_num
|
290
|
+
* self.head_dim
|
291
|
+
* self.dtype.itemsize
|
292
|
+
)
|
293
|
+
v_ptr = k_ptr + v_offset
|
294
|
+
ptr_list.append(k_ptr)
|
295
|
+
ptr_list.append(v_ptr)
|
296
|
+
key_ = keys[index // self.page_size]
|
297
|
+
key_list.append(f"{key_}_{layer_id}_k")
|
298
|
+
key_list.append(f"{key_}_{layer_id}_v")
|
299
|
+
element_size = (
|
300
|
+
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
|
301
|
+
)
|
302
|
+
element_size_list = [element_size] * len(key_list)
|
303
|
+
return key_list, ptr_list, element_size_list
|
304
|
+
|
260
305
|
@property
|
261
306
|
def k_buffer(self):
|
262
307
|
return self.kv_buffer[0]
|
@@ -317,3 +362,30 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
317
362
|
1,
|
318
363
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
319
364
|
)
|
365
|
+
|
366
|
+
def get_buffer_meta(self, keys, indices):
|
367
|
+
ptr_list = []
|
368
|
+
key_list = []
|
369
|
+
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
370
|
+
for index in range(0, len(indices), self.page_size):
|
371
|
+
for layer_id in range(self.layer_num):
|
372
|
+
k_ptr = (
|
373
|
+
kv_buffer_data_ptr
|
374
|
+
+ indices[index]
|
375
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
376
|
+
* self.dtype.itemsize
|
377
|
+
+ layer_id
|
378
|
+
* self.size
|
379
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
380
|
+
* self.dtype.itemsize
|
381
|
+
)
|
382
|
+
ptr_list.append(k_ptr)
|
383
|
+
key_ = keys[index // self.page_size]
|
384
|
+
key_list.append(f"{key_}_{layer_id}_k")
|
385
|
+
element_size = (
|
386
|
+
self.dtype.itemsize
|
387
|
+
* self.page_size
|
388
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
389
|
+
)
|
390
|
+
element_size_list = [element_size] * len(key_list)
|
391
|
+
return key_list, ptr_list, element_size_list
|
@@ -0,0 +1,264 @@
|
|
1
|
+
import hashlib
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import uuid
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from typing import Any, List, Optional
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import torch
|
11
|
+
|
12
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
13
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
14
|
+
|
15
|
+
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
16
|
+
DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
|
22
|
+
local_rank = get_tensor_model_parallel_rank()
|
23
|
+
prefix_str = ""
|
24
|
+
if prefix_block_key:
|
25
|
+
if len(prefix_block_key):
|
26
|
+
prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
|
27
|
+
current_token_ids_bytes = np.array(current_page_ids).tobytes()
|
28
|
+
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
29
|
+
current_hash_hex = current_hash_object.hexdigest()
|
30
|
+
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class MooncakeStoreConfig:
|
35
|
+
local_hostname: str
|
36
|
+
metadata_server: str
|
37
|
+
global_segment_size: int
|
38
|
+
local_buffer_size: int
|
39
|
+
protocol: str
|
40
|
+
device_name: str
|
41
|
+
master_server_address: str
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def from_file() -> "MooncakeStoreConfig":
|
45
|
+
"""Load the config from a JSON file."""
|
46
|
+
file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
47
|
+
if file_path is None:
|
48
|
+
raise ValueError(
|
49
|
+
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
50
|
+
)
|
51
|
+
with open(file_path) as fin:
|
52
|
+
config = json.load(fin)
|
53
|
+
return MooncakeStoreConfig(
|
54
|
+
local_hostname=config.get("local_hostname"),
|
55
|
+
metadata_server=config.get("metadata_server"),
|
56
|
+
global_segment_size=config.get(
|
57
|
+
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
58
|
+
),
|
59
|
+
local_buffer_size=config.get(
|
60
|
+
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
|
61
|
+
),
|
62
|
+
protocol=config.get("protocol", "tcp"),
|
63
|
+
device_name=config.get("device_name", "auto"),
|
64
|
+
master_server_address=config.get("master_server_address"),
|
65
|
+
)
|
66
|
+
|
67
|
+
@staticmethod
|
68
|
+
def load_from_env() -> "MooncakeStoreConfig":
|
69
|
+
"""Load config from a file specified in the environment variable.
|
70
|
+
export MOONCAKE_MASTER=10.13.3.232:50051
|
71
|
+
export MOONCAKE_PROTOCOL="rdma"
|
72
|
+
export MOONCAKE_DEVICE="auto"
|
73
|
+
export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE"
|
74
|
+
"""
|
75
|
+
# other required environment variables...
|
76
|
+
if not os.getenv("MOONCAKE_MASTER"):
|
77
|
+
raise ValueError("The environment variable 'MOONCAKE_MASTER' is not set.")
|
78
|
+
return MooncakeStoreConfig(
|
79
|
+
local_hostname=os.getenv("LOCAL_HOSTNAME", "localhost"),
|
80
|
+
metadata_server=os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"),
|
81
|
+
global_segment_size=int(
|
82
|
+
os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
83
|
+
),
|
84
|
+
local_buffer_size=int(
|
85
|
+
os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
|
86
|
+
),
|
87
|
+
protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
|
88
|
+
device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
|
89
|
+
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
90
|
+
)
|
91
|
+
|
92
|
+
def __post_init__(self):
|
93
|
+
if self.device_name == "auto":
|
94
|
+
os.environ["MC_MS_AUTO_DISC"] = "1"
|
95
|
+
os.environ["MC_MS_FILTERS"] = (
|
96
|
+
"mlx5_bond_0, mlx5_bond_1, mlx5_bond_2, mlx5_bond_3"
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
class MooncakeStore(HiCacheStorage):
|
101
|
+
def __init__(self):
|
102
|
+
try:
|
103
|
+
from mooncake.store import MooncakeDistributedStore
|
104
|
+
except ImportError as e:
|
105
|
+
raise ImportError(
|
106
|
+
"Please install mooncake by following the instructions at "
|
107
|
+
"https://kvcache-ai.github.io/Mooncake/getting_started/build.html"
|
108
|
+
"to run SGLang with MooncakeConnector."
|
109
|
+
) from e
|
110
|
+
|
111
|
+
try:
|
112
|
+
self.store = MooncakeDistributedStore()
|
113
|
+
self.config = MooncakeStoreConfig.load_from_env()
|
114
|
+
logger.info("Mooncake Configuration loaded from env successfully.")
|
115
|
+
|
116
|
+
ret_code = self.store.setup(
|
117
|
+
self.config.local_hostname,
|
118
|
+
self.config.metadata_server,
|
119
|
+
self.config.global_segment_size,
|
120
|
+
self.config.local_buffer_size,
|
121
|
+
self.config.protocol,
|
122
|
+
self.config.device_name,
|
123
|
+
self.config.master_server_address,
|
124
|
+
)
|
125
|
+
if ret_code:
|
126
|
+
logger.error(f"failed to setup mooncake store, error code: {ret_code}")
|
127
|
+
|
128
|
+
logger.info("Connect to Mooncake store successfully.")
|
129
|
+
self.warmup()
|
130
|
+
logger.info("Mooncake store warmup successfully.")
|
131
|
+
|
132
|
+
except ValueError as e:
|
133
|
+
logger.error("Configuration loading failed: %s", e)
|
134
|
+
raise
|
135
|
+
except Exception as exc:
|
136
|
+
logger.error("An error occurred while loading the configuration: %s", exc)
|
137
|
+
raise
|
138
|
+
|
139
|
+
def warmup(self):
|
140
|
+
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
|
141
|
+
# 10 MB
|
142
|
+
warmup_value = bytes(10 * 1024 * 1024)
|
143
|
+
self.store.put(warmup_key, warmup_value)
|
144
|
+
assert self.store.is_exist(warmup_key) == 1
|
145
|
+
self.store.get(warmup_key)
|
146
|
+
self.store.remove(warmup_key)
|
147
|
+
|
148
|
+
def register_buffer(self, buffer: torch.Tensor) -> None:
|
149
|
+
try:
|
150
|
+
buffer_ptr = buffer.data_ptr()
|
151
|
+
buffer_size = buffer.numel() * buffer.element_size()
|
152
|
+
ret_code = self.store.register_buffer(buffer_ptr, buffer_size)
|
153
|
+
if ret_code:
|
154
|
+
logger.error(f"failed to register buffer, error code: {ret_code}")
|
155
|
+
except TypeError as err:
|
156
|
+
logger.error("Failed to register buffer to Mooncake Store: %s", err)
|
157
|
+
raise TypeError("Mooncake Store Register Buffer Error.") from err
|
158
|
+
|
159
|
+
def set(
|
160
|
+
self,
|
161
|
+
key,
|
162
|
+
value: Optional[Any] = None,
|
163
|
+
target_location: Optional[List[int]] = None,
|
164
|
+
target_sizes: Optional[List[int]] = None,
|
165
|
+
) -> bool:
|
166
|
+
assert len(key) == len(target_location) == len(target_sizes)
|
167
|
+
if len(key) == 0:
|
168
|
+
return
|
169
|
+
|
170
|
+
for i in range(len(key)):
|
171
|
+
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
172
|
+
return
|
173
|
+
|
174
|
+
self._put_batch_zero_copy_impl(key, target_location, target_sizes)
|
175
|
+
|
176
|
+
def batch_set(
|
177
|
+
self,
|
178
|
+
keys: List[str],
|
179
|
+
value: Optional[Any] = None,
|
180
|
+
target_location: Optional[List[int]] = None,
|
181
|
+
target_sizes: Optional[List[int]] = None,
|
182
|
+
) -> bool:
|
183
|
+
assert len(keys) == len(target_location) == len(target_sizes)
|
184
|
+
if len(keys) == 0:
|
185
|
+
return
|
186
|
+
|
187
|
+
for i in range(len(keys)):
|
188
|
+
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
189
|
+
return
|
190
|
+
|
191
|
+
self._put_batch_zero_copy_impl(keys, target_location, target_sizes)
|
192
|
+
|
193
|
+
def get(
|
194
|
+
self,
|
195
|
+
key,
|
196
|
+
target_location: Optional[Any] = None,
|
197
|
+
target_sizes: Optional[Any] = None,
|
198
|
+
) -> torch.Tensor | None:
|
199
|
+
assert len(key) == len(target_location) == len(target_sizes)
|
200
|
+
if len(key) == 0:
|
201
|
+
return
|
202
|
+
|
203
|
+
for i in range(len(key)):
|
204
|
+
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
205
|
+
return
|
206
|
+
|
207
|
+
return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
|
208
|
+
|
209
|
+
def batch_get(
|
210
|
+
self,
|
211
|
+
keys: List[str],
|
212
|
+
target_location: Optional[Any] = None,
|
213
|
+
target_sizes: Optional[Any] = None,
|
214
|
+
) -> torch.Tensor | None:
|
215
|
+
assert len(keys) == len(target_location) == len(target_sizes)
|
216
|
+
if len(keys) == 0:
|
217
|
+
return
|
218
|
+
|
219
|
+
for i in range(len(keys)):
|
220
|
+
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
221
|
+
return
|
222
|
+
|
223
|
+
return self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
|
224
|
+
|
225
|
+
def exists(self, keys) -> bool | dict:
|
226
|
+
_keys = []
|
227
|
+
local_rank = torch.cuda.current_device()
|
228
|
+
for key in keys:
|
229
|
+
if key is None:
|
230
|
+
return None
|
231
|
+
# Since mooncake store is stored in layer by layer,
|
232
|
+
# only the first layer is checked here.
|
233
|
+
_keys.append(f"{key}_{local_rank}_k")
|
234
|
+
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
235
|
+
return result
|
236
|
+
|
237
|
+
def delete(self, key) -> None:
|
238
|
+
raise (NotImplementedError)
|
239
|
+
|
240
|
+
def close(self):
|
241
|
+
# MooncakeDistributedStore will automatically call the destructor, so
|
242
|
+
# it is unnecessary to close it manually.
|
243
|
+
pass
|
244
|
+
|
245
|
+
def clear(self) -> None:
|
246
|
+
raise (NotImplementedError)
|
247
|
+
|
248
|
+
def _put_batch_zero_copy_impl(
|
249
|
+
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
250
|
+
) -> None:
|
251
|
+
try:
|
252
|
+
self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
|
253
|
+
except TypeError as err:
|
254
|
+
logger.error("Failed to put value to Mooncake Store: %s", err)
|
255
|
+
raise TypeError("Mooncake Store Put Type Error.") from err
|
256
|
+
|
257
|
+
def _get_batch_zero_copy_impl(
|
258
|
+
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
259
|
+
) -> None:
|
260
|
+
try:
|
261
|
+
self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
|
262
|
+
except TypeError as err:
|
263
|
+
logger.error("Failed to get value from Mooncake Store: %s", err)
|
264
|
+
raise TypeError("Mooncake Store Get Type Error.") from err
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import torch
|
2
|
+
from mooncake_store import MooncakeStore
|
3
|
+
|
4
|
+
|
5
|
+
def test_init_and_warmup():
|
6
|
+
store = MooncakeStore()
|
7
|
+
assert store.store is not None
|
8
|
+
|
9
|
+
|
10
|
+
def test_register_buffer():
|
11
|
+
store = MooncakeStore()
|
12
|
+
tensor = torch.zeros(1024, dtype=torch.float32)
|
13
|
+
store.register_buffer(tensor)
|
14
|
+
|
15
|
+
|
16
|
+
def test_set_and_get():
|
17
|
+
store = MooncakeStore()
|
18
|
+
|
19
|
+
key = ["test_key_" + str(i) for i in range(2)]
|
20
|
+
tensor = torch.arange(256, dtype=torch.float32).cuda()
|
21
|
+
ptrs = [tensor.data_ptr(), tensor.data_ptr()]
|
22
|
+
sizes = [tensor.numel() * tensor.element_size()] * 2
|
23
|
+
|
24
|
+
store.set(key, target_location=ptrs, target_sizes=sizes)
|
25
|
+
store.get(key, target_location=ptrs, target_sizes=sizes)
|
26
|
+
|
27
|
+
|
28
|
+
def test_exists():
|
29
|
+
store = MooncakeStore()
|
30
|
+
keys = ["test_key_0", "non_existent_key"]
|
31
|
+
result = store.exists(keys)
|
32
|
+
assert isinstance(result, dict)
|
33
|
+
assert "test_key_0" in result
|
34
|
+
|
35
|
+
|
36
|
+
if __name__ == "__main__":
|
37
|
+
test_init_and_warmup()
|
38
|
+
test_register_buffer()
|
39
|
+
test_set_and_get()
|
40
|
+
test_exists()
|
@@ -0,0 +1,177 @@
|
|
1
|
+
import logging
|
2
|
+
import multiprocessing
|
3
|
+
import os
|
4
|
+
import threading
|
5
|
+
from functools import wraps
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import List
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from torch.utils.cpp_extension import load
|
11
|
+
|
12
|
+
root = Path(__file__).parent.resolve()
|
13
|
+
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
try:
|
18
|
+
from hf3fs_fuse.io import (
|
19
|
+
deregister_fd,
|
20
|
+
extract_mount_point,
|
21
|
+
make_ioring,
|
22
|
+
make_iovec,
|
23
|
+
register_fd,
|
24
|
+
)
|
25
|
+
except ImportError as e:
|
26
|
+
logger.warning(f"hf3fs_fuse.io is not available: {e}")
|
27
|
+
|
28
|
+
|
29
|
+
def rsynchronized():
|
30
|
+
def _decorator(func):
|
31
|
+
@wraps(func)
|
32
|
+
def wrapper(self, *args, **kwargs):
|
33
|
+
with self.rlock:
|
34
|
+
return func(self, *args, **kwargs)
|
35
|
+
|
36
|
+
return wrapper
|
37
|
+
|
38
|
+
return _decorator
|
39
|
+
|
40
|
+
|
41
|
+
def wsynchronized():
|
42
|
+
def _decorator(func):
|
43
|
+
@wraps(func)
|
44
|
+
def wrapper(self, *args, **kwargs):
|
45
|
+
with self.wlock:
|
46
|
+
return func(self, *args, **kwargs)
|
47
|
+
|
48
|
+
return wrapper
|
49
|
+
|
50
|
+
return _decorator
|
51
|
+
|
52
|
+
|
53
|
+
class Hf3fsClient:
|
54
|
+
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
55
|
+
self.path = path
|
56
|
+
self.size = size
|
57
|
+
self.bytes_per_page = bytes_per_page
|
58
|
+
self.entries = entries
|
59
|
+
|
60
|
+
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
|
61
|
+
os.ftruncate(self.file, size)
|
62
|
+
register_fd(self.file)
|
63
|
+
|
64
|
+
self.hf3fs_mount_point = extract_mount_point(path)
|
65
|
+
self.bs = self.bytes_per_page
|
66
|
+
self.shm_r = multiprocessing.shared_memory.SharedMemory(
|
67
|
+
size=self.bs * self.entries, create=True
|
68
|
+
)
|
69
|
+
self.shm_w = multiprocessing.shared_memory.SharedMemory(
|
70
|
+
size=self.bs * self.entries, create=True
|
71
|
+
)
|
72
|
+
|
73
|
+
self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8)
|
74
|
+
self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8)
|
75
|
+
|
76
|
+
self.numa = -1
|
77
|
+
self.ior_r = make_ioring(
|
78
|
+
self.hf3fs_mount_point,
|
79
|
+
self.entries,
|
80
|
+
for_read=True,
|
81
|
+
timeout=1,
|
82
|
+
numa=self.numa,
|
83
|
+
)
|
84
|
+
self.ior_w = make_ioring(
|
85
|
+
self.hf3fs_mount_point,
|
86
|
+
self.entries,
|
87
|
+
for_read=False,
|
88
|
+
timeout=1,
|
89
|
+
numa=self.numa,
|
90
|
+
)
|
91
|
+
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
|
92
|
+
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
|
93
|
+
|
94
|
+
self.rlock = threading.RLock()
|
95
|
+
self.wlock = threading.RLock()
|
96
|
+
|
97
|
+
@rsynchronized()
|
98
|
+
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
99
|
+
self.check(offsets, tensors)
|
100
|
+
|
101
|
+
# prepare
|
102
|
+
current = 0
|
103
|
+
for offset, tensor in zip(offsets, tensors):
|
104
|
+
size = tensor.numel() * tensor.itemsize
|
105
|
+
self.ior_r.prepare(
|
106
|
+
self.iov_r[current : current + size], True, self.file, offset
|
107
|
+
)
|
108
|
+
current += size
|
109
|
+
|
110
|
+
# submit
|
111
|
+
ionum = len(offsets)
|
112
|
+
resv = self.ior_r.submit().wait(min_results=ionum)
|
113
|
+
|
114
|
+
# results
|
115
|
+
hf3fs_utils.read_shm(self.shm_r_tensor, tensors)
|
116
|
+
results = [res.result for res in resv]
|
117
|
+
|
118
|
+
return results
|
119
|
+
|
120
|
+
@wsynchronized()
|
121
|
+
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
122
|
+
self.check(offsets, tensors)
|
123
|
+
|
124
|
+
# prepare
|
125
|
+
hf3fs_utils.write_shm(tensors, self.shm_w_tensor)
|
126
|
+
current = 0
|
127
|
+
for offset, tensor in zip(offsets, tensors):
|
128
|
+
size = tensor.numel() * tensor.itemsize
|
129
|
+
self.ior_w.prepare(
|
130
|
+
self.iov_w[current : current + size], False, self.file, offset
|
131
|
+
)
|
132
|
+
current += size
|
133
|
+
|
134
|
+
# submit
|
135
|
+
ionum = len(offsets)
|
136
|
+
resv = self.ior_w.submit().wait(min_results=ionum)
|
137
|
+
|
138
|
+
# results
|
139
|
+
results = [res.result for res in resv]
|
140
|
+
|
141
|
+
return results
|
142
|
+
|
143
|
+
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
144
|
+
sizes = [t.numel() * t.itemsize for t in tensors]
|
145
|
+
if any(
|
146
|
+
[
|
147
|
+
len(offsets) > self.entries,
|
148
|
+
len(offsets) != len(sizes),
|
149
|
+
all(
|
150
|
+
[
|
151
|
+
offset < 0 or offset + size > self.size
|
152
|
+
for offset, size in zip(offsets, sizes)
|
153
|
+
]
|
154
|
+
),
|
155
|
+
all([size > self.bytes_per_page for size in sizes]),
|
156
|
+
]
|
157
|
+
):
|
158
|
+
self.close()
|
159
|
+
raise ValueError(f"Hf3fsClient.check: {offsets=}, {sizes=}")
|
160
|
+
|
161
|
+
def get_size(self) -> int:
|
162
|
+
return self.size
|
163
|
+
|
164
|
+
def close(self) -> None:
|
165
|
+
deregister_fd(self.file)
|
166
|
+
os.close(self.file)
|
167
|
+
del self.ior_r
|
168
|
+
del self.ior_w
|
169
|
+
del self.iov_r
|
170
|
+
del self.iov_w
|
171
|
+
self.shm_r.close()
|
172
|
+
self.shm_w.close()
|
173
|
+
self.shm_r.unlink()
|
174
|
+
self.shm_w.unlink()
|
175
|
+
|
176
|
+
def flush(self) -> None:
|
177
|
+
os.fsync(self.file)
|