TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,339 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+ from pathlib import Path
17
+ from unittest.mock import AsyncMock, Mock, patch
18
+
19
+ import pytest
20
+ import pytest_asyncio
21
+ import torch
22
+ import zmq
23
+ from tensordict import TensorDict
24
+
25
+ # Setup path
26
+ parent_dir = Path(__file__).resolve().parent.parent
27
+ sys.path.append(str(parent_dir))
28
+
29
+ from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402
30
+ from transfer_queue.storage import AsyncSimpleStorageManager # noqa: E402
31
+ from transfer_queue.utils.utils import TransferQueueRole # noqa: E402
32
+ from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo # noqa: E402
33
+
34
+
35
+ @pytest_asyncio.fixture
36
+ async def mock_async_storage_manager():
37
+ """Create a mock AsyncSimpleStorageManager for testing."""
38
+
39
+ # Mock storage unit infos
40
+ storage_unit_infos = {
41
+ "storage_0": ZMQServerInfo(
42
+ role=TransferQueueRole.STORAGE,
43
+ id="storage_0",
44
+ ip="127.0.0.1",
45
+ ports={"put_get_socket": 12345},
46
+ ),
47
+ "storage_1": ZMQServerInfo(
48
+ role=TransferQueueRole.STORAGE,
49
+ id="storage_1",
50
+ ip="127.0.0.1",
51
+ ports={"put_get_socket": 12346},
52
+ ),
53
+ }
54
+
55
+ # Mock controller info
56
+ controller_info = ZMQServerInfo(
57
+ role=TransferQueueRole.CONTROLLER,
58
+ id="controller_0",
59
+ ip="127.0.0.1",
60
+ ports={"handshake_socket": 12347, "data_status_update_socket": 12348},
61
+ )
62
+
63
+ config = {
64
+ "storage_unit_infos": storage_unit_infos,
65
+ "controller_info": controller_info,
66
+ }
67
+
68
+ # Mock the handshake process entirely to avoid ZMQ complexity
69
+ with patch(
70
+ "transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"
71
+ ) as mock_connect:
72
+ # Mock the manager without actually connecting
73
+ manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager)
74
+ manager.storage_manager_id = "test_storage_manager"
75
+ manager.config = config
76
+ manager.controller_info = controller_info
77
+ manager.storage_unit_infos = storage_unit_infos
78
+ manager.data_status_update_socket = None
79
+ manager.controller_handshake_socket = None
80
+ manager.zmq_context = None
81
+
82
+ # Add mapping functions
83
+ storage_unit_keys = list(storage_unit_infos.keys())
84
+ manager.global_index_storage_unit_mapping = lambda x: storage_unit_keys[x % len(storage_unit_keys)]
85
+ manager.global_index_local_index_mapping = lambda x: x // len(storage_unit_keys)
86
+
87
+ # Mock essential methods
88
+ manager._connect_to_controller = mock_connect
89
+
90
+ yield manager
91
+
92
+
93
+ @pytest.mark.asyncio
94
+ async def test_async_storage_manager_initialization(mock_async_storage_manager):
95
+ """Test AsyncSimpleStorageManager initialization."""
96
+ manager = mock_async_storage_manager
97
+
98
+ # Test basic properties
99
+ assert len(manager.storage_unit_infos) == 2
100
+ assert "storage_0" in manager.storage_unit_infos
101
+ assert "storage_1" in manager.storage_unit_infos
102
+
103
+ # Test mapping functions
104
+ assert manager.global_index_storage_unit_mapping(0) == "storage_0"
105
+ assert manager.global_index_storage_unit_mapping(1) == "storage_1"
106
+ assert manager.global_index_local_index_mapping(0) == 0
107
+ assert manager.global_index_local_index_mapping(3) == 1
108
+
109
+
110
+ @pytest.mark.asyncio
111
+ async def test_async_storage_manager_mock_operations(mock_async_storage_manager):
112
+ """Test AsyncSimpleStorageManager operations with mocked ZMQ."""
113
+ manager = mock_async_storage_manager
114
+
115
+ # Create test metadata
116
+ sample_metas = [
117
+ SampleMeta(
118
+ partition_id="0",
119
+ global_index=0,
120
+ fields={
121
+ "test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)),
122
+ },
123
+ ),
124
+ SampleMeta(
125
+ partition_id="0",
126
+ global_index=1,
127
+ fields={
128
+ "test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)),
129
+ },
130
+ ),
131
+ ]
132
+ batch_meta = BatchMeta(samples=sample_metas)
133
+
134
+ # Create test data
135
+ test_data = TensorDict(
136
+ {
137
+ "test_field": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])],
138
+ },
139
+ batch_size=2,
140
+ )
141
+
142
+ manager._put_to_single_storage_unit = AsyncMock()
143
+ manager._get_from_single_storage_unit = AsyncMock(
144
+ return_value=([0, 1], ["test_field"], {"test_field": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]})
145
+ )
146
+ manager._clear_single_storage_unit = AsyncMock()
147
+ manager.notify_data_update = AsyncMock()
148
+
149
+ # Test put_data (should not raise exceptions)
150
+ await manager.put_data(test_data, batch_meta)
151
+ manager.notify_data_update.assert_awaited_once()
152
+
153
+ # Test get_data
154
+ retrieved_data = await manager.get_data(batch_meta)
155
+ assert "test_field" in retrieved_data
156
+
157
+ # Test clear_data
158
+ await manager.clear_data(batch_meta)
159
+
160
+
161
+ @pytest.mark.asyncio
162
+ async def test_async_storage_manager_mapping_functions():
163
+ """Test AsyncSimpleStorageManager mapping functions."""
164
+
165
+ # Mock storage unit infos
166
+ storage_unit_infos = {
167
+ "storage_0": ZMQServerInfo(
168
+ role=TransferQueueRole.STORAGE,
169
+ id="storage_0",
170
+ ip="127.0.0.1",
171
+ ports={"put_get_socket": 12345},
172
+ ),
173
+ "storage_1": ZMQServerInfo(
174
+ role=TransferQueueRole.STORAGE,
175
+ id="storage_1",
176
+ ip="127.0.0.1",
177
+ ports={"put_get_socket": 12346},
178
+ ),
179
+ "storage_2": ZMQServerInfo(
180
+ role=TransferQueueRole.STORAGE,
181
+ id="storage_2",
182
+ ip="127.0.0.1",
183
+ ports={"put_get_socket": 12347},
184
+ ),
185
+ }
186
+
187
+ # Mock controller info
188
+ controller_info = ZMQServerInfo(
189
+ role=TransferQueueRole.CONTROLLER,
190
+ id="controller_0",
191
+ ip="127.0.0.1",
192
+ ports={"handshake_socket": 12348, "data_status_update_socket": 12349},
193
+ )
194
+
195
+ config = {
196
+ "storage_unit_infos": storage_unit_infos,
197
+ "controller_info": controller_info,
198
+ }
199
+
200
+ # Mock ZMQ operations
201
+ with (
202
+ patch("transfer_queue.storage.managers.base.create_zmq_socket") as mock_create_socket,
203
+ patch("zmq.Poller") as mock_poller,
204
+ ):
205
+ # Create mock socket with proper sync methods
206
+ mock_socket = Mock()
207
+ mock_socket.connect = Mock() # sync method
208
+ mock_socket.send = Mock() # sync method
209
+ mock_create_socket.return_value = mock_socket
210
+
211
+ # Mock poller with sync methods
212
+ mock_poller_instance = Mock()
213
+ mock_poller_instance.register = Mock() # sync method
214
+ # Return mock socket in poll to simulate handshake response
215
+ mock_poller_instance.poll = Mock(return_value=[(mock_socket, zmq.POLLIN)]) # sync method
216
+ mock_poller.return_value = mock_poller_instance
217
+
218
+ # Mock handshake response
219
+ handshake_response = ZMQMessage.create(
220
+ request_type=ZMQRequestType.HANDSHAKE_ACK,
221
+ sender_id="controller_0",
222
+ body={"message": "Handshake successful"},
223
+ )
224
+ mock_socket.recv = Mock(return_value=handshake_response.serialize())
225
+
226
+ # Create manager
227
+ manager = AsyncSimpleStorageManager(config)
228
+
229
+ # Test round-robin mapping for 3 storage units
230
+ # global_index -> storage_unit mapping: 0->storage_0, 1->storage_1, 2->storage_2,
231
+ # 3->storage_0, 4->storage_1, ...
232
+ assert manager.global_index_storage_unit_mapping(0) == "storage_0"
233
+ assert manager.global_index_storage_unit_mapping(1) == "storage_1"
234
+ assert manager.global_index_storage_unit_mapping(2) == "storage_2"
235
+ assert manager.global_index_storage_unit_mapping(3) == "storage_0"
236
+ assert manager.global_index_storage_unit_mapping(4) == "storage_1"
237
+ assert manager.global_index_storage_unit_mapping(5) == "storage_2"
238
+
239
+ # global_index -> local_index mapping: global_index // num_storage_units
240
+ assert manager.global_index_local_index_mapping(0) == 0
241
+ assert manager.global_index_local_index_mapping(1) == 0
242
+ assert manager.global_index_local_index_mapping(2) == 0
243
+ assert manager.global_index_local_index_mapping(3) == 1
244
+ assert manager.global_index_local_index_mapping(4) == 1
245
+ assert manager.global_index_local_index_mapping(5) == 1
246
+
247
+
248
+ @pytest.mark.asyncio
249
+ async def test_async_storage_manager_error_handling():
250
+ """Test AsyncSimpleStorageManager error handling."""
251
+
252
+ # Mock storage unit infos
253
+ storage_unit_infos = {
254
+ "storage_0": ZMQServerInfo(
255
+ role=TransferQueueRole.STORAGE,
256
+ id="storage_0",
257
+ ip="127.0.0.1",
258
+ ports={"put_get_socket": 12345},
259
+ ),
260
+ }
261
+
262
+ # Mock controller info
263
+ controller_infos = ZMQServerInfo(
264
+ role=TransferQueueRole.CONTROLLER,
265
+ id="controller_0",
266
+ ip="127.0.0.1",
267
+ ports={"handshake_socket": 12346, "data_status_update_socket": 12347},
268
+ )
269
+
270
+ config = {
271
+ "storage_unit_infos": storage_unit_infos,
272
+ "controller_info": controller_infos,
273
+ }
274
+
275
+ # Mock ZMQ operations
276
+ with (
277
+ patch("transfer_queue.storage.managers.base.create_zmq_socket") as mock_create_socket,
278
+ patch("zmq.Poller") as mock_poller,
279
+ ):
280
+ # Create mock socket with proper sync methods
281
+ mock_socket = Mock()
282
+ mock_socket.connect = Mock() # sync method
283
+ mock_socket.send = Mock() # sync method
284
+ mock_create_socket.return_value = mock_socket
285
+
286
+ # Mock poller with sync methods
287
+ mock_poller_instance = Mock()
288
+ mock_poller_instance.register = Mock() # sync method
289
+ # Return mock socket in poll to simulate handshake response
290
+ mock_poller_instance.poll = Mock(return_value=[(mock_socket, zmq.POLLIN)]) # sync method
291
+ mock_poller.return_value = mock_poller_instance
292
+
293
+ # Mock handshake response
294
+ handshake_response = ZMQMessage.create(
295
+ request_type=ZMQRequestType.HANDSHAKE_ACK,
296
+ sender_id="controller_0",
297
+ body={"message": "Handshake successful"},
298
+ )
299
+ mock_socket.recv = Mock(return_value=handshake_response.serialize())
300
+
301
+ # Create manager
302
+ manager = AsyncSimpleStorageManager(config)
303
+
304
+ # Mock operations that raise exceptions
305
+ manager._put_to_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock PUT error"))
306
+ manager._get_from_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock GET error"))
307
+ manager._clear_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock CLEAR error"))
308
+ manager.notify_data_update = AsyncMock()
309
+
310
+ # Create test metadata
311
+ sample_metas = [
312
+ SampleMeta(
313
+ partition_id="0",
314
+ global_index=0,
315
+ fields={
316
+ "test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)),
317
+ },
318
+ ),
319
+ ]
320
+ batch_meta = BatchMeta(samples=sample_metas)
321
+
322
+ # Create test data
323
+ test_data = TensorDict(
324
+ {
325
+ "test_field": [torch.tensor([1.0, 2.0])],
326
+ },
327
+ batch_size=1,
328
+ )
329
+
330
+ # Test that exceptions are properly raised
331
+ with pytest.raises(RuntimeError, match="Mock PUT error"):
332
+ await manager.put_data(test_data, batch_meta)
333
+
334
+ with pytest.raises(RuntimeError, match="Mock GET error"):
335
+ await manager.get_data(batch_meta)
336
+
337
+ # Note: clear_data uses return_exceptions=True, so it doesn't raise exceptions directly
338
+ # Instead, we can verify that the clear operation was attempted
339
+ await manager.clear_data(batch_meta) # Should not raise due to return_exceptions=True