TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import sys
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from unittest.mock import AsyncMock, Mock, patch
|
|
18
|
+
|
|
19
|
+
import pytest
|
|
20
|
+
import pytest_asyncio
|
|
21
|
+
import torch
|
|
22
|
+
import zmq
|
|
23
|
+
from tensordict import TensorDict
|
|
24
|
+
|
|
25
|
+
# Setup path
|
|
26
|
+
parent_dir = Path(__file__).resolve().parent.parent
|
|
27
|
+
sys.path.append(str(parent_dir))
|
|
28
|
+
|
|
29
|
+
from transfer_queue.metadata import BatchMeta, FieldMeta, SampleMeta # noqa: E402
|
|
30
|
+
from transfer_queue.storage import AsyncSimpleStorageManager # noqa: E402
|
|
31
|
+
from transfer_queue.utils.utils import TransferQueueRole # noqa: E402
|
|
32
|
+
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo # noqa: E402
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest_asyncio.fixture
|
|
36
|
+
async def mock_async_storage_manager():
|
|
37
|
+
"""Create a mock AsyncSimpleStorageManager for testing."""
|
|
38
|
+
|
|
39
|
+
# Mock storage unit infos
|
|
40
|
+
storage_unit_infos = {
|
|
41
|
+
"storage_0": ZMQServerInfo(
|
|
42
|
+
role=TransferQueueRole.STORAGE,
|
|
43
|
+
id="storage_0",
|
|
44
|
+
ip="127.0.0.1",
|
|
45
|
+
ports={"put_get_socket": 12345},
|
|
46
|
+
),
|
|
47
|
+
"storage_1": ZMQServerInfo(
|
|
48
|
+
role=TransferQueueRole.STORAGE,
|
|
49
|
+
id="storage_1",
|
|
50
|
+
ip="127.0.0.1",
|
|
51
|
+
ports={"put_get_socket": 12346},
|
|
52
|
+
),
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
# Mock controller info
|
|
56
|
+
controller_info = ZMQServerInfo(
|
|
57
|
+
role=TransferQueueRole.CONTROLLER,
|
|
58
|
+
id="controller_0",
|
|
59
|
+
ip="127.0.0.1",
|
|
60
|
+
ports={"handshake_socket": 12347, "data_status_update_socket": 12348},
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
config = {
|
|
64
|
+
"storage_unit_infos": storage_unit_infos,
|
|
65
|
+
"controller_info": controller_info,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
# Mock the handshake process entirely to avoid ZMQ complexity
|
|
69
|
+
with patch(
|
|
70
|
+
"transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"
|
|
71
|
+
) as mock_connect:
|
|
72
|
+
# Mock the manager without actually connecting
|
|
73
|
+
manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager)
|
|
74
|
+
manager.storage_manager_id = "test_storage_manager"
|
|
75
|
+
manager.config = config
|
|
76
|
+
manager.controller_info = controller_info
|
|
77
|
+
manager.storage_unit_infos = storage_unit_infos
|
|
78
|
+
manager.data_status_update_socket = None
|
|
79
|
+
manager.controller_handshake_socket = None
|
|
80
|
+
manager.zmq_context = None
|
|
81
|
+
|
|
82
|
+
# Add mapping functions
|
|
83
|
+
storage_unit_keys = list(storage_unit_infos.keys())
|
|
84
|
+
manager.global_index_storage_unit_mapping = lambda x: storage_unit_keys[x % len(storage_unit_keys)]
|
|
85
|
+
manager.global_index_local_index_mapping = lambda x: x // len(storage_unit_keys)
|
|
86
|
+
|
|
87
|
+
# Mock essential methods
|
|
88
|
+
manager._connect_to_controller = mock_connect
|
|
89
|
+
|
|
90
|
+
yield manager
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@pytest.mark.asyncio
|
|
94
|
+
async def test_async_storage_manager_initialization(mock_async_storage_manager):
|
|
95
|
+
"""Test AsyncSimpleStorageManager initialization."""
|
|
96
|
+
manager = mock_async_storage_manager
|
|
97
|
+
|
|
98
|
+
# Test basic properties
|
|
99
|
+
assert len(manager.storage_unit_infos) == 2
|
|
100
|
+
assert "storage_0" in manager.storage_unit_infos
|
|
101
|
+
assert "storage_1" in manager.storage_unit_infos
|
|
102
|
+
|
|
103
|
+
# Test mapping functions
|
|
104
|
+
assert manager.global_index_storage_unit_mapping(0) == "storage_0"
|
|
105
|
+
assert manager.global_index_storage_unit_mapping(1) == "storage_1"
|
|
106
|
+
assert manager.global_index_local_index_mapping(0) == 0
|
|
107
|
+
assert manager.global_index_local_index_mapping(3) == 1
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@pytest.mark.asyncio
|
|
111
|
+
async def test_async_storage_manager_mock_operations(mock_async_storage_manager):
|
|
112
|
+
"""Test AsyncSimpleStorageManager operations with mocked ZMQ."""
|
|
113
|
+
manager = mock_async_storage_manager
|
|
114
|
+
|
|
115
|
+
# Create test metadata
|
|
116
|
+
sample_metas = [
|
|
117
|
+
SampleMeta(
|
|
118
|
+
partition_id="0",
|
|
119
|
+
global_index=0,
|
|
120
|
+
fields={
|
|
121
|
+
"test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)),
|
|
122
|
+
},
|
|
123
|
+
),
|
|
124
|
+
SampleMeta(
|
|
125
|
+
partition_id="0",
|
|
126
|
+
global_index=1,
|
|
127
|
+
fields={
|
|
128
|
+
"test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)),
|
|
129
|
+
},
|
|
130
|
+
),
|
|
131
|
+
]
|
|
132
|
+
batch_meta = BatchMeta(samples=sample_metas)
|
|
133
|
+
|
|
134
|
+
# Create test data
|
|
135
|
+
test_data = TensorDict(
|
|
136
|
+
{
|
|
137
|
+
"test_field": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])],
|
|
138
|
+
},
|
|
139
|
+
batch_size=2,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
manager._put_to_single_storage_unit = AsyncMock()
|
|
143
|
+
manager._get_from_single_storage_unit = AsyncMock(
|
|
144
|
+
return_value=([0, 1], ["test_field"], {"test_field": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]})
|
|
145
|
+
)
|
|
146
|
+
manager._clear_single_storage_unit = AsyncMock()
|
|
147
|
+
manager.notify_data_update = AsyncMock()
|
|
148
|
+
|
|
149
|
+
# Test put_data (should not raise exceptions)
|
|
150
|
+
await manager.put_data(test_data, batch_meta)
|
|
151
|
+
manager.notify_data_update.assert_awaited_once()
|
|
152
|
+
|
|
153
|
+
# Test get_data
|
|
154
|
+
retrieved_data = await manager.get_data(batch_meta)
|
|
155
|
+
assert "test_field" in retrieved_data
|
|
156
|
+
|
|
157
|
+
# Test clear_data
|
|
158
|
+
await manager.clear_data(batch_meta)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@pytest.mark.asyncio
|
|
162
|
+
async def test_async_storage_manager_mapping_functions():
|
|
163
|
+
"""Test AsyncSimpleStorageManager mapping functions."""
|
|
164
|
+
|
|
165
|
+
# Mock storage unit infos
|
|
166
|
+
storage_unit_infos = {
|
|
167
|
+
"storage_0": ZMQServerInfo(
|
|
168
|
+
role=TransferQueueRole.STORAGE,
|
|
169
|
+
id="storage_0",
|
|
170
|
+
ip="127.0.0.1",
|
|
171
|
+
ports={"put_get_socket": 12345},
|
|
172
|
+
),
|
|
173
|
+
"storage_1": ZMQServerInfo(
|
|
174
|
+
role=TransferQueueRole.STORAGE,
|
|
175
|
+
id="storage_1",
|
|
176
|
+
ip="127.0.0.1",
|
|
177
|
+
ports={"put_get_socket": 12346},
|
|
178
|
+
),
|
|
179
|
+
"storage_2": ZMQServerInfo(
|
|
180
|
+
role=TransferQueueRole.STORAGE,
|
|
181
|
+
id="storage_2",
|
|
182
|
+
ip="127.0.0.1",
|
|
183
|
+
ports={"put_get_socket": 12347},
|
|
184
|
+
),
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
# Mock controller info
|
|
188
|
+
controller_info = ZMQServerInfo(
|
|
189
|
+
role=TransferQueueRole.CONTROLLER,
|
|
190
|
+
id="controller_0",
|
|
191
|
+
ip="127.0.0.1",
|
|
192
|
+
ports={"handshake_socket": 12348, "data_status_update_socket": 12349},
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
config = {
|
|
196
|
+
"storage_unit_infos": storage_unit_infos,
|
|
197
|
+
"controller_info": controller_info,
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
# Mock ZMQ operations
|
|
201
|
+
with (
|
|
202
|
+
patch("transfer_queue.storage.managers.base.create_zmq_socket") as mock_create_socket,
|
|
203
|
+
patch("zmq.Poller") as mock_poller,
|
|
204
|
+
):
|
|
205
|
+
# Create mock socket with proper sync methods
|
|
206
|
+
mock_socket = Mock()
|
|
207
|
+
mock_socket.connect = Mock() # sync method
|
|
208
|
+
mock_socket.send = Mock() # sync method
|
|
209
|
+
mock_create_socket.return_value = mock_socket
|
|
210
|
+
|
|
211
|
+
# Mock poller with sync methods
|
|
212
|
+
mock_poller_instance = Mock()
|
|
213
|
+
mock_poller_instance.register = Mock() # sync method
|
|
214
|
+
# Return mock socket in poll to simulate handshake response
|
|
215
|
+
mock_poller_instance.poll = Mock(return_value=[(mock_socket, zmq.POLLIN)]) # sync method
|
|
216
|
+
mock_poller.return_value = mock_poller_instance
|
|
217
|
+
|
|
218
|
+
# Mock handshake response
|
|
219
|
+
handshake_response = ZMQMessage.create(
|
|
220
|
+
request_type=ZMQRequestType.HANDSHAKE_ACK,
|
|
221
|
+
sender_id="controller_0",
|
|
222
|
+
body={"message": "Handshake successful"},
|
|
223
|
+
)
|
|
224
|
+
mock_socket.recv = Mock(return_value=handshake_response.serialize())
|
|
225
|
+
|
|
226
|
+
# Create manager
|
|
227
|
+
manager = AsyncSimpleStorageManager(config)
|
|
228
|
+
|
|
229
|
+
# Test round-robin mapping for 3 storage units
|
|
230
|
+
# global_index -> storage_unit mapping: 0->storage_0, 1->storage_1, 2->storage_2,
|
|
231
|
+
# 3->storage_0, 4->storage_1, ...
|
|
232
|
+
assert manager.global_index_storage_unit_mapping(0) == "storage_0"
|
|
233
|
+
assert manager.global_index_storage_unit_mapping(1) == "storage_1"
|
|
234
|
+
assert manager.global_index_storage_unit_mapping(2) == "storage_2"
|
|
235
|
+
assert manager.global_index_storage_unit_mapping(3) == "storage_0"
|
|
236
|
+
assert manager.global_index_storage_unit_mapping(4) == "storage_1"
|
|
237
|
+
assert manager.global_index_storage_unit_mapping(5) == "storage_2"
|
|
238
|
+
|
|
239
|
+
# global_index -> local_index mapping: global_index // num_storage_units
|
|
240
|
+
assert manager.global_index_local_index_mapping(0) == 0
|
|
241
|
+
assert manager.global_index_local_index_mapping(1) == 0
|
|
242
|
+
assert manager.global_index_local_index_mapping(2) == 0
|
|
243
|
+
assert manager.global_index_local_index_mapping(3) == 1
|
|
244
|
+
assert manager.global_index_local_index_mapping(4) == 1
|
|
245
|
+
assert manager.global_index_local_index_mapping(5) == 1
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@pytest.mark.asyncio
|
|
249
|
+
async def test_async_storage_manager_error_handling():
|
|
250
|
+
"""Test AsyncSimpleStorageManager error handling."""
|
|
251
|
+
|
|
252
|
+
# Mock storage unit infos
|
|
253
|
+
storage_unit_infos = {
|
|
254
|
+
"storage_0": ZMQServerInfo(
|
|
255
|
+
role=TransferQueueRole.STORAGE,
|
|
256
|
+
id="storage_0",
|
|
257
|
+
ip="127.0.0.1",
|
|
258
|
+
ports={"put_get_socket": 12345},
|
|
259
|
+
),
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
# Mock controller info
|
|
263
|
+
controller_infos = ZMQServerInfo(
|
|
264
|
+
role=TransferQueueRole.CONTROLLER,
|
|
265
|
+
id="controller_0",
|
|
266
|
+
ip="127.0.0.1",
|
|
267
|
+
ports={"handshake_socket": 12346, "data_status_update_socket": 12347},
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
config = {
|
|
271
|
+
"storage_unit_infos": storage_unit_infos,
|
|
272
|
+
"controller_info": controller_infos,
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
# Mock ZMQ operations
|
|
276
|
+
with (
|
|
277
|
+
patch("transfer_queue.storage.managers.base.create_zmq_socket") as mock_create_socket,
|
|
278
|
+
patch("zmq.Poller") as mock_poller,
|
|
279
|
+
):
|
|
280
|
+
# Create mock socket with proper sync methods
|
|
281
|
+
mock_socket = Mock()
|
|
282
|
+
mock_socket.connect = Mock() # sync method
|
|
283
|
+
mock_socket.send = Mock() # sync method
|
|
284
|
+
mock_create_socket.return_value = mock_socket
|
|
285
|
+
|
|
286
|
+
# Mock poller with sync methods
|
|
287
|
+
mock_poller_instance = Mock()
|
|
288
|
+
mock_poller_instance.register = Mock() # sync method
|
|
289
|
+
# Return mock socket in poll to simulate handshake response
|
|
290
|
+
mock_poller_instance.poll = Mock(return_value=[(mock_socket, zmq.POLLIN)]) # sync method
|
|
291
|
+
mock_poller.return_value = mock_poller_instance
|
|
292
|
+
|
|
293
|
+
# Mock handshake response
|
|
294
|
+
handshake_response = ZMQMessage.create(
|
|
295
|
+
request_type=ZMQRequestType.HANDSHAKE_ACK,
|
|
296
|
+
sender_id="controller_0",
|
|
297
|
+
body={"message": "Handshake successful"},
|
|
298
|
+
)
|
|
299
|
+
mock_socket.recv = Mock(return_value=handshake_response.serialize())
|
|
300
|
+
|
|
301
|
+
# Create manager
|
|
302
|
+
manager = AsyncSimpleStorageManager(config)
|
|
303
|
+
|
|
304
|
+
# Mock operations that raise exceptions
|
|
305
|
+
manager._put_to_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock PUT error"))
|
|
306
|
+
manager._get_from_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock GET error"))
|
|
307
|
+
manager._clear_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock CLEAR error"))
|
|
308
|
+
manager.notify_data_update = AsyncMock()
|
|
309
|
+
|
|
310
|
+
# Create test metadata
|
|
311
|
+
sample_metas = [
|
|
312
|
+
SampleMeta(
|
|
313
|
+
partition_id="0",
|
|
314
|
+
global_index=0,
|
|
315
|
+
fields={
|
|
316
|
+
"test_field": FieldMeta(name="test_field", dtype=torch.float32, shape=(2,)),
|
|
317
|
+
},
|
|
318
|
+
),
|
|
319
|
+
]
|
|
320
|
+
batch_meta = BatchMeta(samples=sample_metas)
|
|
321
|
+
|
|
322
|
+
# Create test data
|
|
323
|
+
test_data = TensorDict(
|
|
324
|
+
{
|
|
325
|
+
"test_field": [torch.tensor([1.0, 2.0])],
|
|
326
|
+
},
|
|
327
|
+
batch_size=1,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Test that exceptions are properly raised
|
|
331
|
+
with pytest.raises(RuntimeError, match="Mock PUT error"):
|
|
332
|
+
await manager.put_data(test_data, batch_meta)
|
|
333
|
+
|
|
334
|
+
with pytest.raises(RuntimeError, match="Mock GET error"):
|
|
335
|
+
await manager.get_data(batch_meta)
|
|
336
|
+
|
|
337
|
+
# Note: clear_data uses return_exceptions=True, so it doesn't raise exceptions directly
|
|
338
|
+
# Instead, we can verify that the clear operation was attempted
|
|
339
|
+
await manager.clear_data(batch_meta) # Should not raise due to return_exceptions=True
|