TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
tests/test_client.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import sys
|
|
16
|
+
import time
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from threading import Thread
|
|
19
|
+
from unittest.mock import patch
|
|
20
|
+
|
|
21
|
+
import pytest
|
|
22
|
+
import torch
|
|
23
|
+
import zmq
|
|
24
|
+
from tensordict import NonTensorStack, TensorDict
|
|
25
|
+
|
|
26
|
+
# Import your classes here
|
|
27
|
+
parent_dir = Path(__file__).resolve().parent.parent
|
|
28
|
+
sys.path.append(str(parent_dir))
|
|
29
|
+
|
|
30
|
+
from transfer_queue import TransferQueueClient # noqa: E402
|
|
31
|
+
from transfer_queue.metadata import ( # noqa: E402
|
|
32
|
+
BatchMeta,
|
|
33
|
+
FieldMeta,
|
|
34
|
+
SampleMeta,
|
|
35
|
+
)
|
|
36
|
+
from transfer_queue.utils.utils import TransferQueueRole # noqa: E402
|
|
37
|
+
from transfer_queue.utils.zmq_utils import ( # noqa: E402
|
|
38
|
+
ZMQMessage,
|
|
39
|
+
ZMQRequestType,
|
|
40
|
+
ZMQServerInfo,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
TEST_DATA = TensorDict(
|
|
44
|
+
{
|
|
45
|
+
"log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
|
|
46
|
+
"variable_length_sequences": torch.nested.as_nested_tensor(
|
|
47
|
+
[
|
|
48
|
+
torch.tensor([-0.5, -1.2, -0.8]),
|
|
49
|
+
torch.tensor([-0.3, -1.5, -2.1, -0.9]),
|
|
50
|
+
torch.tensor([-1.1, -0.7]),
|
|
51
|
+
]
|
|
52
|
+
),
|
|
53
|
+
"prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
|
|
54
|
+
},
|
|
55
|
+
batch_size=[3],
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Mock Controller for Client Unit Testing
|
|
60
|
+
class MockController:
|
|
61
|
+
def __init__(self, controller_id="controller_0"):
|
|
62
|
+
self.controller_id = controller_id
|
|
63
|
+
self.context = zmq.Context()
|
|
64
|
+
|
|
65
|
+
# Socket for data requests
|
|
66
|
+
self.request_socket = self.context.socket(zmq.ROUTER)
|
|
67
|
+
self.request_port = self._bind_to_random_port(self.request_socket)
|
|
68
|
+
|
|
69
|
+
self.zmq_server_info = ZMQServerInfo(
|
|
70
|
+
role=TransferQueueRole.CONTROLLER,
|
|
71
|
+
id=controller_id,
|
|
72
|
+
ip="127.0.0.1",
|
|
73
|
+
ports={
|
|
74
|
+
"request_handle_socket": self.request_port,
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
self.running = True
|
|
79
|
+
self.request_thread = Thread(target=self._handle_requests, daemon=True)
|
|
80
|
+
self.request_thread.start()
|
|
81
|
+
|
|
82
|
+
def _bind_to_random_port(self, socket):
|
|
83
|
+
port = socket.bind_to_random_port("tcp://127.0.0.1")
|
|
84
|
+
return port
|
|
85
|
+
|
|
86
|
+
def _handle_requests(self):
|
|
87
|
+
poller = zmq.Poller()
|
|
88
|
+
poller.register(self.request_socket, zmq.POLLIN)
|
|
89
|
+
|
|
90
|
+
while self.running:
|
|
91
|
+
try:
|
|
92
|
+
socks = dict(poller.poll(100)) # 100ms timeout
|
|
93
|
+
if self.request_socket in socks:
|
|
94
|
+
identity, serialized_msg = self.request_socket.recv_multipart()
|
|
95
|
+
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
96
|
+
|
|
97
|
+
# Determine response based on request type
|
|
98
|
+
if request_msg.request_type == ZMQRequestType.GET_META:
|
|
99
|
+
response_body = self._mock_batch_meta(request_msg.body)
|
|
100
|
+
response_type = ZMQRequestType.GET_META_RESPONSE
|
|
101
|
+
elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
|
|
102
|
+
response_body = self._mock_batch_meta(request_msg.body)
|
|
103
|
+
response_type = ZMQRequestType.GET_CLEAR_META_RESPONSE
|
|
104
|
+
elif request_msg.request_type == ZMQRequestType.CLEAR_META:
|
|
105
|
+
response_body = {"message": "clear ok"}
|
|
106
|
+
response_type = ZMQRequestType.CLEAR_META_RESPONSE
|
|
107
|
+
|
|
108
|
+
# Send response
|
|
109
|
+
response_msg = ZMQMessage.create(
|
|
110
|
+
request_type=response_type,
|
|
111
|
+
sender_id=self.controller_id,
|
|
112
|
+
receiver_id=request_msg.sender_id,
|
|
113
|
+
body=response_body,
|
|
114
|
+
)
|
|
115
|
+
self.request_socket.send_multipart([identity, response_msg.serialize()])
|
|
116
|
+
except zmq.Again:
|
|
117
|
+
continue
|
|
118
|
+
except Exception as e:
|
|
119
|
+
print(f"MockController ERROR: {e}")
|
|
120
|
+
raise
|
|
121
|
+
|
|
122
|
+
def _mock_batch_meta(self, request_body):
|
|
123
|
+
batch_size = request_body.get("batch_size", 1)
|
|
124
|
+
data_fields = request_body.get("data_fields", [])
|
|
125
|
+
|
|
126
|
+
samples = []
|
|
127
|
+
for i in range(batch_size):
|
|
128
|
+
fields = []
|
|
129
|
+
for field_name in data_fields:
|
|
130
|
+
field_meta = FieldMeta(
|
|
131
|
+
name=field_name,
|
|
132
|
+
dtype=None,
|
|
133
|
+
shape=None,
|
|
134
|
+
production_status=0,
|
|
135
|
+
)
|
|
136
|
+
fields.append(field_meta)
|
|
137
|
+
sample = SampleMeta(
|
|
138
|
+
partition_id="0",
|
|
139
|
+
global_index=i,
|
|
140
|
+
fields={field.name: field for field in fields},
|
|
141
|
+
)
|
|
142
|
+
samples.append(sample)
|
|
143
|
+
metadata = BatchMeta(samples=samples)
|
|
144
|
+
|
|
145
|
+
return {"metadata": metadata}
|
|
146
|
+
|
|
147
|
+
def stop(self):
|
|
148
|
+
self.running = False
|
|
149
|
+
time.sleep(0.2) # Give thread time to stop
|
|
150
|
+
self.request_socket.close()
|
|
151
|
+
self.context.term()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# Mock Storage for Client Unit Testing
|
|
155
|
+
class MockStorage:
|
|
156
|
+
def __init__(self, storage_id="storage_0"):
|
|
157
|
+
self.storage_id = storage_id
|
|
158
|
+
self.context = zmq.Context()
|
|
159
|
+
|
|
160
|
+
# Socket for data operations
|
|
161
|
+
self.data_socket = self.context.socket(zmq.ROUTER)
|
|
162
|
+
self.data_port = self._bind_to_random_port(self.data_socket)
|
|
163
|
+
|
|
164
|
+
self.zmq_server_info = ZMQServerInfo(
|
|
165
|
+
role=TransferQueueRole.STORAGE,
|
|
166
|
+
id=storage_id,
|
|
167
|
+
ip="127.0.0.1",
|
|
168
|
+
ports={
|
|
169
|
+
"put_get_socket": self.data_port,
|
|
170
|
+
},
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
self.running = True
|
|
174
|
+
self.data_thread = Thread(target=self._handle_data_requests, daemon=True)
|
|
175
|
+
self.data_thread.start()
|
|
176
|
+
|
|
177
|
+
def _bind_to_random_port(self, socket):
|
|
178
|
+
port = socket.bind_to_random_port("tcp://127.0.0.1")
|
|
179
|
+
return port
|
|
180
|
+
|
|
181
|
+
def _handle_data_requests(self):
|
|
182
|
+
poller = zmq.Poller()
|
|
183
|
+
poller.register(self.data_socket, zmq.POLLIN)
|
|
184
|
+
|
|
185
|
+
while self.running:
|
|
186
|
+
try:
|
|
187
|
+
socks = dict(poller.poll(100)) # 100ms timeout
|
|
188
|
+
if self.data_socket in socks:
|
|
189
|
+
identity, msg_bytes = self.data_socket.recv_multipart()
|
|
190
|
+
msg = ZMQMessage.deserialize(msg_bytes)
|
|
191
|
+
|
|
192
|
+
# Handle different request types
|
|
193
|
+
if msg.request_type == ZMQRequestType.PUT_DATA:
|
|
194
|
+
response_body = {"message": "Data stored successfully"}
|
|
195
|
+
response_type = ZMQRequestType.PUT_DATA_RESPONSE
|
|
196
|
+
elif msg.request_type == ZMQRequestType.GET_DATA:
|
|
197
|
+
response_body = self._handle_get_data(msg.body)
|
|
198
|
+
response_type = ZMQRequestType.GET_DATA_RESPONSE
|
|
199
|
+
elif msg.request_type == ZMQRequestType.CLEAR_DATA:
|
|
200
|
+
response_body = {"message": "Data cleared successfully"}
|
|
201
|
+
response_type = ZMQRequestType.CLEAR_DATA_RESPONSE
|
|
202
|
+
|
|
203
|
+
# Send response
|
|
204
|
+
response_msg = ZMQMessage.create(
|
|
205
|
+
request_type=response_type,
|
|
206
|
+
sender_id=self.storage_id,
|
|
207
|
+
receiver_id=msg.sender_id,
|
|
208
|
+
body=response_body,
|
|
209
|
+
)
|
|
210
|
+
self.data_socket.send_multipart([identity, response_msg.serialize()])
|
|
211
|
+
except zmq.Again:
|
|
212
|
+
continue
|
|
213
|
+
except Exception as e:
|
|
214
|
+
if self.running:
|
|
215
|
+
print(f"MockStorage running exception: {e}")
|
|
216
|
+
else:
|
|
217
|
+
print(f"MockStorage ERROR: {e}")
|
|
218
|
+
raise
|
|
219
|
+
|
|
220
|
+
def _handle_get_data(self, request_body):
|
|
221
|
+
"""Handle GET_DATA request by retrieving stored data"""
|
|
222
|
+
local_indexes = request_body.get("local_indexes", [])
|
|
223
|
+
fields = request_body.get("fields", [])
|
|
224
|
+
|
|
225
|
+
result: dict[str, list] = {}
|
|
226
|
+
for field in fields:
|
|
227
|
+
gathered_items = [TEST_DATA[field][i] for i in local_indexes]
|
|
228
|
+
|
|
229
|
+
if gathered_items:
|
|
230
|
+
all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
|
|
231
|
+
if all_tensors:
|
|
232
|
+
result[field] = torch.nested.as_nested_tensor(gathered_items)
|
|
233
|
+
else:
|
|
234
|
+
result[field] = NonTensorStack(*gathered_items)
|
|
235
|
+
|
|
236
|
+
return {"data": TensorDict(result)}
|
|
237
|
+
|
|
238
|
+
def stop(self):
|
|
239
|
+
self.running = False
|
|
240
|
+
time.sleep(0.2) # Give thread time to stop
|
|
241
|
+
self.data_socket.close()
|
|
242
|
+
self.context.term()
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
# Test Fixtures
|
|
246
|
+
@pytest.fixture
|
|
247
|
+
def mock_controller():
|
|
248
|
+
controller = MockController()
|
|
249
|
+
yield controller
|
|
250
|
+
controller.stop()
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@pytest.fixture
|
|
254
|
+
def mock_storage():
|
|
255
|
+
storage = MockStorage()
|
|
256
|
+
yield storage
|
|
257
|
+
storage.stop()
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@pytest.fixture
|
|
261
|
+
def client_setup(mock_controller, mock_storage):
|
|
262
|
+
# Create client with mock controller and storage
|
|
263
|
+
client_id = "client_0"
|
|
264
|
+
|
|
265
|
+
client = TransferQueueClient(
|
|
266
|
+
client_id=client_id,
|
|
267
|
+
controller_info=mock_controller.zmq_server_info,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Mock the storage manager to avoid handshake issues but mock all data operations
|
|
271
|
+
with patch(
|
|
272
|
+
"transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller"
|
|
273
|
+
):
|
|
274
|
+
config = {
|
|
275
|
+
"controller_info": mock_controller.zmq_server_info,
|
|
276
|
+
"storage_unit_infos": {mock_storage.storage_id: mock_storage.zmq_server_info},
|
|
277
|
+
}
|
|
278
|
+
client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
|
|
279
|
+
|
|
280
|
+
# Mock all storage manager methods to avoid real ZMQ operations
|
|
281
|
+
async def mock_put_data(data, metadata):
|
|
282
|
+
pass # Just pretend to store the data
|
|
283
|
+
|
|
284
|
+
async def mock_get_data(metadata):
|
|
285
|
+
# Return the test data when requested
|
|
286
|
+
return TEST_DATA
|
|
287
|
+
|
|
288
|
+
async def mock_clear_data(metadata):
|
|
289
|
+
pass # Just pretend to clear the data
|
|
290
|
+
|
|
291
|
+
client.storage_manager.put_data = mock_put_data
|
|
292
|
+
client.storage_manager.get_data = mock_get_data
|
|
293
|
+
client.storage_manager.clear_data = mock_clear_data
|
|
294
|
+
|
|
295
|
+
yield client, mock_controller, mock_storage
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# Test basic functionality
|
|
299
|
+
def test_client_initialization(client_setup):
|
|
300
|
+
"""Test client initialization and connection setup"""
|
|
301
|
+
client, mock_controller, mock_storage = client_setup
|
|
302
|
+
|
|
303
|
+
assert client.client_id is not None
|
|
304
|
+
assert client._controller is not None
|
|
305
|
+
assert client._controller.id == mock_controller.controller_id
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def test_put_and_get_data(client_setup):
|
|
309
|
+
"""Test basic put and get operations"""
|
|
310
|
+
client, _, _ = client_setup
|
|
311
|
+
|
|
312
|
+
# Test put operation
|
|
313
|
+
client.put(data=TEST_DATA, partition_id="0")
|
|
314
|
+
|
|
315
|
+
# Get metadata for retrieving data
|
|
316
|
+
metadata = client.get_meta(
|
|
317
|
+
data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, partition_id="0"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Test get operation
|
|
321
|
+
result = client.get_data(metadata)
|
|
322
|
+
|
|
323
|
+
# Verify result structure
|
|
324
|
+
assert "log_probs" in result
|
|
325
|
+
assert "variable_length_sequences" in result
|
|
326
|
+
assert "prompt_text" in result
|
|
327
|
+
|
|
328
|
+
torch.testing.assert_close(result["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
|
|
329
|
+
torch.testing.assert_close(result["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
|
|
330
|
+
torch.testing.assert_close(result["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
|
|
331
|
+
torch.testing.assert_close(result["variable_length_sequences"][1], torch.tensor([-0.3, -1.5, -2.1, -0.9]))
|
|
332
|
+
assert result["prompt_text"][0] == "Hello world!"
|
|
333
|
+
assert result["prompt_text"][1] == "This is a longer sentence for testing"
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def test_get_meta(client_setup):
|
|
337
|
+
"""Test metadata retrieval"""
|
|
338
|
+
client, _, _ = client_setup
|
|
339
|
+
|
|
340
|
+
# Test get_meta operation
|
|
341
|
+
metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=10, partition_id="0")
|
|
342
|
+
|
|
343
|
+
# Verify metadata structure
|
|
344
|
+
assert hasattr(metadata, "global_indexes")
|
|
345
|
+
assert hasattr(metadata, "field_names")
|
|
346
|
+
assert hasattr(metadata, "size")
|
|
347
|
+
assert len(metadata.global_indexes) == 10
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def test_clear_operation(client_setup):
|
|
351
|
+
"""Test clear operation"""
|
|
352
|
+
client, _, _ = client_setup
|
|
353
|
+
|
|
354
|
+
# Test clear operation
|
|
355
|
+
client.clear(partition_id="0")
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
# Test with single controller and multiple storage units
|
|
359
|
+
def test_single_controller_multiple_storages():
|
|
360
|
+
"""Test client with single controller and multiple storage units"""
|
|
361
|
+
# Create single controller and multiple storage units
|
|
362
|
+
controller = MockController("controller_0")
|
|
363
|
+
storages = [MockStorage(f"storage_{i}") for i in range(3)]
|
|
364
|
+
|
|
365
|
+
try:
|
|
366
|
+
# Create client with single controller
|
|
367
|
+
client_id = "client_test_single_controller"
|
|
368
|
+
|
|
369
|
+
client = TransferQueueClient(client_id=client_id, controller_info=controller.zmq_server_info)
|
|
370
|
+
|
|
371
|
+
# Mock the storage manager to avoid handshake issues but mock all data operations
|
|
372
|
+
with patch(
|
|
373
|
+
"transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller"
|
|
374
|
+
):
|
|
375
|
+
config = {
|
|
376
|
+
"controller_info": controller.zmq_server_info,
|
|
377
|
+
"storage_unit_infos": {s.storage_id: s.zmq_server_info for s in storages},
|
|
378
|
+
}
|
|
379
|
+
client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
|
|
380
|
+
|
|
381
|
+
# Mock all storage manager methods to avoid real ZMQ operations
|
|
382
|
+
async def mock_put_data(data, metadata):
|
|
383
|
+
pass # Just pretend to store the data
|
|
384
|
+
|
|
385
|
+
async def mock_get_data(metadata):
|
|
386
|
+
# Return some test data when requested
|
|
387
|
+
return TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
|
|
388
|
+
|
|
389
|
+
async def mock_clear_data(metadata):
|
|
390
|
+
pass # Just pretend to clear the data
|
|
391
|
+
|
|
392
|
+
client.storage_manager.put_data = mock_put_data
|
|
393
|
+
client.storage_manager.get_data = mock_get_data
|
|
394
|
+
client.storage_manager.clear_data = mock_clear_data
|
|
395
|
+
|
|
396
|
+
# Verify controller is set
|
|
397
|
+
assert client._controller is not None
|
|
398
|
+
assert client._controller.id == controller.controller_id
|
|
399
|
+
|
|
400
|
+
# Test basic operation
|
|
401
|
+
test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
|
|
402
|
+
|
|
403
|
+
# Test put operation
|
|
404
|
+
client.put(data=test_data, partition_id="0")
|
|
405
|
+
|
|
406
|
+
finally:
|
|
407
|
+
# Clean up
|
|
408
|
+
controller.stop()
|
|
409
|
+
for s in storages:
|
|
410
|
+
s.stop()
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
# Test error handling
|
|
414
|
+
def test_put_without_required_params(client_setup):
|
|
415
|
+
"""Test put operation without required parameters"""
|
|
416
|
+
client, _, _ = client_setup
|
|
417
|
+
|
|
418
|
+
# Create test data
|
|
419
|
+
test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
|
|
420
|
+
|
|
421
|
+
# Test put without partition id (should fail)
|
|
422
|
+
with pytest.raises(ValueError):
|
|
423
|
+
client.put(data=test_data)
|
tests/test_controller.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import sys
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import pytest
|
|
20
|
+
import ray
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
parent_dir = Path(__file__).resolve().parent.parent
|
|
24
|
+
sys.path.append(str(parent_dir))
|
|
25
|
+
|
|
26
|
+
# Set up logging
|
|
27
|
+
logging.basicConfig(level=logging.INFO)
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
from transfer_queue import TransferQueueController # noqa: E402
|
|
31
|
+
from transfer_queue.controller import TQ_INIT_FIELD_NUM # noqa: E402
|
|
32
|
+
from transfer_queue.utils.utils import ProductionStatus # noqa: E402
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture(scope="function")
|
|
36
|
+
def ray_setup():
|
|
37
|
+
if ray.is_initialized():
|
|
38
|
+
ray.shutdown()
|
|
39
|
+
ray.init(
|
|
40
|
+
ignore_reinit_error=True,
|
|
41
|
+
runtime_env={"env_vars": {"RAY_DEBUG": "1", "RAY_DEDUP_LOGS": "0"}},
|
|
42
|
+
log_to_driver=True,
|
|
43
|
+
)
|
|
44
|
+
yield
|
|
45
|
+
if ray.is_initialized():
|
|
46
|
+
ray.shutdown()
|
|
47
|
+
logger.info("Ray has been shut down completely after test")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TestTransferQueueController:
|
|
51
|
+
def test_controller_with_single_partition(self, ray_setup):
|
|
52
|
+
gbs = 8
|
|
53
|
+
num_n_samples = 4
|
|
54
|
+
|
|
55
|
+
tq_controller = TransferQueueController.remote()
|
|
56
|
+
|
|
57
|
+
# Test get metadata in insert mode
|
|
58
|
+
partition_id = "train_0"
|
|
59
|
+
data_fields = ["prompt_ids", "attention_mask"]
|
|
60
|
+
metadata = ray.get(
|
|
61
|
+
tq_controller.get_metadata.remote(
|
|
62
|
+
data_fields=data_fields,
|
|
63
|
+
batch_size=gbs * num_n_samples,
|
|
64
|
+
partition_id=partition_id,
|
|
65
|
+
mode="insert",
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
assert metadata.global_indexes == list(range(gbs * num_n_samples))
|
|
70
|
+
assert metadata.samples[0].partition_id == "train_0"
|
|
71
|
+
assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in metadata.samples]) == int(
|
|
72
|
+
ProductionStatus.NOT_PRODUCED
|
|
73
|
+
)
|
|
74
|
+
assert sum([int(sample.fields.get("attention_mask").production_status) for sample in metadata.samples]) == int(
|
|
75
|
+
ProductionStatus.NOT_PRODUCED
|
|
76
|
+
)
|
|
77
|
+
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
|
|
78
|
+
assert partition_index_range == set(range(gbs * num_n_samples))
|
|
79
|
+
|
|
80
|
+
print("✓ Initial get metadata correct")
|
|
81
|
+
|
|
82
|
+
# Test update production status
|
|
83
|
+
success = ray.get(
|
|
84
|
+
tq_controller.update_production_status.remote(
|
|
85
|
+
partition_id=partition_id,
|
|
86
|
+
global_indexes=metadata.global_indexes,
|
|
87
|
+
field_names=metadata.field_names,
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
assert success
|
|
91
|
+
partition = ray.get(tq_controller.get_partition.remote(partition_id))
|
|
92
|
+
assert partition.production_status is not None
|
|
93
|
+
assert partition.production_status.size(0) == gbs * num_n_samples
|
|
94
|
+
assert partition.production_status.size(1) == TQ_INIT_FIELD_NUM
|
|
95
|
+
assert torch.equal(
|
|
96
|
+
sum(partition.production_status[:, : len(data_fields)]),
|
|
97
|
+
torch.Tensor([gbs * num_n_samples, gbs * num_n_samples]),
|
|
98
|
+
)
|
|
99
|
+
assert torch.equal(
|
|
100
|
+
sum(partition.production_status[:, len(data_fields) :]),
|
|
101
|
+
torch.zeros(1 * (TQ_INIT_FIELD_NUM - len(data_fields))),
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
print(f"✓ Updated production status for partition {partition_id}")
|
|
105
|
+
|
|
106
|
+
# Test get metadate in fetch mode
|
|
107
|
+
gen_meta = ray.get(
|
|
108
|
+
tq_controller.get_metadata.remote(
|
|
109
|
+
data_fields=["prompt_ids"],
|
|
110
|
+
batch_size=gbs * num_n_samples,
|
|
111
|
+
partition_id=partition_id,
|
|
112
|
+
mode="fetch",
|
|
113
|
+
task_name="generate_sequences",
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
assert gen_meta.global_indexes == list(range(gbs * num_n_samples))
|
|
117
|
+
assert gen_meta.samples[0].partition_id == "train_0"
|
|
118
|
+
assert gen_meta.field_names == ["prompt_ids"]
|
|
119
|
+
partition = ray.get(tq_controller.get_partition.remote(partition_id))
|
|
120
|
+
assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples))
|
|
121
|
+
print("✓ Get metadata in fetch mode correct")
|
|
122
|
+
|
|
123
|
+
# Test get clear meta
|
|
124
|
+
clear_meta = ray.get(
|
|
125
|
+
tq_controller.get_metadata.remote(
|
|
126
|
+
data_fields=[],
|
|
127
|
+
partition_id=partition_id,
|
|
128
|
+
mode="insert",
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
assert clear_meta.global_indexes == list(range(gbs * num_n_samples))
|
|
132
|
+
assert [sample.fields for sample in clear_meta.samples] == [{}] * (gbs * num_n_samples)
|
|
133
|
+
print("✓ Clear metadata correct")
|
|
134
|
+
|
|
135
|
+
# Test clear
|
|
136
|
+
ray.get(tq_controller.clear.remote(partition_id))
|
|
137
|
+
partition = ray.get(tq_controller.get_partition.remote(partition_id))
|
|
138
|
+
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
|
|
139
|
+
assert partition_index_range == set()
|
|
140
|
+
assert torch.all(partition.production_status == 0)
|
|
141
|
+
assert torch.all(partition.consumption_status["generate_sequences"] == 0)
|
|
142
|
+
print("✓ Clear correct")
|
|
143
|
+
|
|
144
|
+
def test_controller_with_multi_partitions(self, ray_setup):
|
|
145
|
+
gbs_1 = 8
|
|
146
|
+
num_n_samples_1 = 4
|
|
147
|
+
partition_id_1 = "train_0"
|
|
148
|
+
|
|
149
|
+
gbs_2 = 16
|
|
150
|
+
num_n_samples_2 = 1
|
|
151
|
+
partition_id_2 = "val_0"
|
|
152
|
+
|
|
153
|
+
gbs_3 = 32
|
|
154
|
+
num_n_samples_3 = 2
|
|
155
|
+
partition_id_3 = "train_1"
|
|
156
|
+
|
|
157
|
+
tq_controller = TransferQueueController.remote()
|
|
158
|
+
|
|
159
|
+
# Test get metadata in insert mode
|
|
160
|
+
data_fields = ["prompt_ids", "attention_mask"]
|
|
161
|
+
metadata = ray.get(
|
|
162
|
+
tq_controller.get_metadata.remote(
|
|
163
|
+
data_fields=data_fields,
|
|
164
|
+
batch_size=gbs_1 * num_n_samples_1,
|
|
165
|
+
partition_id=partition_id_1,
|
|
166
|
+
mode="insert",
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Test update production status
|
|
171
|
+
success = ray.get(
|
|
172
|
+
tq_controller.update_production_status.remote(
|
|
173
|
+
partition_id=partition_id_1,
|
|
174
|
+
global_indexes=metadata.global_indexes,
|
|
175
|
+
field_names=metadata.field_names,
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
assert success
|
|
179
|
+
|
|
180
|
+
# Test get metadate in fetch mode
|
|
181
|
+
gen_meta = ray.get(
|
|
182
|
+
tq_controller.get_metadata.remote(
|
|
183
|
+
data_fields=["prompt_ids"],
|
|
184
|
+
batch_size=gbs_1 * num_n_samples_1,
|
|
185
|
+
partition_id=partition_id_1,
|
|
186
|
+
mode="fetch",
|
|
187
|
+
task_name="generate_sequences",
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
assert gen_meta
|
|
191
|
+
|
|
192
|
+
# Test get clear meta
|
|
193
|
+
clear_meta = ray.get(
|
|
194
|
+
tq_controller.get_metadata.remote(
|
|
195
|
+
data_fields=[],
|
|
196
|
+
partition_id=partition_id_1,
|
|
197
|
+
mode="insert",
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
assert clear_meta
|
|
201
|
+
|
|
202
|
+
# =========================partition 2=============================#
|
|
203
|
+
data_fields = ["prompt_ids", "attention_mask"]
|
|
204
|
+
val_metadata = ray.get(
|
|
205
|
+
tq_controller.get_metadata.remote(
|
|
206
|
+
data_fields=data_fields,
|
|
207
|
+
batch_size=gbs_2 * num_n_samples_2,
|
|
208
|
+
partition_id=partition_id_2,
|
|
209
|
+
mode="insert",
|
|
210
|
+
)
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
part1_index_range = gbs_1 * num_n_samples_1
|
|
214
|
+
part2_index_range = gbs_2 * num_n_samples_2
|
|
215
|
+
assert val_metadata.global_indexes == list(range(part1_index_range, part2_index_range + part1_index_range))
|
|
216
|
+
assert val_metadata.samples[0].partition_id == "val_0"
|
|
217
|
+
assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in val_metadata.samples]) == int(
|
|
218
|
+
ProductionStatus.NOT_PRODUCED
|
|
219
|
+
)
|
|
220
|
+
assert sum(
|
|
221
|
+
[int(sample.fields.get("attention_mask").production_status) for sample in val_metadata.samples]
|
|
222
|
+
) == int(ProductionStatus.NOT_PRODUCED)
|
|
223
|
+
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
|
|
224
|
+
assert partition_index_range == set(range(part1_index_range, part2_index_range + part1_index_range))
|
|
225
|
+
|
|
226
|
+
# Update production status
|
|
227
|
+
success = ray.get(
|
|
228
|
+
tq_controller.update_production_status.remote(
|
|
229
|
+
partition_id=partition_id_2,
|
|
230
|
+
global_indexes=val_metadata.global_indexes,
|
|
231
|
+
field_names=val_metadata.field_names,
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
assert success
|
|
235
|
+
|
|
236
|
+
# Clear partition 1
|
|
237
|
+
partition_index_range_1 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
|
|
238
|
+
assert partition_index_range_1
|
|
239
|
+
ray.get(tq_controller.clear.remote(partition_id_1))
|
|
240
|
+
partition_1_after_clear = ray.get(tq_controller.get_partition.remote(partition_id_1))
|
|
241
|
+
partition_index_range_1_after_clear = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
|
|
242
|
+
|
|
243
|
+
assert not partition_index_range_1_after_clear
|
|
244
|
+
assert torch.all(partition_1_after_clear.production_status[list(partition_index_range_1), :] == 0)
|
|
245
|
+
assert torch.all(partition_1_after_clear.consumption_status["generate_sequences"] == 0)
|
|
246
|
+
|
|
247
|
+
partition_2 = ray.get(tq_controller.get_partition.remote(partition_id_2))
|
|
248
|
+
partition_index_range_2 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
|
|
249
|
+
assert partition_index_range_2 == set([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
|
|
250
|
+
assert torch.all(
|
|
251
|
+
partition_2.production_status[list(partition_index_range_2), : len(val_metadata.field_names)] == 1
|
|
252
|
+
)
|
|
253
|
+
print("✓ Only clear partition 1 correct")
|
|
254
|
+
|
|
255
|
+
# =========================partition 3=============================#
|
|
256
|
+
metadata_2 = ray.get(
|
|
257
|
+
tq_controller.get_metadata.remote(
|
|
258
|
+
data_fields=data_fields,
|
|
259
|
+
batch_size=gbs_3 * num_n_samples_3,
|
|
260
|
+
partition_id=partition_id_3,
|
|
261
|
+
mode="insert",
|
|
262
|
+
)
|
|
263
|
+
)
|
|
264
|
+
assert metadata_2.global_indexes == list(range(32)) + list(range(48, 80))
|
|
265
|
+
assert metadata_2.samples[0].partition_id == "train_1"
|
|
266
|
+
assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in metadata_2.samples]) == int(
|
|
267
|
+
ProductionStatus.NOT_PRODUCED
|
|
268
|
+
)
|
|
269
|
+
assert sum(
|
|
270
|
+
[int(sample.fields.get("attention_mask").production_status) for sample in metadata_2.samples]
|
|
271
|
+
) == int(ProductionStatus.NOT_PRODUCED)
|
|
272
|
+
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3))
|
|
273
|
+
assert partition_index_range == set(list(range(32)) + list(range(48, 80)))
|
|
274
|
+
print("✓ Correctly assign partition_3")
|