sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +29 -68
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +57 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +43 -39
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +145 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -12
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/model_executor/cuda_graph_runner.py +42 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +28 -23
- sglang/srt/models/glm4_moe.py +85 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -4
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +115 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +4 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,229 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import TYPE_CHECKING, List, Set
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
9
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
10
|
+
from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
|
11
|
+
IOHandle,
|
12
|
+
RadixTreeCpp,
|
13
|
+
TreeNodeCpp,
|
14
|
+
)
|
15
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from sglang.srt.managers.schedule_batch import Req
|
19
|
+
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class RadixCacheCpp(BasePrefixCache):
|
25
|
+
def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor:
|
26
|
+
"""
|
27
|
+
Merge a list of tensors into a single tensor.
|
28
|
+
Args:
|
29
|
+
l (List[torch.Tensor]): List of tensors to merge.
|
30
|
+
Returns:
|
31
|
+
torch.Tensor: Merged tensor.
|
32
|
+
"""
|
33
|
+
if len(l) == 0:
|
34
|
+
return torch.empty(0, dtype=torch.int64, device=self.device)
|
35
|
+
elif len(l) == 1:
|
36
|
+
return l[0]
|
37
|
+
else:
|
38
|
+
return torch.cat(l)
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
disable: bool,
|
43
|
+
use_hicache: bool,
|
44
|
+
req_to_token_pool: ReqToTokenPool,
|
45
|
+
token_to_kv_pool: BaseTokenToKVPoolAllocator,
|
46
|
+
tp_cache_group: torch.distributed.ProcessGroup,
|
47
|
+
page_size: int,
|
48
|
+
hicache_ratio: float,
|
49
|
+
hicache_size: int,
|
50
|
+
hicache_write_policy: str,
|
51
|
+
enable_kv_cache_events: bool = False,
|
52
|
+
hicache_oracle: bool = False,
|
53
|
+
enable_write_cancel: bool = False,
|
54
|
+
):
|
55
|
+
self.disable = disable
|
56
|
+
self.enable_write_cancel = enable_write_cancel
|
57
|
+
|
58
|
+
assert (
|
59
|
+
enable_kv_cache_events is False
|
60
|
+
), "HiRadixCache does not support kv cache events yet"
|
61
|
+
self.kv_cache = token_to_kv_pool.get_kvcache()
|
62
|
+
|
63
|
+
# record the nodes with ongoing write through
|
64
|
+
self.ongoing_write_through: Set[IOHandle] = set()
|
65
|
+
# record the node segments with ongoing load back
|
66
|
+
self.ongoing_load_back: Set[IOHandle] = set()
|
67
|
+
# todo: dynamically adjust the threshold
|
68
|
+
self.write_through_threshold = (
|
69
|
+
1 if hicache_write_policy == "write_through" else 2
|
70
|
+
)
|
71
|
+
self.device = token_to_kv_pool.device
|
72
|
+
self.token_to_kv_pool = token_to_kv_pool
|
73
|
+
self.req_to_token_pool = req_to_token_pool
|
74
|
+
self.page_size = page_size
|
75
|
+
|
76
|
+
self.tp_group = tp_cache_group
|
77
|
+
|
78
|
+
if not use_hicache:
|
79
|
+
self.tree = RadixTreeCpp(
|
80
|
+
disabled=self.disable,
|
81
|
+
page_size=page_size,
|
82
|
+
host_size=None, # no host cache, this should be removed in the future
|
83
|
+
write_through_threshold=self.write_through_threshold,
|
84
|
+
)
|
85
|
+
self.cache_controller = None
|
86
|
+
return # early return if hicache is not used
|
87
|
+
|
88
|
+
raise NotImplementedError("Host cache is not supported yet")
|
89
|
+
|
90
|
+
def reset(self):
|
91
|
+
if self.cache_controller is not None:
|
92
|
+
# need to clear the acks before resetting the cache controller
|
93
|
+
raise NotImplementedError("Host cache is not supported yet")
|
94
|
+
self.tree.reset()
|
95
|
+
|
96
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
97
|
+
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
|
98
|
+
self.tree.match_prefix(key)
|
99
|
+
)
|
100
|
+
return MatchResult(
|
101
|
+
device_indices=self._merge_tensor(device_indices_vec),
|
102
|
+
last_device_node=node_gpu,
|
103
|
+
last_host_node=node_cpu,
|
104
|
+
host_hit_length=host_indices_length,
|
105
|
+
)
|
106
|
+
|
107
|
+
def _insert(self, key: List[int], value: torch.Tensor) -> int:
|
108
|
+
"""
|
109
|
+
Insert a key-value pair into the radix tree.
|
110
|
+
Args:
|
111
|
+
key (List[int]): The key to insert, represented as a list of integers.
|
112
|
+
value (torch.Tensor): The value to associate with the key.
|
113
|
+
Returns:
|
114
|
+
int: Number of device indices that were already present in the tree before the insertion.
|
115
|
+
"""
|
116
|
+
ongoing_write, length = self.tree.writing_through(key, value)
|
117
|
+
if self.cache_controller is None:
|
118
|
+
assert len(ongoing_write) == 0, "Implementation error"
|
119
|
+
return length
|
120
|
+
|
121
|
+
raise NotImplementedError("Host cache is not supported yet")
|
122
|
+
|
123
|
+
def dec_lock_ref(self, node: TreeNodeCpp):
|
124
|
+
"""
|
125
|
+
Decrement the reference count of a node to root of the radix tree.
|
126
|
+
Args:
|
127
|
+
node (TreeNodeCpp): The handle of the node to decrement the reference count for.
|
128
|
+
"""
|
129
|
+
self.tree.lock_ref(node, False) # do not increment
|
130
|
+
|
131
|
+
def inc_lock_ref(self, node: TreeNodeCpp):
|
132
|
+
"""
|
133
|
+
Increment the reference count of from a node to root of the radix tree.
|
134
|
+
Args:
|
135
|
+
node (TreeNodeCpp): The handle of the node to increment the reference count for.
|
136
|
+
"""
|
137
|
+
self.tree.lock_ref(node, True)
|
138
|
+
|
139
|
+
def evict(self, num_tokens: int):
|
140
|
+
evicted_device_indices = self.tree.evict(num_tokens)
|
141
|
+
for indice in evicted_device_indices:
|
142
|
+
self.token_to_kv_pool.free(indice)
|
143
|
+
|
144
|
+
def evictable_size(self):
|
145
|
+
return self.tree.evictable_size()
|
146
|
+
|
147
|
+
def protected_size(self):
|
148
|
+
return self.tree.protected_size()
|
149
|
+
|
150
|
+
def total_size(self):
|
151
|
+
return self.tree.total_size()
|
152
|
+
|
153
|
+
def cache_finished_req(self, req: Req):
|
154
|
+
"""Cache request when it finishes."""
|
155
|
+
assert req.req_pool_idx is not None
|
156
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
157
|
+
overall_len = len(token_ids) # prefill + decode
|
158
|
+
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
|
159
|
+
|
160
|
+
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
161
|
+
# it will automatically align them, but length of them should be equal
|
162
|
+
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
163
|
+
new_prefix_len = self._insert(token_ids, kv_indices)
|
164
|
+
|
165
|
+
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
166
|
+
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
167
|
+
|
168
|
+
# KVCache between old & new is newly generated, but already exists in the pool
|
169
|
+
# we need to free this newly generated kv indices
|
170
|
+
if old_prefix_len < new_prefix_len:
|
171
|
+
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
172
|
+
|
173
|
+
# need to free the unaligned part, since it cannot be inserted into the radix tree
|
174
|
+
if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
|
175
|
+
(unaligned_len := overall_len % self.page_size) > 0
|
176
|
+
):
|
177
|
+
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
|
178
|
+
self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
|
179
|
+
|
180
|
+
# Remove req slot release the cache lock
|
181
|
+
self.dec_lock_ref(req.last_node)
|
182
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
183
|
+
|
184
|
+
def cache_unfinished_req(self, req: Req):
|
185
|
+
"""Cache request when it is unfinished."""
|
186
|
+
assert req.req_pool_idx is not None
|
187
|
+
token_ids = req.fill_ids
|
188
|
+
prefill_len = len(token_ids) # prefill only (maybe chunked)
|
189
|
+
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefill_len]
|
190
|
+
|
191
|
+
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
192
|
+
# it will automatically align them, but length of them should be equal
|
193
|
+
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
194
|
+
new_prefix_len = self._insert(token_ids, kv_indices)
|
195
|
+
|
196
|
+
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
197
|
+
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
198
|
+
|
199
|
+
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
|
200
|
+
# The prefix indices need to updated to reuse the kv indices in the pool
|
201
|
+
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
|
202
|
+
new_indices = self._merge_tensor(new_indices_vec)
|
203
|
+
assert new_prefix_len <= len(new_indices)
|
204
|
+
|
205
|
+
# KVCache between old & new is newly generated, but already exists in the pool
|
206
|
+
# we need to free this newly generated kv indices and reuse the indices in the pool
|
207
|
+
if old_prefix_len < new_prefix_len:
|
208
|
+
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
209
|
+
reused_indices = new_indices[old_prefix_len:new_prefix_len]
|
210
|
+
self.req_to_token_pool.req_to_token[
|
211
|
+
req.req_pool_idx, old_prefix_len:new_prefix_len
|
212
|
+
] = reused_indices
|
213
|
+
|
214
|
+
if req.last_node != new_last_node:
|
215
|
+
self.dec_lock_ref(req.last_node)
|
216
|
+
self.inc_lock_ref(new_last_node)
|
217
|
+
|
218
|
+
# NOTE: there might be unaligned tail, so we may need to append it
|
219
|
+
assert len(new_indices) <= prefill_len < len(new_indices) + self.page_size
|
220
|
+
if self.page_size != 1 and len(new_indices) < prefill_len:
|
221
|
+
req.prefix_indices = torch.cat(
|
222
|
+
[new_indices, kv_indices[len(new_indices) :]]
|
223
|
+
)
|
224
|
+
else:
|
225
|
+
req.prefix_indices = new_indices
|
226
|
+
req.last_node = new_last_node
|
227
|
+
|
228
|
+
def pretty_print(self):
|
229
|
+
return self.tree.debug_print()
|
@@ -14,6 +14,7 @@ hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
|
|
14
14
|
|
15
15
|
logger = logging.getLogger(__name__)
|
16
16
|
|
17
|
+
HF3FS_AVAILABLE = True
|
17
18
|
try:
|
18
19
|
from hf3fs_fuse.io import (
|
19
20
|
deregister_fd,
|
@@ -22,8 +23,8 @@ try:
|
|
22
23
|
make_iovec,
|
23
24
|
register_fd,
|
24
25
|
)
|
25
|
-
except ImportError
|
26
|
-
|
26
|
+
except ImportError:
|
27
|
+
HF3FS_AVAILABLE = False
|
27
28
|
|
28
29
|
|
29
30
|
def rsynchronized():
|
@@ -52,6 +53,11 @@ def wsynchronized():
|
|
52
53
|
|
53
54
|
class Hf3fsClient:
|
54
55
|
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
56
|
+
if not HF3FS_AVAILABLE:
|
57
|
+
raise ImportError(
|
58
|
+
"hf3fs_fuse.io is not available. Please install the hf3fs_fuse package."
|
59
|
+
)
|
60
|
+
|
55
61
|
self.path = path
|
56
62
|
self.size = size
|
57
63
|
self.bytes_per_page = bytes_per_page
|
@@ -0,0 +1,35 @@
|
|
1
|
+
#include <torch/extension.h>
|
2
|
+
|
3
|
+
#include <cstring>
|
4
|
+
#include <vector>
|
5
|
+
|
6
|
+
void read_shm(const torch::Tensor &shm, std::vector<torch::Tensor> dst) {
|
7
|
+
py::gil_scoped_release release;
|
8
|
+
char *src_ptr = static_cast<char *>(shm.data_ptr());
|
9
|
+
size_t current = 0;
|
10
|
+
for (size_t i = 0; i < dst.size(); ++i) {
|
11
|
+
auto &t = dst[i];
|
12
|
+
size_t t_bytes = t.numel() * t.element_size();
|
13
|
+
char *dst_ptr = static_cast<char *>(t.data_ptr());
|
14
|
+
std::memcpy(dst_ptr, src_ptr + current, t_bytes);
|
15
|
+
current += t_bytes;
|
16
|
+
}
|
17
|
+
}
|
18
|
+
|
19
|
+
void write_shm(const std::vector<torch::Tensor> src, torch::Tensor &shm) {
|
20
|
+
py::gil_scoped_release release;
|
21
|
+
char *dst_ptr = static_cast<char *>(shm.data_ptr());
|
22
|
+
size_t current = 0;
|
23
|
+
for (size_t i = 0; i < src.size(); ++i) {
|
24
|
+
auto &t = src[i];
|
25
|
+
size_t t_bytes = t.numel() * t.element_size();
|
26
|
+
char *src_ptr = static_cast<char *>(t.data_ptr());
|
27
|
+
std::memcpy(dst_ptr + current, src_ptr, t_bytes);
|
28
|
+
current += t_bytes;
|
29
|
+
}
|
30
|
+
}
|
31
|
+
|
32
|
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
33
|
+
m.def("read_shm", &read_shm, "Read tensors from shared memory");
|
34
|
+
m.def("write_shm", &write_shm, "Write tensors to shared memory");
|
35
|
+
}
|
@@ -0,0 +1,163 @@
|
|
1
|
+
import hashlib
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
import uuid
|
6
|
+
from typing import Dict, List, Optional, Tuple, Union
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
11
|
+
|
12
|
+
from .nixl_utils import NixlBackendSelection, NixlFileManager, NixlRegistration
|
13
|
+
|
14
|
+
try:
|
15
|
+
from nixl._api import nixl_agent, nixl_agent_config
|
16
|
+
except ImportError as e:
|
17
|
+
raise ImportError(
|
18
|
+
"Please install NIXL by following the instructions at "
|
19
|
+
"https://github.com/ai-dynamo/nixl/blob/main/README.md "
|
20
|
+
"to use HiCacheNixl storage backend."
|
21
|
+
) from e
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
class HiCacheNixl(HiCacheStorage):
|
27
|
+
"""HiCacheNixl provides high-performance storage using NIXL plugins."""
|
28
|
+
|
29
|
+
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
30
|
+
"""Initialize NIXL storage connector."""
|
31
|
+
self.file_manager = (
|
32
|
+
NixlFileManager(file_path)
|
33
|
+
if plugin not in NixlBackendSelection.OBJ_PLUGINS
|
34
|
+
else None
|
35
|
+
)
|
36
|
+
|
37
|
+
agent_config = nixl_agent_config(backends=[])
|
38
|
+
self.agent_name = f"hicache_nixl_{str(uuid.uuid4())}"
|
39
|
+
self.agent = nixl_agent(self.agent_name, agent_config)
|
40
|
+
|
41
|
+
self.backend_selector = NixlBackendSelection(plugin)
|
42
|
+
if not self.backend_selector.create_backend(self.agent):
|
43
|
+
raise RuntimeError("Failed to create NIXL backend")
|
44
|
+
|
45
|
+
self.registration = NixlRegistration(self.agent)
|
46
|
+
|
47
|
+
def _execute_transfer(
|
48
|
+
self, tensors: List[torch.Tensor], keys: List[str], direction: str
|
49
|
+
) -> bool:
|
50
|
+
if len(tensors) != len(keys):
|
51
|
+
logger.error("Mismatch between number of tensors and files/objects")
|
52
|
+
return False
|
53
|
+
|
54
|
+
if not self.registration.register_buffers(tensors):
|
55
|
+
logger.error("Failed to register tensors")
|
56
|
+
return False
|
57
|
+
|
58
|
+
# Get transfer tuples based on backend type
|
59
|
+
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
|
60
|
+
if self.backend_selector.mem_type == "FILE":
|
61
|
+
file_tuples = self.file_manager.files_to_nixl_tuples(keys)
|
62
|
+
if not file_tuples or not self.registration.register_files(file_tuples):
|
63
|
+
logger.error("Failed to prepare files for transfer")
|
64
|
+
return False
|
65
|
+
transfer_tuples = [
|
66
|
+
(x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
|
67
|
+
]
|
68
|
+
else:
|
69
|
+
if not self.registration.register_objects(keys, tensors):
|
70
|
+
logger.error("Failed to register objects")
|
71
|
+
return False
|
72
|
+
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
|
73
|
+
|
74
|
+
try:
|
75
|
+
# Get transfer descriptors
|
76
|
+
if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
|
77
|
+
file_descs := self.agent.get_xfer_descs(
|
78
|
+
transfer_tuples, self.backend_selector.mem_type
|
79
|
+
)
|
80
|
+
) is None:
|
81
|
+
logger.error("Failed to get transfer descriptors")
|
82
|
+
return False
|
83
|
+
|
84
|
+
# Initialize and execute transfer
|
85
|
+
if (
|
86
|
+
xfer_req := self.agent.initialize_xfer(
|
87
|
+
direction, tensor_descs, file_descs, self.agent_name
|
88
|
+
)
|
89
|
+
) is None:
|
90
|
+
logger.error("Failed to create transfer request")
|
91
|
+
return False
|
92
|
+
|
93
|
+
state = self.agent.transfer(xfer_req)
|
94
|
+
while state != "DONE":
|
95
|
+
state = self.agent.check_xfer_state(xfer_req)
|
96
|
+
if state == "ERR":
|
97
|
+
logger.error("Transfer failed")
|
98
|
+
return False
|
99
|
+
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
|
100
|
+
return True
|
101
|
+
|
102
|
+
except Exception as e:
|
103
|
+
logger.error(f"Failed to execute transfer: {e}")
|
104
|
+
import traceback
|
105
|
+
|
106
|
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
107
|
+
return False
|
108
|
+
|
109
|
+
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
110
|
+
if not keys:
|
111
|
+
return True
|
112
|
+
|
113
|
+
if self.backend_selector.mem_type == "FILE":
|
114
|
+
file_paths = []
|
115
|
+
for key in keys:
|
116
|
+
tensor_path = self.file_manager.get_file_path(key)
|
117
|
+
if not self.file_manager.create_file(tensor_path):
|
118
|
+
logger.error(f"Failed to create file {tensor_path}")
|
119
|
+
return False
|
120
|
+
file_paths.append(tensor_path)
|
121
|
+
return self._execute_transfer(values, file_paths, "WRITE")
|
122
|
+
else:
|
123
|
+
return self._execute_transfer(values, keys, "WRITE")
|
124
|
+
|
125
|
+
def set(self, key: str, value: torch.Tensor) -> bool:
|
126
|
+
return self.batch_set([key], [value])
|
127
|
+
|
128
|
+
def get(
|
129
|
+
self, key: str, dst_tensor: Optional[torch.Tensor] = None
|
130
|
+
) -> torch.Tensor | None:
|
131
|
+
if dst_tensor is None: # To be removed, being compatible with the current API
|
132
|
+
return None
|
133
|
+
result = self.batch_get([key], [dst_tensor])
|
134
|
+
return result[0] if result else None
|
135
|
+
|
136
|
+
def batch_get(
|
137
|
+
self, keys: List[str], dst_tensors: List[torch.Tensor]
|
138
|
+
) -> List[Optional[torch.Tensor]]:
|
139
|
+
if not keys:
|
140
|
+
return []
|
141
|
+
|
142
|
+
if self.backend_selector.mem_type == "FILE":
|
143
|
+
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
144
|
+
success = self._execute_transfer(dst_tensors, file_paths, "READ")
|
145
|
+
else:
|
146
|
+
success = self._execute_transfer(dst_tensors, keys, "READ")
|
147
|
+
return dst_tensors if success else [None] * len(keys)
|
148
|
+
|
149
|
+
def exists(self, key: str) -> bool:
|
150
|
+
tuples = self.registration.create_query_tuples(
|
151
|
+
key,
|
152
|
+
self.backend_selector.mem_type,
|
153
|
+
self.file_manager if self.backend_selector.mem_type == "FILE" else None,
|
154
|
+
)
|
155
|
+
if not tuples:
|
156
|
+
return False
|
157
|
+
|
158
|
+
query_res = self.agent.query_memory(
|
159
|
+
tuples,
|
160
|
+
self.backend_selector.backend_name,
|
161
|
+
mem_type=self.backend_selector.mem_type,
|
162
|
+
)
|
163
|
+
return query_res[0] is not None # can be expanded to multiple keys
|
@@ -0,0 +1,238 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
class NixlBackendSelection:
|
11
|
+
"""Handles NIXL backend selection and creation."""
|
12
|
+
|
13
|
+
# Priority order for File-based plugins in case of auto selection
|
14
|
+
FILE_PLUGINS = ["3FS", "POSIX", "GDS_MT", "GDS"]
|
15
|
+
# Priority order for File-based plugins in case of auto selection (add more as needed)
|
16
|
+
OBJ_PLUGINS = ["OBJ"] # Based on Amazon S3 SDK
|
17
|
+
|
18
|
+
def __init__(self, plugin: str = "auto"):
|
19
|
+
"""Initialize backend selection.
|
20
|
+
Args:
|
21
|
+
plugin: Plugin to use (default "auto" selects best available).
|
22
|
+
Can be a file plugin (3FS, POSIX, GDS, GDS_MT) or
|
23
|
+
an object plugin (OBJ).
|
24
|
+
"""
|
25
|
+
self.plugin = plugin
|
26
|
+
self.backend_name = None
|
27
|
+
self.mem_type = None
|
28
|
+
|
29
|
+
def set_bucket(self, bucket_name: str) -> None:
|
30
|
+
"""Set AWS bucket name in environment variable."""
|
31
|
+
os.environ["AWS_DEFAULT_BUCKET"] = bucket_name
|
32
|
+
logger.debug(f"Set AWS bucket name to: {bucket_name}")
|
33
|
+
|
34
|
+
def create_backend(self, agent) -> bool:
|
35
|
+
"""Create the appropriate NIXL backend based on configuration."""
|
36
|
+
try:
|
37
|
+
plugin_list = agent.get_plugin_list()
|
38
|
+
logger.debug(f"Available NIXL plugins: {plugin_list}")
|
39
|
+
|
40
|
+
# Handle explicit plugin selection or auto priority
|
41
|
+
if self.plugin == "auto":
|
42
|
+
# Try all file plugins first
|
43
|
+
for plugin in self.FILE_PLUGINS:
|
44
|
+
if plugin in plugin_list:
|
45
|
+
self.backend_name = plugin
|
46
|
+
break
|
47
|
+
# If no file plugin found, try object plugins
|
48
|
+
if not self.backend_name:
|
49
|
+
for plugin in self.OBJ_PLUGINS:
|
50
|
+
if plugin in plugin_list:
|
51
|
+
self.backend_name = plugin
|
52
|
+
break
|
53
|
+
else:
|
54
|
+
# Use explicitly requested plugin
|
55
|
+
self.backend_name = self.plugin
|
56
|
+
|
57
|
+
if self.backend_name not in plugin_list:
|
58
|
+
logger.error(
|
59
|
+
f"Backend {self.backend_name} not available in plugins: {plugin_list}"
|
60
|
+
)
|
61
|
+
return False
|
62
|
+
|
63
|
+
# Create backend and set memory type
|
64
|
+
if self.backend_name in self.OBJ_PLUGINS:
|
65
|
+
bucket = os.environ.get("AWS_DEFAULT_BUCKET")
|
66
|
+
if not bucket:
|
67
|
+
logger.error(
|
68
|
+
"AWS_DEFAULT_BUCKET environment variable must be set for object storage"
|
69
|
+
)
|
70
|
+
return False
|
71
|
+
agent.create_backend(self.backend_name, {"bucket": bucket})
|
72
|
+
else:
|
73
|
+
agent.create_backend(self.backend_name)
|
74
|
+
|
75
|
+
self.mem_type = "OBJ" if self.backend_name in self.OBJ_PLUGINS else "FILE"
|
76
|
+
logger.debug(
|
77
|
+
f"Created NIXL backend: {self.backend_name} with memory type: {self.mem_type}"
|
78
|
+
)
|
79
|
+
return True
|
80
|
+
|
81
|
+
except Exception as e:
|
82
|
+
logger.error(f"Failed to create NIXL backend: {e}")
|
83
|
+
return False
|
84
|
+
|
85
|
+
|
86
|
+
class NixlRegistration:
|
87
|
+
"""Handles NIXL memory registration."""
|
88
|
+
|
89
|
+
def __init__(self, agent):
|
90
|
+
self.agent = agent
|
91
|
+
|
92
|
+
def create_query_tuples(
|
93
|
+
self, key: str, mem_type: str, file_manager=None
|
94
|
+
) -> List[Tuple]:
|
95
|
+
"""Create NIXL tuples for querying memory.
|
96
|
+
Args:
|
97
|
+
key: Key to query (file path for FILE or object key for OBJ)
|
98
|
+
mem_type: Memory type ("FILE" or "OBJ")
|
99
|
+
file_manager: Optional NixlFileManager for FILE memory type
|
100
|
+
Returns:
|
101
|
+
List of NIXL tuples for querying
|
102
|
+
"""
|
103
|
+
if mem_type == "FILE":
|
104
|
+
if file_manager is None:
|
105
|
+
logger.error("file_manager required for FILE memory type")
|
106
|
+
return []
|
107
|
+
return [(0, 0, 0, file_manager.get_file_path(key))]
|
108
|
+
else: # OBJ
|
109
|
+
return [(0, 0, key)]
|
110
|
+
|
111
|
+
def _register_memory(
|
112
|
+
self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str
|
113
|
+
) -> Optional[Any]:
|
114
|
+
"""Common registration logic for files, objects, and buffers.
|
115
|
+
Args:
|
116
|
+
items: List of tuples or tensors to register
|
117
|
+
mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM")
|
118
|
+
desc: Description for logging
|
119
|
+
"""
|
120
|
+
try:
|
121
|
+
if not items:
|
122
|
+
return None
|
123
|
+
|
124
|
+
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
125
|
+
if reg_descs is None:
|
126
|
+
logger.error("Failed to create registration descriptors")
|
127
|
+
return None
|
128
|
+
|
129
|
+
registered_memory = self.agent.register_memory(reg_descs)
|
130
|
+
if registered_memory:
|
131
|
+
return registered_memory
|
132
|
+
else:
|
133
|
+
logger.error("Failed to register with NIXL")
|
134
|
+
return None
|
135
|
+
|
136
|
+
except Exception as e:
|
137
|
+
logger.error(f"Failed to register {desc}: {e}")
|
138
|
+
return None
|
139
|
+
|
140
|
+
def register_buffers(
|
141
|
+
self, buffers: Union[torch.Tensor, List[torch.Tensor]]
|
142
|
+
) -> Optional[Any]:
|
143
|
+
"""Register tensors/buffers with NIXL."""
|
144
|
+
if isinstance(buffers, torch.Tensor):
|
145
|
+
buffers = [buffers]
|
146
|
+
|
147
|
+
if not buffers:
|
148
|
+
return None
|
149
|
+
|
150
|
+
# Determine memory type based on tensor device
|
151
|
+
mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM"
|
152
|
+
return self._register_memory(buffers, mem_type, "buffers")
|
153
|
+
|
154
|
+
def register_files(self, tuples: List[tuple]) -> Optional[Any]:
|
155
|
+
"""Register files with NIXL using (0, 0, fd, file_path) tuples."""
|
156
|
+
return self._register_memory(tuples, "FILE", "files")
|
157
|
+
|
158
|
+
def register_objects(
|
159
|
+
self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None
|
160
|
+
) -> Optional[Any]:
|
161
|
+
"""Register objects with NIXL."""
|
162
|
+
if not keys:
|
163
|
+
return None
|
164
|
+
|
165
|
+
# Create object tuples with proper sizes
|
166
|
+
tuples = [
|
167
|
+
(0, tensor.element_size() * tensor.numel() if tensor else 0, key)
|
168
|
+
for key, tensor in zip(keys, tensors or [None] * len(keys))
|
169
|
+
]
|
170
|
+
return self._register_memory(tuples, "OBJ", "objects")
|
171
|
+
|
172
|
+
|
173
|
+
class NixlFileManager:
|
174
|
+
"""Handles file system operations for NIXL."""
|
175
|
+
|
176
|
+
def __init__(self, base_dir: str):
|
177
|
+
"""
|
178
|
+
Initialize file manager.
|
179
|
+
Args:
|
180
|
+
base_dir: Base directory for storing tensor files
|
181
|
+
"""
|
182
|
+
self.base_dir = base_dir
|
183
|
+
if base_dir == "":
|
184
|
+
logger.debug(f"Initialized file manager without a base directory")
|
185
|
+
else:
|
186
|
+
os.makedirs(base_dir, exist_ok=True)
|
187
|
+
logger.debug(f"Initialized file manager with base directory: {base_dir}")
|
188
|
+
|
189
|
+
def get_file_path(self, key: str) -> str:
|
190
|
+
"""Get full file path for a given key."""
|
191
|
+
return os.path.join(self.base_dir, key)
|
192
|
+
|
193
|
+
def create_file(self, file_path: str) -> bool:
|
194
|
+
"""Create a file if it doesn't exist."""
|
195
|
+
try:
|
196
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
197
|
+
if not os.path.exists(file_path):
|
198
|
+
with open(file_path, "wb") as f:
|
199
|
+
pass # Create empty file
|
200
|
+
return True
|
201
|
+
except Exception as e:
|
202
|
+
logger.error(f"Failed to create file {file_path}: {e}")
|
203
|
+
return False
|
204
|
+
|
205
|
+
def open_file(self, file_path: str) -> Optional[int]:
|
206
|
+
"""Open a file and return its file descriptor."""
|
207
|
+
try:
|
208
|
+
fd = os.open(file_path, os.O_RDWR)
|
209
|
+
return fd
|
210
|
+
except Exception as e:
|
211
|
+
logger.error(f"Failed to open file {file_path}: {e}")
|
212
|
+
return None
|
213
|
+
|
214
|
+
def close_file(self, fd: int) -> bool:
|
215
|
+
"""Close a file descriptor."""
|
216
|
+
try:
|
217
|
+
os.close(fd)
|
218
|
+
return True
|
219
|
+
except Exception as e:
|
220
|
+
logger.error(f"Failed to close file descriptor {fd}: {e}")
|
221
|
+
return False
|
222
|
+
|
223
|
+
def files_to_nixl_tuples(
|
224
|
+
self, file_paths: List[str], open_file: bool = True
|
225
|
+
) -> List[Tuple[int, int, int, str]]:
|
226
|
+
"""Create NIXL tuples (offset, length, fd, file_path) for given files."""
|
227
|
+
if not open_file:
|
228
|
+
return [(0, 0, 0, path) for path in file_paths]
|
229
|
+
|
230
|
+
tuples = []
|
231
|
+
for path in file_paths:
|
232
|
+
if (fd := self.open_file(path)) is None:
|
233
|
+
# Clean up on failure
|
234
|
+
for t in tuples:
|
235
|
+
self.close_file(t[2])
|
236
|
+
return []
|
237
|
+
tuples.append((0, 0, fd, path))
|
238
|
+
return tuples
|