sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +39 -674
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
- sglang/srt/layers/quantization/fp8.py +52 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +165 -67
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +90 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +60 -17
- sglang/srt/mem_cache/hiradix_cache.py +36 -8
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +418 -29
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/glm4_moe.py +6 -4
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +991 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +49 -18
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,264 @@
|
|
1
|
+
import hashlib
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import uuid
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from typing import Any, List, Optional
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import torch
|
11
|
+
|
12
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
13
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
14
|
+
|
15
|
+
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
16
|
+
DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
|
22
|
+
local_rank = get_tensor_model_parallel_rank()
|
23
|
+
prefix_str = ""
|
24
|
+
if prefix_block_key:
|
25
|
+
if len(prefix_block_key):
|
26
|
+
prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
|
27
|
+
current_token_ids_bytes = np.array(current_page_ids).tobytes()
|
28
|
+
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
29
|
+
current_hash_hex = current_hash_object.hexdigest()
|
30
|
+
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class MooncakeStoreConfig:
|
35
|
+
local_hostname: str
|
36
|
+
metadata_server: str
|
37
|
+
global_segment_size: int
|
38
|
+
local_buffer_size: int
|
39
|
+
protocol: str
|
40
|
+
device_name: str
|
41
|
+
master_server_address: str
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def from_file() -> "MooncakeStoreConfig":
|
45
|
+
"""Load the config from a JSON file."""
|
46
|
+
file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
47
|
+
if file_path is None:
|
48
|
+
raise ValueError(
|
49
|
+
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
50
|
+
)
|
51
|
+
with open(file_path) as fin:
|
52
|
+
config = json.load(fin)
|
53
|
+
return MooncakeStoreConfig(
|
54
|
+
local_hostname=config.get("local_hostname"),
|
55
|
+
metadata_server=config.get("metadata_server"),
|
56
|
+
global_segment_size=config.get(
|
57
|
+
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
58
|
+
),
|
59
|
+
local_buffer_size=config.get(
|
60
|
+
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
|
61
|
+
),
|
62
|
+
protocol=config.get("protocol", "tcp"),
|
63
|
+
device_name=config.get("device_name", "auto"),
|
64
|
+
master_server_address=config.get("master_server_address"),
|
65
|
+
)
|
66
|
+
|
67
|
+
@staticmethod
|
68
|
+
def load_from_env() -> "MooncakeStoreConfig":
|
69
|
+
"""Load config from a file specified in the environment variable.
|
70
|
+
export MOONCAKE_MASTER=10.13.3.232:50051
|
71
|
+
export MOONCAKE_PROTOCOL="rdma"
|
72
|
+
export MOONCAKE_DEVICE="auto"
|
73
|
+
export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE"
|
74
|
+
"""
|
75
|
+
# other required environment variables...
|
76
|
+
if not os.getenv("MOONCAKE_MASTER"):
|
77
|
+
raise ValueError("The environment variable 'MOONCAKE_MASTER' is not set.")
|
78
|
+
return MooncakeStoreConfig(
|
79
|
+
local_hostname=os.getenv("LOCAL_HOSTNAME", "localhost"),
|
80
|
+
metadata_server=os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"),
|
81
|
+
global_segment_size=int(
|
82
|
+
os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
83
|
+
),
|
84
|
+
local_buffer_size=int(
|
85
|
+
os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
|
86
|
+
),
|
87
|
+
protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
|
88
|
+
device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
|
89
|
+
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
90
|
+
)
|
91
|
+
|
92
|
+
def __post_init__(self):
|
93
|
+
if self.device_name == "auto":
|
94
|
+
os.environ["MC_MS_AUTO_DISC"] = "1"
|
95
|
+
os.environ["MC_MS_FILTERS"] = (
|
96
|
+
"mlx5_bond_0, mlx5_bond_1, mlx5_bond_2, mlx5_bond_3"
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
class MooncakeStore(HiCacheStorage):
|
101
|
+
def __init__(self):
|
102
|
+
try:
|
103
|
+
from mooncake.store import MooncakeDistributedStore
|
104
|
+
except ImportError as e:
|
105
|
+
raise ImportError(
|
106
|
+
"Please install mooncake by following the instructions at "
|
107
|
+
"https://kvcache-ai.github.io/Mooncake/getting_started/build.html"
|
108
|
+
"to run SGLang with MooncakeConnector."
|
109
|
+
) from e
|
110
|
+
|
111
|
+
try:
|
112
|
+
self.store = MooncakeDistributedStore()
|
113
|
+
self.config = MooncakeStoreConfig.load_from_env()
|
114
|
+
logger.info("Mooncake Configuration loaded from env successfully.")
|
115
|
+
|
116
|
+
ret_code = self.store.setup(
|
117
|
+
self.config.local_hostname,
|
118
|
+
self.config.metadata_server,
|
119
|
+
self.config.global_segment_size,
|
120
|
+
self.config.local_buffer_size,
|
121
|
+
self.config.protocol,
|
122
|
+
self.config.device_name,
|
123
|
+
self.config.master_server_address,
|
124
|
+
)
|
125
|
+
if ret_code:
|
126
|
+
logger.error(f"failed to setup mooncake store, error code: {ret_code}")
|
127
|
+
|
128
|
+
logger.info("Connect to Mooncake store successfully.")
|
129
|
+
self.warmup()
|
130
|
+
logger.info("Mooncake store warmup successfully.")
|
131
|
+
|
132
|
+
except ValueError as e:
|
133
|
+
logger.error("Configuration loading failed: %s", e)
|
134
|
+
raise
|
135
|
+
except Exception as exc:
|
136
|
+
logger.error("An error occurred while loading the configuration: %s", exc)
|
137
|
+
raise
|
138
|
+
|
139
|
+
def warmup(self):
|
140
|
+
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
|
141
|
+
# 10 MB
|
142
|
+
warmup_value = bytes(10 * 1024 * 1024)
|
143
|
+
self.store.put(warmup_key, warmup_value)
|
144
|
+
assert self.store.is_exist(warmup_key) == 1
|
145
|
+
self.store.get(warmup_key)
|
146
|
+
self.store.remove(warmup_key)
|
147
|
+
|
148
|
+
def register_buffer(self, buffer: torch.Tensor) -> None:
|
149
|
+
try:
|
150
|
+
buffer_ptr = buffer.data_ptr()
|
151
|
+
buffer_size = buffer.numel() * buffer.element_size()
|
152
|
+
ret_code = self.store.register_buffer(buffer_ptr, buffer_size)
|
153
|
+
if ret_code:
|
154
|
+
logger.error(f"failed to register buffer, error code: {ret_code}")
|
155
|
+
except TypeError as err:
|
156
|
+
logger.error("Failed to register buffer to Mooncake Store: %s", err)
|
157
|
+
raise TypeError("Mooncake Store Register Buffer Error.") from err
|
158
|
+
|
159
|
+
def set(
|
160
|
+
self,
|
161
|
+
key,
|
162
|
+
value: Optional[Any] = None,
|
163
|
+
target_location: Optional[List[int]] = None,
|
164
|
+
target_sizes: Optional[List[int]] = None,
|
165
|
+
) -> bool:
|
166
|
+
assert len(key) == len(target_location) == len(target_sizes)
|
167
|
+
if len(key) == 0:
|
168
|
+
return
|
169
|
+
|
170
|
+
for i in range(len(key)):
|
171
|
+
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
172
|
+
return
|
173
|
+
|
174
|
+
self._put_batch_zero_copy_impl(key, target_location, target_sizes)
|
175
|
+
|
176
|
+
def batch_set(
|
177
|
+
self,
|
178
|
+
keys: List[str],
|
179
|
+
value: Optional[Any] = None,
|
180
|
+
target_location: Optional[List[int]] = None,
|
181
|
+
target_sizes: Optional[List[int]] = None,
|
182
|
+
) -> bool:
|
183
|
+
assert len(keys) == len(target_location) == len(target_sizes)
|
184
|
+
if len(keys) == 0:
|
185
|
+
return
|
186
|
+
|
187
|
+
for i in range(len(keys)):
|
188
|
+
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
189
|
+
return
|
190
|
+
|
191
|
+
self._put_batch_zero_copy_impl(keys, target_location, target_sizes)
|
192
|
+
|
193
|
+
def get(
|
194
|
+
self,
|
195
|
+
key,
|
196
|
+
target_location: Optional[Any] = None,
|
197
|
+
target_sizes: Optional[Any] = None,
|
198
|
+
) -> torch.Tensor | None:
|
199
|
+
assert len(key) == len(target_location) == len(target_sizes)
|
200
|
+
if len(key) == 0:
|
201
|
+
return
|
202
|
+
|
203
|
+
for i in range(len(key)):
|
204
|
+
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
205
|
+
return
|
206
|
+
|
207
|
+
return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
|
208
|
+
|
209
|
+
def batch_get(
|
210
|
+
self,
|
211
|
+
keys: List[str],
|
212
|
+
target_location: Optional[Any] = None,
|
213
|
+
target_sizes: Optional[Any] = None,
|
214
|
+
) -> torch.Tensor | None:
|
215
|
+
assert len(keys) == len(target_location) == len(target_sizes)
|
216
|
+
if len(keys) == 0:
|
217
|
+
return
|
218
|
+
|
219
|
+
for i in range(len(keys)):
|
220
|
+
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
221
|
+
return
|
222
|
+
|
223
|
+
return self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
|
224
|
+
|
225
|
+
def exists(self, keys) -> bool | dict:
|
226
|
+
_keys = []
|
227
|
+
local_rank = torch.cuda.current_device()
|
228
|
+
for key in keys:
|
229
|
+
if key is None:
|
230
|
+
return None
|
231
|
+
# Since mooncake store is stored in layer by layer,
|
232
|
+
# only the first layer is checked here.
|
233
|
+
_keys.append(f"{key}_{local_rank}_k")
|
234
|
+
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
235
|
+
return result
|
236
|
+
|
237
|
+
def delete(self, key) -> None:
|
238
|
+
raise (NotImplementedError)
|
239
|
+
|
240
|
+
def close(self):
|
241
|
+
# MooncakeDistributedStore will automatically call the destructor, so
|
242
|
+
# it is unnecessary to close it manually.
|
243
|
+
pass
|
244
|
+
|
245
|
+
def clear(self) -> None:
|
246
|
+
raise (NotImplementedError)
|
247
|
+
|
248
|
+
def _put_batch_zero_copy_impl(
|
249
|
+
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
250
|
+
) -> None:
|
251
|
+
try:
|
252
|
+
self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
|
253
|
+
except TypeError as err:
|
254
|
+
logger.error("Failed to put value to Mooncake Store: %s", err)
|
255
|
+
raise TypeError("Mooncake Store Put Type Error.") from err
|
256
|
+
|
257
|
+
def _get_batch_zero_copy_impl(
|
258
|
+
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
259
|
+
) -> None:
|
260
|
+
try:
|
261
|
+
self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
|
262
|
+
except TypeError as err:
|
263
|
+
logger.error("Failed to get value from Mooncake Store: %s", err)
|
264
|
+
raise TypeError("Mooncake Store Get Type Error.") from err
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import torch
|
2
|
+
from mooncake_store import MooncakeStore
|
3
|
+
|
4
|
+
|
5
|
+
def test_init_and_warmup():
|
6
|
+
store = MooncakeStore()
|
7
|
+
assert store.store is not None
|
8
|
+
|
9
|
+
|
10
|
+
def test_register_buffer():
|
11
|
+
store = MooncakeStore()
|
12
|
+
tensor = torch.zeros(1024, dtype=torch.float32)
|
13
|
+
store.register_buffer(tensor)
|
14
|
+
|
15
|
+
|
16
|
+
def test_set_and_get():
|
17
|
+
store = MooncakeStore()
|
18
|
+
|
19
|
+
key = ["test_key_" + str(i) for i in range(2)]
|
20
|
+
tensor = torch.arange(256, dtype=torch.float32).cuda()
|
21
|
+
ptrs = [tensor.data_ptr(), tensor.data_ptr()]
|
22
|
+
sizes = [tensor.numel() * tensor.element_size()] * 2
|
23
|
+
|
24
|
+
store.set(key, target_location=ptrs, target_sizes=sizes)
|
25
|
+
store.get(key, target_location=ptrs, target_sizes=sizes)
|
26
|
+
|
27
|
+
|
28
|
+
def test_exists():
|
29
|
+
store = MooncakeStore()
|
30
|
+
keys = ["test_key_0", "non_existent_key"]
|
31
|
+
result = store.exists(keys)
|
32
|
+
assert isinstance(result, dict)
|
33
|
+
assert "test_key_0" in result
|
34
|
+
|
35
|
+
|
36
|
+
if __name__ == "__main__":
|
37
|
+
test_init_and_warmup()
|
38
|
+
test_register_buffer()
|
39
|
+
test_set_and_get()
|
40
|
+
test_exists()
|
@@ -0,0 +1,163 @@
|
|
1
|
+
import hashlib
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
import uuid
|
6
|
+
from typing import Dict, List, Optional, Tuple, Union
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
11
|
+
|
12
|
+
from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
|
13
|
+
|
14
|
+
try:
|
15
|
+
from nixl._api import nixl_agent, nixl_agent_config
|
16
|
+
except ImportError as e:
|
17
|
+
raise ImportError(
|
18
|
+
"Please install NIXL by following the instructions at "
|
19
|
+
"https://github.com/ai-dynamo/nixl/blob/main/README.md "
|
20
|
+
"to use HiCacheNixl storage backend."
|
21
|
+
) from e
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class HiCacheNixl(HiCacheStorage):
|
27
|
+
"""HiCacheNixl provides high-performance storage using NIXL plugins."""
|
28
|
+
|
29
|
+
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
30
|
+
"""Initialize NIXL storage connector."""
|
31
|
+
self.file_manager = (
|
32
|
+
NixlFileManager(file_path)
|
33
|
+
if plugin not in NixlBackendSelection.OBJ_PLUGINS
|
34
|
+
else None
|
35
|
+
)
|
36
|
+
|
37
|
+
agent_config = nixl_agent_config(backends=[])
|
38
|
+
self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
|
39
|
+
self.agent = nixl_agent(self.agent_name, agent_config)
|
40
|
+
|
41
|
+
self.backend_selector = NixlBackendSelection(plugin)
|
42
|
+
if not self.backend_selector.create_backend(self.agent):
|
43
|
+
raise RuntimeError("Failed to create NIXL backend")
|
44
|
+
|
45
|
+
self.registration = NixlRegistration(self.agent)
|
46
|
+
|
47
|
+
def _execute_transfer(
|
48
|
+
self, tensors: List[torch.Tensor], keys: List[str], direction: str
|
49
|
+
) -> bool:
|
50
|
+
if len(tensors) != len(keys):
|
51
|
+
logger.error("Mismatch between number of tensors and files/objects")
|
52
|
+
return False
|
53
|
+
|
54
|
+
if not self.registration.register_buffers(tensors):
|
55
|
+
logger.error("Failed to register tensors")
|
56
|
+
return False
|
57
|
+
|
58
|
+
# Get transfer tuples based on backend type
|
59
|
+
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
|
60
|
+
if self.backend_selector.mem_type == "FILE":
|
61
|
+
file_tuples = self.file_manager.files_to_nixl_tuples(keys)
|
62
|
+
if not file_tuples or not self.registration.register_files(file_tuples):
|
63
|
+
logger.error("Failed to prepare files for transfer")
|
64
|
+
return False
|
65
|
+
transfer_tuples = [
|
66
|
+
(x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
|
67
|
+
]
|
68
|
+
else:
|
69
|
+
if not self.registration.register_objects(keys, tensors):
|
70
|
+
logger.error("Failed to register objects")
|
71
|
+
return False
|
72
|
+
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
|
73
|
+
|
74
|
+
try:
|
75
|
+
# Get transfer descriptors
|
76
|
+
if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
|
77
|
+
file_descs := self.agent.get_xfer_descs(
|
78
|
+
transfer_tuples, self.backend_selector.mem_type
|
79
|
+
)
|
80
|
+
) is None:
|
81
|
+
logger.error("Failed to get transfer descriptors")
|
82
|
+
return False
|
83
|
+
|
84
|
+
# Initialize and execute transfer
|
85
|
+
if (
|
86
|
+
xfer_req := self.agent.initialize_xfer(
|
87
|
+
direction, tensor_descs, file_descs, self.agent_name
|
88
|
+
)
|
89
|
+
) is None:
|
90
|
+
logger.error("Failed to create transfer request")
|
91
|
+
return False
|
92
|
+
|
93
|
+
state = self.agent.transfer(xfer_req)
|
94
|
+
while state != "DONE":
|
95
|
+
state = self.agent.check_xfer_state(xfer_req)
|
96
|
+
if state == "ERR":
|
97
|
+
logger.error("Transfer failed")
|
98
|
+
return False
|
99
|
+
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
|
100
|
+
return True
|
101
|
+
|
102
|
+
except Exception as e:
|
103
|
+
logger.error(f"Failed to execute transfer: {e}")
|
104
|
+
import traceback
|
105
|
+
|
106
|
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
107
|
+
return False
|
108
|
+
|
109
|
+
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
110
|
+
if not keys:
|
111
|
+
return True
|
112
|
+
|
113
|
+
if self.backend_selector.mem_type == "FILE":
|
114
|
+
file_paths = []
|
115
|
+
for key in keys:
|
116
|
+
tensor_path = self.file_manager.get_file_path(key)
|
117
|
+
if not self.file_manager.create_file(tensor_path):
|
118
|
+
logger.error(f"Failed to create file {tensor_path}")
|
119
|
+
return False
|
120
|
+
file_paths.append(tensor_path)
|
121
|
+
return self._execute_transfer(values, file_paths, "WRITE")
|
122
|
+
else:
|
123
|
+
return self._execute_transfer(values, keys, "WRITE")
|
124
|
+
|
125
|
+
def set(self, key: str, value: torch.Tensor) -> bool:
|
126
|
+
return self.batch_set([key], [value])
|
127
|
+
|
128
|
+
def get(
|
129
|
+
self, key: str, dst_tensor: Optional[torch.Tensor] = None
|
130
|
+
) -> torch.Tensor | None:
|
131
|
+
if dst_tensor is None: # To be removed, being compatible with the current API
|
132
|
+
return None
|
133
|
+
result = self.batch_get([key], [dst_tensor])
|
134
|
+
return result[0] if result else None
|
135
|
+
|
136
|
+
def batch_get(
|
137
|
+
self, keys: List[str], dst_tensors: List[torch.Tensor]
|
138
|
+
) -> List[Optional[torch.Tensor]]:
|
139
|
+
if not keys:
|
140
|
+
return []
|
141
|
+
|
142
|
+
if self.backend_selector.mem_type == "FILE":
|
143
|
+
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
144
|
+
success = self._execute_transfer(dst_tensors, file_paths, "READ")
|
145
|
+
else:
|
146
|
+
success = self._execute_transfer(dst_tensors, keys, "READ")
|
147
|
+
return dst_tensors if success else [None] * len(keys)
|
148
|
+
|
149
|
+
def exists(self, key: str) -> bool:
|
150
|
+
tuples = self.registration.create_query_tuples(
|
151
|
+
key,
|
152
|
+
self.backend_selector.mem_type,
|
153
|
+
self.file_manager if self.backend_selector.mem_type == "FILE" else None,
|
154
|
+
)
|
155
|
+
if not tuples:
|
156
|
+
return False
|
157
|
+
|
158
|
+
query_res = self.agent.query_memory(
|
159
|
+
tuples,
|
160
|
+
self.backend_selector.backend_name,
|
161
|
+
mem_type=self.backend_selector.mem_type,
|
162
|
+
)
|
163
|
+
return query_res[0] is not None # can be expanded to multiple keys
|