sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +39 -674
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
- sglang/srt/layers/quantization/fp8.py +52 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +165 -67
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +90 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +60 -17
- sglang/srt/mem_cache/hiradix_cache.py +36 -8
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +418 -29
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/glm4_moe.py +6 -4
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +991 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +49 -18
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,238 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
class NixlBackendSelection:
|
11
|
+
"""Handles NIXL backend selection and creation."""
|
12
|
+
|
13
|
+
# Priority order for File-based plugins in case of auto selection
|
14
|
+
FILE_PLUGINS = ["3FS", "POSIX", "GDS_MT", "GDS"]
|
15
|
+
# Priority order for File-based plugins in case of auto selection (add more as needed)
|
16
|
+
OBJ_PLUGINS = ["OBJ"] # Based on Amazon S3 SDK
|
17
|
+
|
18
|
+
def __init__(self, plugin: str = "auto"):
|
19
|
+
"""Initialize backend selection.
|
20
|
+
Args:
|
21
|
+
plugin: Plugin to use (default "auto" selects best available).
|
22
|
+
Can be a file plugin (3FS, POSIX, GDS, GDS_MT) or
|
23
|
+
an object plugin (OBJ).
|
24
|
+
"""
|
25
|
+
self.plugin = plugin
|
26
|
+
self.backend_name = None
|
27
|
+
self.mem_type = None
|
28
|
+
|
29
|
+
def set_bucket(self, bucket_name: str) -> None:
|
30
|
+
"""Set AWS bucket name in environment variable."""
|
31
|
+
os.environ["AWS_DEFAULT_BUCKET"] = bucket_name
|
32
|
+
logger.debug(f"Set AWS bucket name to: {bucket_name}")
|
33
|
+
|
34
|
+
def create_backend(self, agent) -> bool:
|
35
|
+
"""Create the appropriate NIXL backend based on configuration."""
|
36
|
+
try:
|
37
|
+
plugin_list = agent.get_plugin_list()
|
38
|
+
logger.debug(f"Available NIXL plugins: {plugin_list}")
|
39
|
+
|
40
|
+
# Handle explicit plugin selection or auto priority
|
41
|
+
if self.plugin == "auto":
|
42
|
+
# Try all file plugins first
|
43
|
+
for plugin in self.FILE_PLUGINS:
|
44
|
+
if plugin in plugin_list:
|
45
|
+
self.backend_name = plugin
|
46
|
+
break
|
47
|
+
# If no file plugin found, try object plugins
|
48
|
+
if not self.backend_name:
|
49
|
+
for plugin in self.OBJ_PLUGINS:
|
50
|
+
if plugin in plugin_list:
|
51
|
+
self.backend_name = plugin
|
52
|
+
break
|
53
|
+
else:
|
54
|
+
# Use explicitly requested plugin
|
55
|
+
self.backend_name = self.plugin
|
56
|
+
|
57
|
+
if self.backend_name not in plugin_list:
|
58
|
+
logger.error(
|
59
|
+
f"Backend {self.backend_name} not available in plugins: {plugin_list}"
|
60
|
+
)
|
61
|
+
return False
|
62
|
+
|
63
|
+
# Create backend and set memory type
|
64
|
+
if self.backend_name in self.OBJ_PLUGINS:
|
65
|
+
bucket = os.environ.get("AWS_DEFAULT_BUCKET")
|
66
|
+
if not bucket:
|
67
|
+
logger.error(
|
68
|
+
"AWS_DEFAULT_BUCKET environment variable must be set for object storage"
|
69
|
+
)
|
70
|
+
return False
|
71
|
+
agent.create_backend(self.backend_name, {"bucket": bucket})
|
72
|
+
else:
|
73
|
+
agent.create_backend(self.backend_name)
|
74
|
+
|
75
|
+
self.mem_type = "OBJ" if self.backend_name in self.OBJ_PLUGINS else "FILE"
|
76
|
+
logger.debug(
|
77
|
+
f"Created NIXL backend: {self.backend_name} with memory type: {self.mem_type}"
|
78
|
+
)
|
79
|
+
return True
|
80
|
+
|
81
|
+
except Exception as e:
|
82
|
+
logger.error(f"Failed to create NIXL backend: {e}")
|
83
|
+
return False
|
84
|
+
|
85
|
+
|
86
|
+
class NixlRegistration:
|
87
|
+
"""Handles NIXL memory registration."""
|
88
|
+
|
89
|
+
def __init__(self, agent):
|
90
|
+
self.agent = agent
|
91
|
+
|
92
|
+
def create_query_tuples(
|
93
|
+
self, key: str, mem_type: str, file_manager=None
|
94
|
+
) -> List[Tuple]:
|
95
|
+
"""Create NIXL tuples for querying memory.
|
96
|
+
Args:
|
97
|
+
key: Key to query (file path for FILE or object key for OBJ)
|
98
|
+
mem_type: Memory type ("FILE" or "OBJ")
|
99
|
+
file_manager: Optional NixlFileManager for FILE memory type
|
100
|
+
Returns:
|
101
|
+
List of NIXL tuples for querying
|
102
|
+
"""
|
103
|
+
if mem_type == "FILE":
|
104
|
+
if file_manager is None:
|
105
|
+
logger.error("file_manager required for FILE memory type")
|
106
|
+
return []
|
107
|
+
return [(0, 0, 0, file_manager.get_file_path(key))]
|
108
|
+
else: # OBJ
|
109
|
+
return [(0, 0, key)]
|
110
|
+
|
111
|
+
def _register_memory(
|
112
|
+
self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str
|
113
|
+
) -> Optional[Any]:
|
114
|
+
"""Common registration logic for files, objects, and buffers.
|
115
|
+
Args:
|
116
|
+
items: List of tuples or tensors to register
|
117
|
+
mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM")
|
118
|
+
desc: Description for logging
|
119
|
+
"""
|
120
|
+
try:
|
121
|
+
if not items:
|
122
|
+
return None
|
123
|
+
|
124
|
+
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
125
|
+
if reg_descs is None:
|
126
|
+
logger.error("Failed to create registration descriptors")
|
127
|
+
return None
|
128
|
+
|
129
|
+
registered_memory = self.agent.register_memory(reg_descs)
|
130
|
+
if registered_memory:
|
131
|
+
return registered_memory
|
132
|
+
else:
|
133
|
+
logger.error("Failed to register with NIXL")
|
134
|
+
return None
|
135
|
+
|
136
|
+
except Exception as e:
|
137
|
+
logger.error(f"Failed to register {desc}: {e}")
|
138
|
+
return None
|
139
|
+
|
140
|
+
def register_buffers(
|
141
|
+
self, buffers: Union[torch.Tensor, List[torch.Tensor]]
|
142
|
+
) -> Optional[Any]:
|
143
|
+
"""Register tensors/buffers with NIXL."""
|
144
|
+
if isinstance(buffers, torch.Tensor):
|
145
|
+
buffers = [buffers]
|
146
|
+
|
147
|
+
if not buffers:
|
148
|
+
return None
|
149
|
+
|
150
|
+
# Determine memory type based on tensor device
|
151
|
+
mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM"
|
152
|
+
return self._register_memory(buffers, mem_type, "buffers")
|
153
|
+
|
154
|
+
def register_files(self, tuples: List[tuple]) -> Optional[Any]:
|
155
|
+
"""Register files with NIXL using (0, 0, fd, file_path) tuples."""
|
156
|
+
return self._register_memory(tuples, "FILE", "files")
|
157
|
+
|
158
|
+
def register_objects(
|
159
|
+
self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None
|
160
|
+
) -> Optional[Any]:
|
161
|
+
"""Register objects with NIXL."""
|
162
|
+
if not keys:
|
163
|
+
return None
|
164
|
+
|
165
|
+
# Create object tuples with proper sizes
|
166
|
+
tuples = [
|
167
|
+
(0, tensor.element_size() * tensor.numel() if tensor else 0, key)
|
168
|
+
for key, tensor in zip(keys, tensors or [None] * len(keys))
|
169
|
+
]
|
170
|
+
return self._register_memory(tuples, "OBJ", "objects")
|
171
|
+
|
172
|
+
|
173
|
+
class NixlFileManager:
|
174
|
+
"""Handles file system operations for NIXL."""
|
175
|
+
|
176
|
+
def __init__(self, base_dir: str):
|
177
|
+
"""
|
178
|
+
Initialize file manager.
|
179
|
+
Args:
|
180
|
+
base_dir: Base directory for storing tensor files
|
181
|
+
"""
|
182
|
+
self.base_dir = base_dir
|
183
|
+
if base_dir == "":
|
184
|
+
logger.debug(f"Initialized file manager without a base directory")
|
185
|
+
else:
|
186
|
+
os.makedirs(base_dir, exist_ok=True)
|
187
|
+
logger.debug(f"Initialized file manager with base directory: {base_dir}")
|
188
|
+
|
189
|
+
def get_file_path(self, key: str) -> str:
|
190
|
+
"""Get full file path for a given key."""
|
191
|
+
return os.path.join(self.base_dir, key)
|
192
|
+
|
193
|
+
def create_file(self, file_path: str) -> bool:
|
194
|
+
"""Create a file if it doesn't exist."""
|
195
|
+
try:
|
196
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
197
|
+
if not os.path.exists(file_path):
|
198
|
+
with open(file_path, "wb") as f:
|
199
|
+
pass # Create empty file
|
200
|
+
return True
|
201
|
+
except Exception as e:
|
202
|
+
logger.error(f"Failed to create file {file_path}: {e}")
|
203
|
+
return False
|
204
|
+
|
205
|
+
def open_file(self, file_path: str) -> Optional[int]:
|
206
|
+
"""Open a file and return its file descriptor."""
|
207
|
+
try:
|
208
|
+
fd = os.open(file_path, os.O_RDWR)
|
209
|
+
return fd
|
210
|
+
except Exception as e:
|
211
|
+
logger.error(f"Failed to open file {file_path}: {e}")
|
212
|
+
return None
|
213
|
+
|
214
|
+
def close_file(self, fd: int) -> bool:
|
215
|
+
"""Close a file descriptor."""
|
216
|
+
try:
|
217
|
+
os.close(fd)
|
218
|
+
return True
|
219
|
+
except Exception as e:
|
220
|
+
logger.error(f"Failed to close file descriptor {fd}: {e}")
|
221
|
+
return False
|
222
|
+
|
223
|
+
def files_to_nixl_tuples(
|
224
|
+
self, file_paths: List[str], open_file: bool = True
|
225
|
+
) -> List[Tuple[int, int, int, str]]:
|
226
|
+
"""Create NIXL tuples (offset, length, fd, file_path) for given files."""
|
227
|
+
if not open_file:
|
228
|
+
return [(0, 0, 0, path) for path in file_paths]
|
229
|
+
|
230
|
+
tuples = []
|
231
|
+
for path in file_paths:
|
232
|
+
if (fd := self.open_file(path)) is None:
|
233
|
+
# Clean up on failure
|
234
|
+
for t in tuples:
|
235
|
+
self.close_file(t[2])
|
236
|
+
return []
|
237
|
+
tuples.append((0, 0, fd, path))
|
238
|
+
return tuples
|
@@ -0,0 +1,216 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
import os
|
4
|
+
import unittest
|
5
|
+
from typing import List, Optional
|
6
|
+
from unittest.mock import MagicMock
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
11
|
+
from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
|
12
|
+
|
13
|
+
|
14
|
+
class TestNixlUnified(unittest.TestCase):
|
15
|
+
"""Unified test suite for all NIXL components."""
|
16
|
+
|
17
|
+
def setUp(self):
|
18
|
+
"""Set up test environment."""
|
19
|
+
# Create test directories
|
20
|
+
self.test_dir = "/tmp/test_nixl_unified"
|
21
|
+
os.makedirs(self.test_dir, exist_ok=True)
|
22
|
+
|
23
|
+
# Mock NIXL agent for registration tests
|
24
|
+
self.mock_agent = MagicMock()
|
25
|
+
self.mock_agent.get_reg_descs.return_value = "mock_reg_descs"
|
26
|
+
self.mock_agent.register_memory.return_value = "mock_registered_memory"
|
27
|
+
|
28
|
+
# Create instances
|
29
|
+
self.file_manager = NixlFileManager(self.test_dir)
|
30
|
+
self.registration = NixlRegistration(self.mock_agent)
|
31
|
+
try:
|
32
|
+
self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX")
|
33
|
+
except ImportError:
|
34
|
+
self.skipTest("NIXL not available, skipping NIXL storage tests")
|
35
|
+
|
36
|
+
def tearDown(self):
|
37
|
+
"""Clean up test directories."""
|
38
|
+
if os.path.exists(self.test_dir):
|
39
|
+
import shutil
|
40
|
+
|
41
|
+
shutil.rmtree(self.test_dir)
|
42
|
+
|
43
|
+
def delete_test_file(self, file_path: str) -> bool:
|
44
|
+
"""Helper method to delete a test file.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
file_path: Path to the file to delete
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
bool: True if file was deleted or didn't exist, False on error
|
51
|
+
"""
|
52
|
+
try:
|
53
|
+
if os.path.exists(file_path):
|
54
|
+
os.remove(file_path)
|
55
|
+
return True
|
56
|
+
except Exception as e:
|
57
|
+
return False
|
58
|
+
|
59
|
+
def verify_tensors_equal(self, expected: torch.Tensor, actual: torch.Tensor):
|
60
|
+
"""Helper to verify tensor equality."""
|
61
|
+
self.assertIsNotNone(actual, "Retrieved tensor is None")
|
62
|
+
self.assertTrue(
|
63
|
+
torch.allclose(expected, actual, atol=1e-6),
|
64
|
+
f"Tensors not equal:\nExpected: {expected}\nActual: {actual}",
|
65
|
+
)
|
66
|
+
|
67
|
+
def verify_tensor_lists_equal(
|
68
|
+
self, expected: List[torch.Tensor], actual: List[torch.Tensor]
|
69
|
+
):
|
70
|
+
"""Helper to verify lists of tensors are equal."""
|
71
|
+
self.assertEqual(len(expected), len(actual), "Lists have different lengths")
|
72
|
+
for exp, act in zip(expected, actual):
|
73
|
+
self.verify_tensors_equal(exp, act)
|
74
|
+
|
75
|
+
# ============================================================================
|
76
|
+
# HiCache Integration Tests
|
77
|
+
# ============================================================================
|
78
|
+
|
79
|
+
def test_single_set_get(self):
|
80
|
+
"""Test single tensor set/get operations."""
|
81
|
+
key = "test_key"
|
82
|
+
value = torch.randn(10, 10, device="cpu")
|
83
|
+
dst_tensor = torch.zeros_like(value, device="cpu")
|
84
|
+
|
85
|
+
# Test set
|
86
|
+
self.assertTrue(self.hicache.set(key, value))
|
87
|
+
self.assertTrue(self.hicache.exists(key))
|
88
|
+
|
89
|
+
# Test get
|
90
|
+
retrieved = self.hicache.get(key, dst_tensor)
|
91
|
+
self.verify_tensors_equal(value, retrieved)
|
92
|
+
|
93
|
+
def test_batch_set_get(self):
|
94
|
+
"""Test batch tensor set/get operations."""
|
95
|
+
keys = ["key1", "key2", "key3"]
|
96
|
+
values = [
|
97
|
+
torch.randn(5, 5, device="cpu"),
|
98
|
+
torch.randn(3, 3, device="cpu"),
|
99
|
+
torch.randn(7, 7, device="cpu"),
|
100
|
+
]
|
101
|
+
dst_tensors = [torch.zeros_like(v, device="cpu") for v in values]
|
102
|
+
|
103
|
+
# Test batch set
|
104
|
+
self.assertTrue(self.hicache.batch_set(keys, values))
|
105
|
+
self.assertTrue(all(self.hicache.exists(key) for key in keys))
|
106
|
+
|
107
|
+
# Test batch get
|
108
|
+
retrieved = self.hicache.batch_get(keys, dst_tensors)
|
109
|
+
self.verify_tensor_lists_equal(values, retrieved)
|
110
|
+
|
111
|
+
def test_mixed_operations(self):
|
112
|
+
"""Test mixing single and batch operations."""
|
113
|
+
# Test interleaved set/get operations
|
114
|
+
key1, key2 = "key1", "key2"
|
115
|
+
value1 = torch.randn(4, 4, device="cpu")
|
116
|
+
value2 = torch.randn(6, 6, device="cpu")
|
117
|
+
dst1 = torch.zeros_like(value1)
|
118
|
+
dst2 = torch.zeros_like(value2)
|
119
|
+
|
120
|
+
# Single set/get
|
121
|
+
self.assertTrue(self.hicache.set(key1, value1))
|
122
|
+
retrieved1 = self.hicache.get(key1, dst1)
|
123
|
+
self.verify_tensors_equal(value1, retrieved1)
|
124
|
+
|
125
|
+
# Batch set/get
|
126
|
+
self.assertTrue(self.hicache.batch_set([key2], [value2]))
|
127
|
+
retrieved2 = self.hicache.batch_get([key2], [dst2])
|
128
|
+
self.verify_tensors_equal(value2, retrieved2[0])
|
129
|
+
|
130
|
+
def test_data_integrity(self):
|
131
|
+
"""Test data integrity across operations."""
|
132
|
+
# Test with various tensor types and sizes
|
133
|
+
test_cases = [
|
134
|
+
("float32", torch.randn(10, 10, dtype=torch.float32)),
|
135
|
+
("float64", torch.randn(5, 5, dtype=torch.float64)),
|
136
|
+
("int32", torch.randint(-100, 100, (8, 8), dtype=torch.int32)),
|
137
|
+
("int64", torch.randint(-100, 100, (6, 6), dtype=torch.int64)),
|
138
|
+
("bool", torch.randint(0, 2, (4, 4)).bool()),
|
139
|
+
]
|
140
|
+
|
141
|
+
for name, tensor in test_cases:
|
142
|
+
with self.subTest(tensor_type=name):
|
143
|
+
key = f"test_{name}"
|
144
|
+
dst_tensor = torch.zeros_like(tensor)
|
145
|
+
|
146
|
+
# Set and immediately get
|
147
|
+
self.assertTrue(self.hicache.set(key, tensor))
|
148
|
+
retrieved1 = self.hicache.get(key, dst_tensor)
|
149
|
+
self.verify_tensors_equal(tensor, retrieved1)
|
150
|
+
|
151
|
+
# Get again to verify persistence
|
152
|
+
dst_tensor.zero_()
|
153
|
+
retrieved2 = self.hicache.get(key, dst_tensor)
|
154
|
+
self.verify_tensors_equal(tensor, retrieved2)
|
155
|
+
|
156
|
+
def test_basic_file_operations(self):
|
157
|
+
"""Test basic file operations."""
|
158
|
+
test_file = os.path.join(self.test_dir, "test_file.bin")
|
159
|
+
self.file_manager.create_file(test_file)
|
160
|
+
self.assertTrue(os.path.exists(test_file))
|
161
|
+
self.assertEqual(os.path.getsize(test_file), 0) # Empty file
|
162
|
+
|
163
|
+
# Test file deletion
|
164
|
+
self.assertTrue(self.delete_test_file(test_file))
|
165
|
+
self.assertFalse(os.path.exists(test_file))
|
166
|
+
|
167
|
+
def test_create_nixl_tuples(self):
|
168
|
+
"""Test creation of NIXL tuples."""
|
169
|
+
test_file = os.path.join(self.test_dir, "test_file.bin")
|
170
|
+
self.file_manager.create_file(test_file)
|
171
|
+
|
172
|
+
# Test tuple creation
|
173
|
+
tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
|
174
|
+
self.assertIsNotNone(tuples)
|
175
|
+
self.assertTrue(len(tuples) > 0)
|
176
|
+
|
177
|
+
def test_error_handling(self):
|
178
|
+
"""Test error handling in file operations."""
|
179
|
+
# Test non-existent file
|
180
|
+
self.assertTrue(
|
181
|
+
self.delete_test_file("nonexistent_file.bin")
|
182
|
+
) # Returns True if file doesn't exist
|
183
|
+
|
184
|
+
# Test invalid file path
|
185
|
+
self.assertFalse(self.file_manager.create_file("")) # Empty path should fail
|
186
|
+
|
187
|
+
def test_register_buffers(self):
|
188
|
+
"""Test registration of memory buffers."""
|
189
|
+
# Create test tensor
|
190
|
+
tensor = torch.randn(10, 10)
|
191
|
+
|
192
|
+
# Test buffer registration
|
193
|
+
self.assertIsNotNone(self.registration.register_buffers(tensor))
|
194
|
+
|
195
|
+
# Test batch registration
|
196
|
+
tensors = [torch.randn(5, 5) for _ in range(3)]
|
197
|
+
self.assertIsNotNone(self.registration.register_buffers(tensors))
|
198
|
+
|
199
|
+
def test_register_files_with_tuples(self):
|
200
|
+
"""Test registration of files using NIXL tuples."""
|
201
|
+
files = [os.path.join(self.test_dir, f"test_file_{i}.bin") for i in range(3)]
|
202
|
+
for file in files:
|
203
|
+
self.file_manager.create_file(file)
|
204
|
+
|
205
|
+
# Create tuples and register
|
206
|
+
tuples = self.file_manager.files_to_nixl_tuples(files, False)
|
207
|
+
self.registration.register_files(tuples)
|
208
|
+
|
209
|
+
# Verify tuples
|
210
|
+
self.assertEqual(len(tuples), len(files))
|
211
|
+
for t, f in zip(tuples, files):
|
212
|
+
self.assertEqual(t[3], f) # Check file path
|
213
|
+
|
214
|
+
|
215
|
+
if __name__ == "__main__":
|
216
|
+
unittest.main()
|
@@ -0,0 +1,183 @@
|
|
1
|
+
import logging
|
2
|
+
import multiprocessing
|
3
|
+
import os
|
4
|
+
import threading
|
5
|
+
from functools import wraps
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import List
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from torch.utils.cpp_extension import load
|
11
|
+
|
12
|
+
root = Path(__file__).parent.resolve()
|
13
|
+
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
HF3FS_AVAILABLE = True
|
18
|
+
try:
|
19
|
+
from hf3fs_fuse.io import (
|
20
|
+
deregister_fd,
|
21
|
+
extract_mount_point,
|
22
|
+
make_ioring,
|
23
|
+
make_iovec,
|
24
|
+
register_fd,
|
25
|
+
)
|
26
|
+
except ImportError:
|
27
|
+
HF3FS_AVAILABLE = False
|
28
|
+
|
29
|
+
|
30
|
+
def rsynchronized():
|
31
|
+
def _decorator(func):
|
32
|
+
@wraps(func)
|
33
|
+
def wrapper(self, *args, **kwargs):
|
34
|
+
with self.rlock:
|
35
|
+
return func(self, *args, **kwargs)
|
36
|
+
|
37
|
+
return wrapper
|
38
|
+
|
39
|
+
return _decorator
|
40
|
+
|
41
|
+
|
42
|
+
def wsynchronized():
|
43
|
+
def _decorator(func):
|
44
|
+
@wraps(func)
|
45
|
+
def wrapper(self, *args, **kwargs):
|
46
|
+
with self.wlock:
|
47
|
+
return func(self, *args, **kwargs)
|
48
|
+
|
49
|
+
return wrapper
|
50
|
+
|
51
|
+
return _decorator
|
52
|
+
|
53
|
+
|
54
|
+
class Hf3fsClient:
|
55
|
+
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
56
|
+
if not HF3FS_AVAILABLE:
|
57
|
+
raise ImportError(
|
58
|
+
"hf3fs_fuse.io is not available. Please install the hf3fs_fuse package."
|
59
|
+
)
|
60
|
+
|
61
|
+
self.path = path
|
62
|
+
self.size = size
|
63
|
+
self.bytes_per_page = bytes_per_page
|
64
|
+
self.entries = entries
|
65
|
+
|
66
|
+
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
|
67
|
+
os.ftruncate(self.file, size)
|
68
|
+
register_fd(self.file)
|
69
|
+
|
70
|
+
self.hf3fs_mount_point = extract_mount_point(path)
|
71
|
+
self.bs = self.bytes_per_page
|
72
|
+
self.shm_r = multiprocessing.shared_memory.SharedMemory(
|
73
|
+
size=self.bs * self.entries, create=True
|
74
|
+
)
|
75
|
+
self.shm_w = multiprocessing.shared_memory.SharedMemory(
|
76
|
+
size=self.bs * self.entries, create=True
|
77
|
+
)
|
78
|
+
|
79
|
+
self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8)
|
80
|
+
self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8)
|
81
|
+
|
82
|
+
self.numa = -1
|
83
|
+
self.ior_r = make_ioring(
|
84
|
+
self.hf3fs_mount_point,
|
85
|
+
self.entries,
|
86
|
+
for_read=True,
|
87
|
+
timeout=1,
|
88
|
+
numa=self.numa,
|
89
|
+
)
|
90
|
+
self.ior_w = make_ioring(
|
91
|
+
self.hf3fs_mount_point,
|
92
|
+
self.entries,
|
93
|
+
for_read=False,
|
94
|
+
timeout=1,
|
95
|
+
numa=self.numa,
|
96
|
+
)
|
97
|
+
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
|
98
|
+
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
|
99
|
+
|
100
|
+
self.rlock = threading.RLock()
|
101
|
+
self.wlock = threading.RLock()
|
102
|
+
|
103
|
+
@rsynchronized()
|
104
|
+
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
105
|
+
self.check(offsets, tensors)
|
106
|
+
|
107
|
+
# prepare
|
108
|
+
current = 0
|
109
|
+
for offset, tensor in zip(offsets, tensors):
|
110
|
+
size = tensor.numel() * tensor.itemsize
|
111
|
+
self.ior_r.prepare(
|
112
|
+
self.iov_r[current : current + size], True, self.file, offset
|
113
|
+
)
|
114
|
+
current += size
|
115
|
+
|
116
|
+
# submit
|
117
|
+
ionum = len(offsets)
|
118
|
+
resv = self.ior_r.submit().wait(min_results=ionum)
|
119
|
+
|
120
|
+
# results
|
121
|
+
hf3fs_utils.read_shm(self.shm_r_tensor, tensors)
|
122
|
+
results = [res.result for res in resv]
|
123
|
+
|
124
|
+
return results
|
125
|
+
|
126
|
+
@wsynchronized()
|
127
|
+
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
128
|
+
self.check(offsets, tensors)
|
129
|
+
|
130
|
+
# prepare
|
131
|
+
hf3fs_utils.write_shm(tensors, self.shm_w_tensor)
|
132
|
+
current = 0
|
133
|
+
for offset, tensor in zip(offsets, tensors):
|
134
|
+
size = tensor.numel() * tensor.itemsize
|
135
|
+
self.ior_w.prepare(
|
136
|
+
self.iov_w[current : current + size], False, self.file, offset
|
137
|
+
)
|
138
|
+
current += size
|
139
|
+
|
140
|
+
# submit
|
141
|
+
ionum = len(offsets)
|
142
|
+
resv = self.ior_w.submit().wait(min_results=ionum)
|
143
|
+
|
144
|
+
# results
|
145
|
+
results = [res.result for res in resv]
|
146
|
+
|
147
|
+
return results
|
148
|
+
|
149
|
+
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
150
|
+
sizes = [t.numel() * t.itemsize for t in tensors]
|
151
|
+
if any(
|
152
|
+
[
|
153
|
+
len(offsets) > self.entries,
|
154
|
+
len(offsets) != len(sizes),
|
155
|
+
all(
|
156
|
+
[
|
157
|
+
offset < 0 or offset + size > self.size
|
158
|
+
for offset, size in zip(offsets, sizes)
|
159
|
+
]
|
160
|
+
),
|
161
|
+
all([size > self.bytes_per_page for size in sizes]),
|
162
|
+
]
|
163
|
+
):
|
164
|
+
self.close()
|
165
|
+
raise ValueError(f"Hf3fsClient.check: {offsets=}, {sizes=}")
|
166
|
+
|
167
|
+
def get_size(self) -> int:
|
168
|
+
return self.size
|
169
|
+
|
170
|
+
def close(self) -> None:
|
171
|
+
deregister_fd(self.file)
|
172
|
+
os.close(self.file)
|
173
|
+
del self.ior_r
|
174
|
+
del self.ior_w
|
175
|
+
del self.iov_r
|
176
|
+
del self.iov_w
|
177
|
+
self.shm_r.close()
|
178
|
+
self.shm_w.close()
|
179
|
+
self.shm_r.unlink()
|
180
|
+
self.shm_w.unlink()
|
181
|
+
|
182
|
+
def flush(self) -> None:
|
183
|
+
os.fsync(self.file)
|