sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,238 @@
1
+ import logging
2
+ import os
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class NixlBackendSelection:
11
+ """Handles NIXL backend selection and creation."""
12
+
13
+ # Priority order for File-based plugins in case of auto selection
14
+ FILE_PLUGINS = ["3FS", "POSIX", "GDS_MT", "GDS"]
15
+ # Priority order for File-based plugins in case of auto selection (add more as needed)
16
+ OBJ_PLUGINS = ["OBJ"] # Based on Amazon S3 SDK
17
+
18
+ def __init__(self, plugin: str = "auto"):
19
+ """Initialize backend selection.
20
+ Args:
21
+ plugin: Plugin to use (default "auto" selects best available).
22
+ Can be a file plugin (3FS, POSIX, GDS, GDS_MT) or
23
+ an object plugin (OBJ).
24
+ """
25
+ self.plugin = plugin
26
+ self.backend_name = None
27
+ self.mem_type = None
28
+
29
+ def set_bucket(self, bucket_name: str) -> None:
30
+ """Set AWS bucket name in environment variable."""
31
+ os.environ["AWS_DEFAULT_BUCKET"] = bucket_name
32
+ logger.debug(f"Set AWS bucket name to: {bucket_name}")
33
+
34
+ def create_backend(self, agent) -> bool:
35
+ """Create the appropriate NIXL backend based on configuration."""
36
+ try:
37
+ plugin_list = agent.get_plugin_list()
38
+ logger.debug(f"Available NIXL plugins: {plugin_list}")
39
+
40
+ # Handle explicit plugin selection or auto priority
41
+ if self.plugin == "auto":
42
+ # Try all file plugins first
43
+ for plugin in self.FILE_PLUGINS:
44
+ if plugin in plugin_list:
45
+ self.backend_name = plugin
46
+ break
47
+ # If no file plugin found, try object plugins
48
+ if not self.backend_name:
49
+ for plugin in self.OBJ_PLUGINS:
50
+ if plugin in plugin_list:
51
+ self.backend_name = plugin
52
+ break
53
+ else:
54
+ # Use explicitly requested plugin
55
+ self.backend_name = self.plugin
56
+
57
+ if self.backend_name not in plugin_list:
58
+ logger.error(
59
+ f"Backend {self.backend_name} not available in plugins: {plugin_list}"
60
+ )
61
+ return False
62
+
63
+ # Create backend and set memory type
64
+ if self.backend_name in self.OBJ_PLUGINS:
65
+ bucket = os.environ.get("AWS_DEFAULT_BUCKET")
66
+ if not bucket:
67
+ logger.error(
68
+ "AWS_DEFAULT_BUCKET environment variable must be set for object storage"
69
+ )
70
+ return False
71
+ agent.create_backend(self.backend_name, {"bucket": bucket})
72
+ else:
73
+ agent.create_backend(self.backend_name)
74
+
75
+ self.mem_type = "OBJ" if self.backend_name in self.OBJ_PLUGINS else "FILE"
76
+ logger.debug(
77
+ f"Created NIXL backend: {self.backend_name} with memory type: {self.mem_type}"
78
+ )
79
+ return True
80
+
81
+ except Exception as e:
82
+ logger.error(f"Failed to create NIXL backend: {e}")
83
+ return False
84
+
85
+
86
+ class NixlRegistration:
87
+ """Handles NIXL memory registration."""
88
+
89
+ def __init__(self, agent):
90
+ self.agent = agent
91
+
92
+ def create_query_tuples(
93
+ self, key: str, mem_type: str, file_manager=None
94
+ ) -> List[Tuple]:
95
+ """Create NIXL tuples for querying memory.
96
+ Args:
97
+ key: Key to query (file path for FILE or object key for OBJ)
98
+ mem_type: Memory type ("FILE" or "OBJ")
99
+ file_manager: Optional NixlFileManager for FILE memory type
100
+ Returns:
101
+ List of NIXL tuples for querying
102
+ """
103
+ if mem_type == "FILE":
104
+ if file_manager is None:
105
+ logger.error("file_manager required for FILE memory type")
106
+ return []
107
+ return [(0, 0, 0, file_manager.get_file_path(key))]
108
+ else: # OBJ
109
+ return [(0, 0, key)]
110
+
111
+ def _register_memory(
112
+ self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str
113
+ ) -> Optional[Any]:
114
+ """Common registration logic for files, objects, and buffers.
115
+ Args:
116
+ items: List of tuples or tensors to register
117
+ mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM")
118
+ desc: Description for logging
119
+ """
120
+ try:
121
+ if not items:
122
+ return None
123
+
124
+ reg_descs = self.agent.get_reg_descs(items, mem_type)
125
+ if reg_descs is None:
126
+ logger.error("Failed to create registration descriptors")
127
+ return None
128
+
129
+ registered_memory = self.agent.register_memory(reg_descs)
130
+ if registered_memory:
131
+ return registered_memory
132
+ else:
133
+ logger.error("Failed to register with NIXL")
134
+ return None
135
+
136
+ except Exception as e:
137
+ logger.error(f"Failed to register {desc}: {e}")
138
+ return None
139
+
140
+ def register_buffers(
141
+ self, buffers: Union[torch.Tensor, List[torch.Tensor]]
142
+ ) -> Optional[Any]:
143
+ """Register tensors/buffers with NIXL."""
144
+ if isinstance(buffers, torch.Tensor):
145
+ buffers = [buffers]
146
+
147
+ if not buffers:
148
+ return None
149
+
150
+ # Determine memory type based on tensor device
151
+ mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM"
152
+ return self._register_memory(buffers, mem_type, "buffers")
153
+
154
+ def register_files(self, tuples: List[tuple]) -> Optional[Any]:
155
+ """Register files with NIXL using (0, 0, fd, file_path) tuples."""
156
+ return self._register_memory(tuples, "FILE", "files")
157
+
158
+ def register_objects(
159
+ self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None
160
+ ) -> Optional[Any]:
161
+ """Register objects with NIXL."""
162
+ if not keys:
163
+ return None
164
+
165
+ # Create object tuples with proper sizes
166
+ tuples = [
167
+ (0, tensor.element_size() * tensor.numel() if tensor else 0, key)
168
+ for key, tensor in zip(keys, tensors or [None] * len(keys))
169
+ ]
170
+ return self._register_memory(tuples, "OBJ", "objects")
171
+
172
+
173
+ class NixlFileManager:
174
+ """Handles file system operations for NIXL."""
175
+
176
+ def __init__(self, base_dir: str):
177
+ """
178
+ Initialize file manager.
179
+ Args:
180
+ base_dir: Base directory for storing tensor files
181
+ """
182
+ self.base_dir = base_dir
183
+ if base_dir == "":
184
+ logger.debug(f"Initialized file manager without a base directory")
185
+ else:
186
+ os.makedirs(base_dir, exist_ok=True)
187
+ logger.debug(f"Initialized file manager with base directory: {base_dir}")
188
+
189
+ def get_file_path(self, key: str) -> str:
190
+ """Get full file path for a given key."""
191
+ return os.path.join(self.base_dir, key)
192
+
193
+ def create_file(self, file_path: str) -> bool:
194
+ """Create a file if it doesn't exist."""
195
+ try:
196
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
197
+ if not os.path.exists(file_path):
198
+ with open(file_path, "wb") as f:
199
+ pass # Create empty file
200
+ return True
201
+ except Exception as e:
202
+ logger.error(f"Failed to create file {file_path}: {e}")
203
+ return False
204
+
205
+ def open_file(self, file_path: str) -> Optional[int]:
206
+ """Open a file and return its file descriptor."""
207
+ try:
208
+ fd = os.open(file_path, os.O_RDWR)
209
+ return fd
210
+ except Exception as e:
211
+ logger.error(f"Failed to open file {file_path}: {e}")
212
+ return None
213
+
214
+ def close_file(self, fd: int) -> bool:
215
+ """Close a file descriptor."""
216
+ try:
217
+ os.close(fd)
218
+ return True
219
+ except Exception as e:
220
+ logger.error(f"Failed to close file descriptor {fd}: {e}")
221
+ return False
222
+
223
+ def files_to_nixl_tuples(
224
+ self, file_paths: List[str], open_file: bool = True
225
+ ) -> List[Tuple[int, int, int, str]]:
226
+ """Create NIXL tuples (offset, length, fd, file_path) for given files."""
227
+ if not open_file:
228
+ return [(0, 0, 0, path) for path in file_paths]
229
+
230
+ tuples = []
231
+ for path in file_paths:
232
+ if (fd := self.open_file(path)) is None:
233
+ # Clean up on failure
234
+ for t in tuples:
235
+ self.close_file(t[2])
236
+ return []
237
+ tuples.append((0, 0, fd, path))
238
+ return tuples
@@ -0,0 +1,216 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import unittest
5
+ from typing import List, Optional
6
+ from unittest.mock import MagicMock
7
+
8
+ import torch
9
+
10
+ from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
11
+ from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
12
+
13
+
14
+ class TestNixlUnified(unittest.TestCase):
15
+ """Unified test suite for all NIXL components."""
16
+
17
+ def setUp(self):
18
+ """Set up test environment."""
19
+ # Create test directories
20
+ self.test_dir = "/tmp/test_nixl_unified"
21
+ os.makedirs(self.test_dir, exist_ok=True)
22
+
23
+ # Mock NIXL agent for registration tests
24
+ self.mock_agent = MagicMock()
25
+ self.mock_agent.get_reg_descs.return_value = "mock_reg_descs"
26
+ self.mock_agent.register_memory.return_value = "mock_registered_memory"
27
+
28
+ # Create instances
29
+ self.file_manager = NixlFileManager(self.test_dir)
30
+ self.registration = NixlRegistration(self.mock_agent)
31
+ try:
32
+ self.hicache = HiCacheNixl(file_path=self.test_dir, plugin="POSIX")
33
+ except ImportError:
34
+ self.skipTest("NIXL not available, skipping NIXL storage tests")
35
+
36
+ def tearDown(self):
37
+ """Clean up test directories."""
38
+ if os.path.exists(self.test_dir):
39
+ import shutil
40
+
41
+ shutil.rmtree(self.test_dir)
42
+
43
+ def delete_test_file(self, file_path: str) -> bool:
44
+ """Helper method to delete a test file.
45
+
46
+ Args:
47
+ file_path: Path to the file to delete
48
+
49
+ Returns:
50
+ bool: True if file was deleted or didn't exist, False on error
51
+ """
52
+ try:
53
+ if os.path.exists(file_path):
54
+ os.remove(file_path)
55
+ return True
56
+ except Exception as e:
57
+ return False
58
+
59
+ def verify_tensors_equal(self, expected: torch.Tensor, actual: torch.Tensor):
60
+ """Helper to verify tensor equality."""
61
+ self.assertIsNotNone(actual, "Retrieved tensor is None")
62
+ self.assertTrue(
63
+ torch.allclose(expected, actual, atol=1e-6),
64
+ f"Tensors not equal:\nExpected: {expected}\nActual: {actual}",
65
+ )
66
+
67
+ def verify_tensor_lists_equal(
68
+ self, expected: List[torch.Tensor], actual: List[torch.Tensor]
69
+ ):
70
+ """Helper to verify lists of tensors are equal."""
71
+ self.assertEqual(len(expected), len(actual), "Lists have different lengths")
72
+ for exp, act in zip(expected, actual):
73
+ self.verify_tensors_equal(exp, act)
74
+
75
+ # ============================================================================
76
+ # HiCache Integration Tests
77
+ # ============================================================================
78
+
79
+ def test_single_set_get(self):
80
+ """Test single tensor set/get operations."""
81
+ key = "test_key"
82
+ value = torch.randn(10, 10, device="cpu")
83
+ dst_tensor = torch.zeros_like(value, device="cpu")
84
+
85
+ # Test set
86
+ self.assertTrue(self.hicache.set(key, value))
87
+ self.assertTrue(self.hicache.exists(key))
88
+
89
+ # Test get
90
+ retrieved = self.hicache.get(key, dst_tensor)
91
+ self.verify_tensors_equal(value, retrieved)
92
+
93
+ def test_batch_set_get(self):
94
+ """Test batch tensor set/get operations."""
95
+ keys = ["key1", "key2", "key3"]
96
+ values = [
97
+ torch.randn(5, 5, device="cpu"),
98
+ torch.randn(3, 3, device="cpu"),
99
+ torch.randn(7, 7, device="cpu"),
100
+ ]
101
+ dst_tensors = [torch.zeros_like(v, device="cpu") for v in values]
102
+
103
+ # Test batch set
104
+ self.assertTrue(self.hicache.batch_set(keys, values))
105
+ self.assertTrue(all(self.hicache.exists(key) for key in keys))
106
+
107
+ # Test batch get
108
+ retrieved = self.hicache.batch_get(keys, dst_tensors)
109
+ self.verify_tensor_lists_equal(values, retrieved)
110
+
111
+ def test_mixed_operations(self):
112
+ """Test mixing single and batch operations."""
113
+ # Test interleaved set/get operations
114
+ key1, key2 = "key1", "key2"
115
+ value1 = torch.randn(4, 4, device="cpu")
116
+ value2 = torch.randn(6, 6, device="cpu")
117
+ dst1 = torch.zeros_like(value1)
118
+ dst2 = torch.zeros_like(value2)
119
+
120
+ # Single set/get
121
+ self.assertTrue(self.hicache.set(key1, value1))
122
+ retrieved1 = self.hicache.get(key1, dst1)
123
+ self.verify_tensors_equal(value1, retrieved1)
124
+
125
+ # Batch set/get
126
+ self.assertTrue(self.hicache.batch_set([key2], [value2]))
127
+ retrieved2 = self.hicache.batch_get([key2], [dst2])
128
+ self.verify_tensors_equal(value2, retrieved2[0])
129
+
130
+ def test_data_integrity(self):
131
+ """Test data integrity across operations."""
132
+ # Test with various tensor types and sizes
133
+ test_cases = [
134
+ ("float32", torch.randn(10, 10, dtype=torch.float32)),
135
+ ("float64", torch.randn(5, 5, dtype=torch.float64)),
136
+ ("int32", torch.randint(-100, 100, (8, 8), dtype=torch.int32)),
137
+ ("int64", torch.randint(-100, 100, (6, 6), dtype=torch.int64)),
138
+ ("bool", torch.randint(0, 2, (4, 4)).bool()),
139
+ ]
140
+
141
+ for name, tensor in test_cases:
142
+ with self.subTest(tensor_type=name):
143
+ key = f"test_{name}"
144
+ dst_tensor = torch.zeros_like(tensor)
145
+
146
+ # Set and immediately get
147
+ self.assertTrue(self.hicache.set(key, tensor))
148
+ retrieved1 = self.hicache.get(key, dst_tensor)
149
+ self.verify_tensors_equal(tensor, retrieved1)
150
+
151
+ # Get again to verify persistence
152
+ dst_tensor.zero_()
153
+ retrieved2 = self.hicache.get(key, dst_tensor)
154
+ self.verify_tensors_equal(tensor, retrieved2)
155
+
156
+ def test_basic_file_operations(self):
157
+ """Test basic file operations."""
158
+ test_file = os.path.join(self.test_dir, "test_file.bin")
159
+ self.file_manager.create_file(test_file)
160
+ self.assertTrue(os.path.exists(test_file))
161
+ self.assertEqual(os.path.getsize(test_file), 0) # Empty file
162
+
163
+ # Test file deletion
164
+ self.assertTrue(self.delete_test_file(test_file))
165
+ self.assertFalse(os.path.exists(test_file))
166
+
167
+ def test_create_nixl_tuples(self):
168
+ """Test creation of NIXL tuples."""
169
+ test_file = os.path.join(self.test_dir, "test_file.bin")
170
+ self.file_manager.create_file(test_file)
171
+
172
+ # Test tuple creation
173
+ tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
174
+ self.assertIsNotNone(tuples)
175
+ self.assertTrue(len(tuples) > 0)
176
+
177
+ def test_error_handling(self):
178
+ """Test error handling in file operations."""
179
+ # Test non-existent file
180
+ self.assertTrue(
181
+ self.delete_test_file("nonexistent_file.bin")
182
+ ) # Returns True if file doesn't exist
183
+
184
+ # Test invalid file path
185
+ self.assertFalse(self.file_manager.create_file("")) # Empty path should fail
186
+
187
+ def test_register_buffers(self):
188
+ """Test registration of memory buffers."""
189
+ # Create test tensor
190
+ tensor = torch.randn(10, 10)
191
+
192
+ # Test buffer registration
193
+ self.assertIsNotNone(self.registration.register_buffers(tensor))
194
+
195
+ # Test batch registration
196
+ tensors = [torch.randn(5, 5) for _ in range(3)]
197
+ self.assertIsNotNone(self.registration.register_buffers(tensors))
198
+
199
+ def test_register_files_with_tuples(self):
200
+ """Test registration of files using NIXL tuples."""
201
+ files = [os.path.join(self.test_dir, f"test_file_{i}.bin") for i in range(3)]
202
+ for file in files:
203
+ self.file_manager.create_file(file)
204
+
205
+ # Create tuples and register
206
+ tuples = self.file_manager.files_to_nixl_tuples(files, False)
207
+ self.registration.register_files(tuples)
208
+
209
+ # Verify tuples
210
+ self.assertEqual(len(tuples), len(files))
211
+ for t, f in zip(tuples, files):
212
+ self.assertEqual(t[3], f) # Check file path
213
+
214
+
215
+ if __name__ == "__main__":
216
+ unittest.main()
@@ -0,0 +1,183 @@
1
+ import logging
2
+ import multiprocessing
3
+ import os
4
+ import threading
5
+ from functools import wraps
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import torch
10
+ from torch.utils.cpp_extension import load
11
+
12
+ root = Path(__file__).parent.resolve()
13
+ hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ HF3FS_AVAILABLE = True
18
+ try:
19
+ from hf3fs_fuse.io import (
20
+ deregister_fd,
21
+ extract_mount_point,
22
+ make_ioring,
23
+ make_iovec,
24
+ register_fd,
25
+ )
26
+ except ImportError:
27
+ HF3FS_AVAILABLE = False
28
+
29
+
30
+ def rsynchronized():
31
+ def _decorator(func):
32
+ @wraps(func)
33
+ def wrapper(self, *args, **kwargs):
34
+ with self.rlock:
35
+ return func(self, *args, **kwargs)
36
+
37
+ return wrapper
38
+
39
+ return _decorator
40
+
41
+
42
+ def wsynchronized():
43
+ def _decorator(func):
44
+ @wraps(func)
45
+ def wrapper(self, *args, **kwargs):
46
+ with self.wlock:
47
+ return func(self, *args, **kwargs)
48
+
49
+ return wrapper
50
+
51
+ return _decorator
52
+
53
+
54
+ class Hf3fsClient:
55
+ def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
56
+ if not HF3FS_AVAILABLE:
57
+ raise ImportError(
58
+ "hf3fs_fuse.io is not available. Please install the hf3fs_fuse package."
59
+ )
60
+
61
+ self.path = path
62
+ self.size = size
63
+ self.bytes_per_page = bytes_per_page
64
+ self.entries = entries
65
+
66
+ self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
67
+ os.ftruncate(self.file, size)
68
+ register_fd(self.file)
69
+
70
+ self.hf3fs_mount_point = extract_mount_point(path)
71
+ self.bs = self.bytes_per_page
72
+ self.shm_r = multiprocessing.shared_memory.SharedMemory(
73
+ size=self.bs * self.entries, create=True
74
+ )
75
+ self.shm_w = multiprocessing.shared_memory.SharedMemory(
76
+ size=self.bs * self.entries, create=True
77
+ )
78
+
79
+ self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8)
80
+ self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8)
81
+
82
+ self.numa = -1
83
+ self.ior_r = make_ioring(
84
+ self.hf3fs_mount_point,
85
+ self.entries,
86
+ for_read=True,
87
+ timeout=1,
88
+ numa=self.numa,
89
+ )
90
+ self.ior_w = make_ioring(
91
+ self.hf3fs_mount_point,
92
+ self.entries,
93
+ for_read=False,
94
+ timeout=1,
95
+ numa=self.numa,
96
+ )
97
+ self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
98
+ self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
99
+
100
+ self.rlock = threading.RLock()
101
+ self.wlock = threading.RLock()
102
+
103
+ @rsynchronized()
104
+ def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
105
+ self.check(offsets, tensors)
106
+
107
+ # prepare
108
+ current = 0
109
+ for offset, tensor in zip(offsets, tensors):
110
+ size = tensor.numel() * tensor.itemsize
111
+ self.ior_r.prepare(
112
+ self.iov_r[current : current + size], True, self.file, offset
113
+ )
114
+ current += size
115
+
116
+ # submit
117
+ ionum = len(offsets)
118
+ resv = self.ior_r.submit().wait(min_results=ionum)
119
+
120
+ # results
121
+ hf3fs_utils.read_shm(self.shm_r_tensor, tensors)
122
+ results = [res.result for res in resv]
123
+
124
+ return results
125
+
126
+ @wsynchronized()
127
+ def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
128
+ self.check(offsets, tensors)
129
+
130
+ # prepare
131
+ hf3fs_utils.write_shm(tensors, self.shm_w_tensor)
132
+ current = 0
133
+ for offset, tensor in zip(offsets, tensors):
134
+ size = tensor.numel() * tensor.itemsize
135
+ self.ior_w.prepare(
136
+ self.iov_w[current : current + size], False, self.file, offset
137
+ )
138
+ current += size
139
+
140
+ # submit
141
+ ionum = len(offsets)
142
+ resv = self.ior_w.submit().wait(min_results=ionum)
143
+
144
+ # results
145
+ results = [res.result for res in resv]
146
+
147
+ return results
148
+
149
+ def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
150
+ sizes = [t.numel() * t.itemsize for t in tensors]
151
+ if any(
152
+ [
153
+ len(offsets) > self.entries,
154
+ len(offsets) != len(sizes),
155
+ all(
156
+ [
157
+ offset < 0 or offset + size > self.size
158
+ for offset, size in zip(offsets, sizes)
159
+ ]
160
+ ),
161
+ all([size > self.bytes_per_page for size in sizes]),
162
+ ]
163
+ ):
164
+ self.close()
165
+ raise ValueError(f"Hf3fsClient.check: {offsets=}, {sizes=}")
166
+
167
+ def get_size(self) -> int:
168
+ return self.size
169
+
170
+ def close(self) -> None:
171
+ deregister_fd(self.file)
172
+ os.close(self.file)
173
+ del self.ior_r
174
+ del self.ior_w
175
+ del self.iov_r
176
+ del self.iov_w
177
+ self.shm_r.close()
178
+ self.shm_w.close()
179
+ self.shm_r.unlink()
180
+ self.shm_w.unlink()
181
+
182
+ def flush(self) -> None:
183
+ os.fsync(self.file)