TransferQueue 0.0.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +307 -0
- recipe/simple_use_case/sync_demo.py +223 -0
- tests/test_client.py +390 -0
- tests/test_controller.py +268 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +479 -0
- transfer_queue/__init__.py +42 -0
- transfer_queue/client.py +663 -0
- transfer_queue/controller.py +772 -0
- transfer_queue/metadata.py +603 -0
- transfer_queue/storage.py +515 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +98 -0
- transfer_queue/utils/zmq_utils.py +175 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.0.1.dev0.dist-info/METADATA +15 -0
- transferqueue-0.0.1.dev0.dist-info/RECORD +21 -0
- transferqueue-0.0.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.0.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.0.1.dev0.dist-info/top_level.txt +4 -0
tests/test_client.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import sys
|
|
16
|
+
import time
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from threading import Thread
|
|
19
|
+
|
|
20
|
+
import pytest
|
|
21
|
+
import torch
|
|
22
|
+
import zmq
|
|
23
|
+
from tensordict import NonTensorStack, TensorDict
|
|
24
|
+
|
|
25
|
+
# Import your classes here
|
|
26
|
+
parent_dir = Path(__file__).resolve().parent.parent
|
|
27
|
+
sys.path.append(str(parent_dir))
|
|
28
|
+
|
|
29
|
+
from transfer_queue import TransferQueueClient # noqa: E402
|
|
30
|
+
from transfer_queue.metadata import ( # noqa: E402
|
|
31
|
+
BatchMeta,
|
|
32
|
+
FieldMeta,
|
|
33
|
+
SampleMeta,
|
|
34
|
+
)
|
|
35
|
+
from transfer_queue.utils.zmq_utils import ( # noqa: E402
|
|
36
|
+
ZMQMessage,
|
|
37
|
+
ZMQRequestType,
|
|
38
|
+
ZMQServerInfo,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
TEST_DATA = TensorDict(
|
|
42
|
+
{
|
|
43
|
+
"log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
|
|
44
|
+
"variable_length_sequences": torch.nested.as_nested_tensor(
|
|
45
|
+
[
|
|
46
|
+
torch.tensor([-0.5, -1.2, -0.8]),
|
|
47
|
+
torch.tensor([-0.3, -1.5, -2.1, -0.9]),
|
|
48
|
+
torch.tensor([-1.1, -0.7]),
|
|
49
|
+
]
|
|
50
|
+
),
|
|
51
|
+
"prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
|
|
52
|
+
},
|
|
53
|
+
batch_size=[3],
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Mock Controller for Client Unit Testing
|
|
58
|
+
class MockController:
|
|
59
|
+
def __init__(self, controller_id="controller_0"):
|
|
60
|
+
self.controller_id = controller_id
|
|
61
|
+
self.context = zmq.Context()
|
|
62
|
+
|
|
63
|
+
# Socket for data requests
|
|
64
|
+
self.request_socket = self.context.socket(zmq.ROUTER)
|
|
65
|
+
self.request_port = self._bind_to_random_port(self.request_socket)
|
|
66
|
+
|
|
67
|
+
self.zmq_server_info = ZMQServerInfo.create(
|
|
68
|
+
role="TransferQueueController",
|
|
69
|
+
id=controller_id,
|
|
70
|
+
ip="127.0.0.1",
|
|
71
|
+
ports={
|
|
72
|
+
"request_handle_socket": self.request_port,
|
|
73
|
+
},
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self.running = True
|
|
77
|
+
self.request_thread = Thread(target=self._handle_requests, daemon=True)
|
|
78
|
+
self.request_thread.start()
|
|
79
|
+
|
|
80
|
+
def _bind_to_random_port(self, socket):
|
|
81
|
+
port = socket.bind_to_random_port("tcp://127.0.0.1")
|
|
82
|
+
return port
|
|
83
|
+
|
|
84
|
+
def _handle_requests(self):
|
|
85
|
+
poller = zmq.Poller()
|
|
86
|
+
poller.register(self.request_socket, zmq.POLLIN)
|
|
87
|
+
|
|
88
|
+
while self.running:
|
|
89
|
+
try:
|
|
90
|
+
socks = dict(poller.poll(100)) # 100ms timeout
|
|
91
|
+
if self.request_socket in socks:
|
|
92
|
+
identity, serialized_msg = self.request_socket.recv_multipart()
|
|
93
|
+
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
94
|
+
|
|
95
|
+
# Determine response based on request type
|
|
96
|
+
if request_msg.request_type == ZMQRequestType.GET_META:
|
|
97
|
+
response_body = self._mock_batch_meta(request_msg.body)
|
|
98
|
+
response_type = ZMQRequestType.GET_META_RESPONSE
|
|
99
|
+
elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
|
|
100
|
+
response_body = self._mock_batch_meta(request_msg.body)
|
|
101
|
+
response_type = ZMQRequestType.GET_CLEAR_META_RESPONSE
|
|
102
|
+
elif request_msg.request_type == ZMQRequestType.CLEAR_META:
|
|
103
|
+
response_body = {"message": "clear ok"}
|
|
104
|
+
response_type = ZMQRequestType.CLEAR_META_RESPONSE
|
|
105
|
+
|
|
106
|
+
# Send response
|
|
107
|
+
response_msg = ZMQMessage.create(
|
|
108
|
+
request_type=response_type,
|
|
109
|
+
sender_id=self.controller_id,
|
|
110
|
+
receiver_id=request_msg.sender_id,
|
|
111
|
+
body=response_body,
|
|
112
|
+
)
|
|
113
|
+
self.request_socket.send_multipart([identity, response_msg.serialize()])
|
|
114
|
+
except zmq.Again:
|
|
115
|
+
continue
|
|
116
|
+
except Exception as e:
|
|
117
|
+
if self.is_running:
|
|
118
|
+
print(f"MockController running exception: {e}")
|
|
119
|
+
else:
|
|
120
|
+
print(f"MockController ERROR: {e}")
|
|
121
|
+
raise
|
|
122
|
+
|
|
123
|
+
def _mock_batch_meta(self, request_body):
|
|
124
|
+
batch_size = request_body.get("batch_size", 1)
|
|
125
|
+
data_fields = request_body.get("data_fields", [])
|
|
126
|
+
|
|
127
|
+
samples = []
|
|
128
|
+
for i in range(batch_size):
|
|
129
|
+
fields = []
|
|
130
|
+
for field_name in data_fields:
|
|
131
|
+
field_meta = FieldMeta(
|
|
132
|
+
name=field_name,
|
|
133
|
+
dtype=None,
|
|
134
|
+
shape=None,
|
|
135
|
+
production_status=0,
|
|
136
|
+
)
|
|
137
|
+
fields.append(field_meta)
|
|
138
|
+
sample = SampleMeta(
|
|
139
|
+
global_step=0,
|
|
140
|
+
global_index=i,
|
|
141
|
+
storage_id="storage_0",
|
|
142
|
+
local_index=i,
|
|
143
|
+
fields={field.name: field for field in fields},
|
|
144
|
+
)
|
|
145
|
+
samples.append(sample)
|
|
146
|
+
metadata = BatchMeta(samples=samples)
|
|
147
|
+
|
|
148
|
+
return {"metadata": metadata}
|
|
149
|
+
|
|
150
|
+
def stop(self):
|
|
151
|
+
self.running = False
|
|
152
|
+
time.sleep(0.2) # Give thread time to stop
|
|
153
|
+
self.request_socket.close()
|
|
154
|
+
self.context.term()
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# Mock Storage for Client Unit Testing
|
|
158
|
+
class MockStorage:
|
|
159
|
+
def __init__(self, storage_id="storage_0"):
|
|
160
|
+
self.storage_id = storage_id
|
|
161
|
+
self.context = zmq.Context()
|
|
162
|
+
|
|
163
|
+
# Socket for data operations
|
|
164
|
+
self.data_socket = self.context.socket(zmq.ROUTER)
|
|
165
|
+
self.data_port = self._bind_to_random_port(self.data_socket)
|
|
166
|
+
|
|
167
|
+
self.zmq_server_info = ZMQServerInfo.create(
|
|
168
|
+
role="TransferQueueStorage",
|
|
169
|
+
id=storage_id,
|
|
170
|
+
ip="127.0.0.1",
|
|
171
|
+
ports={
|
|
172
|
+
"put_get_socket": self.data_port,
|
|
173
|
+
},
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self.running = True
|
|
177
|
+
self.data_thread = Thread(target=self._handle_data_requests, daemon=True)
|
|
178
|
+
self.data_thread.start()
|
|
179
|
+
|
|
180
|
+
def _bind_to_random_port(self, socket):
|
|
181
|
+
port = socket.bind_to_random_port("tcp://127.0.0.1")
|
|
182
|
+
return port
|
|
183
|
+
|
|
184
|
+
def _handle_data_requests(self):
|
|
185
|
+
poller = zmq.Poller()
|
|
186
|
+
poller.register(self.data_socket, zmq.POLLIN)
|
|
187
|
+
|
|
188
|
+
while self.running:
|
|
189
|
+
try:
|
|
190
|
+
socks = dict(poller.poll(100)) # 100ms timeout
|
|
191
|
+
if self.data_socket in socks:
|
|
192
|
+
identity, msg_bytes = self.data_socket.recv_multipart()
|
|
193
|
+
msg = ZMQMessage.deserialize(msg_bytes)
|
|
194
|
+
|
|
195
|
+
# Handle different request types
|
|
196
|
+
if msg.request_type == ZMQRequestType.PUT_DATA:
|
|
197
|
+
response_body = {"message": "Data stored successfully"}
|
|
198
|
+
response_type = ZMQRequestType.PUT_DATA_RESPONSE
|
|
199
|
+
elif msg.request_type == ZMQRequestType.GET_DATA:
|
|
200
|
+
response_body = self._handle_get_data(msg.body)
|
|
201
|
+
response_type = ZMQRequestType.GET_DATA_RESPONSE
|
|
202
|
+
elif msg.request_type == ZMQRequestType.CLEAR_DATA:
|
|
203
|
+
response_body = {"message": "Data cleared successfully"}
|
|
204
|
+
response_type = ZMQRequestType.CLEAR_DATA_RESPONSE
|
|
205
|
+
|
|
206
|
+
# Send response
|
|
207
|
+
response_msg = ZMQMessage.create(
|
|
208
|
+
request_type=response_type,
|
|
209
|
+
sender_id=self.storage_id,
|
|
210
|
+
receiver_id=msg.sender_id,
|
|
211
|
+
body=response_body,
|
|
212
|
+
)
|
|
213
|
+
self.data_socket.send_multipart([identity, response_msg.serialize()])
|
|
214
|
+
except zmq.Again:
|
|
215
|
+
continue
|
|
216
|
+
except Exception as e:
|
|
217
|
+
if self.is_running:
|
|
218
|
+
print(f"MockStorage running exception: {e}")
|
|
219
|
+
else:
|
|
220
|
+
print(f"MockStorage ERROR: {e}")
|
|
221
|
+
raise
|
|
222
|
+
|
|
223
|
+
def _handle_get_data(self, request_body):
|
|
224
|
+
"""Handle GET_DATA request by retrieving stored data"""
|
|
225
|
+
local_indexes = request_body.get("local_indexes", [])
|
|
226
|
+
fields = request_body.get("fields", [])
|
|
227
|
+
|
|
228
|
+
result: dict[str, list] = {}
|
|
229
|
+
for field in fields:
|
|
230
|
+
gathered_items = [TEST_DATA[field][i] for i in local_indexes]
|
|
231
|
+
|
|
232
|
+
if gathered_items:
|
|
233
|
+
all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
|
|
234
|
+
if all_tensors:
|
|
235
|
+
result[field] = torch.nested.as_nested_tensor(gathered_items)
|
|
236
|
+
else:
|
|
237
|
+
result[field] = NonTensorStack(*gathered_items)
|
|
238
|
+
|
|
239
|
+
return {"data": TensorDict(result)}
|
|
240
|
+
|
|
241
|
+
def stop(self):
|
|
242
|
+
self.running = False
|
|
243
|
+
time.sleep(0.2) # Give thread time to stop
|
|
244
|
+
self.data_socket.close()
|
|
245
|
+
self.context.term()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# Test Fixtures
|
|
249
|
+
@pytest.fixture
|
|
250
|
+
def mock_controller():
|
|
251
|
+
controller = MockController()
|
|
252
|
+
yield controller
|
|
253
|
+
controller.stop()
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@pytest.fixture
|
|
257
|
+
def mock_storage():
|
|
258
|
+
storage = MockStorage()
|
|
259
|
+
yield storage
|
|
260
|
+
storage.stop()
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@pytest.fixture
|
|
264
|
+
def client_setup(mock_controller, mock_storage):
|
|
265
|
+
# Create client with mock controller and storage
|
|
266
|
+
client_id = "client_0"
|
|
267
|
+
|
|
268
|
+
client = TransferQueueClient(
|
|
269
|
+
client_id=client_id,
|
|
270
|
+
controller_infos={mock_controller.controller_id: mock_controller.zmq_server_info},
|
|
271
|
+
storage_infos={mock_storage.storage_id: mock_storage.zmq_server_info},
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# Give some time for connections to establish
|
|
275
|
+
time.sleep(0.5)
|
|
276
|
+
|
|
277
|
+
yield client, mock_controller, mock_storage
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
# Test basic functionality
|
|
281
|
+
def test_client_initialization(client_setup):
|
|
282
|
+
"""Test client initialization and connection setup"""
|
|
283
|
+
client, mock_controller, mock_storage = client_setup
|
|
284
|
+
|
|
285
|
+
assert client.client_id is not None
|
|
286
|
+
assert mock_controller.controller_id in client._controllers
|
|
287
|
+
assert mock_storage.storage_id in client._storages
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def test_put_and_get_data(client_setup):
|
|
291
|
+
"""Test basic put and get operations"""
|
|
292
|
+
client, _, _ = client_setup
|
|
293
|
+
|
|
294
|
+
# Test put operation
|
|
295
|
+
client.put(data=TEST_DATA, global_step=0)
|
|
296
|
+
|
|
297
|
+
# Get metadata for retrieving data
|
|
298
|
+
metadata = client.get_meta(
|
|
299
|
+
data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, global_step=0
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Test get operation
|
|
303
|
+
result = client.get_data(metadata)
|
|
304
|
+
|
|
305
|
+
# Verify result structure
|
|
306
|
+
assert "log_probs" in result
|
|
307
|
+
assert "variable_length_sequences" in result
|
|
308
|
+
assert "prompt_text" in result
|
|
309
|
+
|
|
310
|
+
torch.testing.assert_close(result["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
|
|
311
|
+
torch.testing.assert_close(result["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
|
|
312
|
+
torch.testing.assert_close(result["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
|
|
313
|
+
torch.testing.assert_close(result["variable_length_sequences"][1], torch.tensor([-0.3, -1.5, -2.1, -0.9]))
|
|
314
|
+
assert result["prompt_text"][0] == "Hello world!"
|
|
315
|
+
assert result["prompt_text"][1] == "This is a longer sentence for testing"
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def test_get_meta(client_setup):
|
|
319
|
+
"""Test metadata retrieval"""
|
|
320
|
+
client, _, _ = client_setup
|
|
321
|
+
|
|
322
|
+
# Test get_meta operation
|
|
323
|
+
metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=10, global_step=0)
|
|
324
|
+
|
|
325
|
+
# Verify metadata structure
|
|
326
|
+
assert hasattr(metadata, "storage_meta_groups")
|
|
327
|
+
assert hasattr(metadata, "global_indexes")
|
|
328
|
+
assert hasattr(metadata, "field_names")
|
|
329
|
+
assert hasattr(metadata, "size")
|
|
330
|
+
assert len(metadata.global_indexes) == 10
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def test_clear_operation(client_setup):
|
|
334
|
+
"""Test clear operation"""
|
|
335
|
+
client, _, _ = client_setup
|
|
336
|
+
|
|
337
|
+
# Test clear operation
|
|
338
|
+
client.clear(global_step=0)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
# Test with multiple controllers and storage units
|
|
342
|
+
def test_multiple_servers():
|
|
343
|
+
"""Test client with multiple controllers and storage units"""
|
|
344
|
+
# Create multiple mock servers
|
|
345
|
+
controllers = [MockController(f"controller_{i}") for i in range(2)]
|
|
346
|
+
storages = [MockStorage(f"storage_{i}") for i in range(3)]
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
# Create client with multiple servers
|
|
350
|
+
client_id = "client_test_multiple_servers"
|
|
351
|
+
|
|
352
|
+
controller_infos = {c.controller_id: c.zmq_server_info for c in controllers}
|
|
353
|
+
storage_infos = {s.storage_id: s.zmq_server_info for s in storages}
|
|
354
|
+
|
|
355
|
+
client = TransferQueueClient(
|
|
356
|
+
client_id=client_id, controller_infos=controller_infos, storage_infos=storage_infos
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Give time for connections
|
|
360
|
+
time.sleep(1.0)
|
|
361
|
+
|
|
362
|
+
# Verify connections
|
|
363
|
+
assert len(client._controllers) == 2
|
|
364
|
+
assert len(client._storages) == 3
|
|
365
|
+
|
|
366
|
+
# Test basic operation
|
|
367
|
+
test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
|
|
368
|
+
|
|
369
|
+
# Test put operation
|
|
370
|
+
client.put(data=test_data, global_step=0)
|
|
371
|
+
|
|
372
|
+
finally:
|
|
373
|
+
# Clean up
|
|
374
|
+
for c in controllers:
|
|
375
|
+
c.stop()
|
|
376
|
+
for s in storages:
|
|
377
|
+
s.stop()
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
# Test error handling
|
|
381
|
+
def test_put_without_required_params(client_setup):
|
|
382
|
+
"""Test put operation without required parameters"""
|
|
383
|
+
client, _, _ = client_setup
|
|
384
|
+
|
|
385
|
+
# Create test data
|
|
386
|
+
test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
|
|
387
|
+
|
|
388
|
+
# Test put without global_step (should fail)
|
|
389
|
+
with pytest.raises(AssertionError):
|
|
390
|
+
client.put(data=test_data)
|
tests/test_controller.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import math
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pytest
|
|
22
|
+
import ray
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
parent_dir = Path(__file__).resolve().parent.parent
|
|
26
|
+
sys.path.append(str(parent_dir))
|
|
27
|
+
|
|
28
|
+
from transfer_queue.controller import TQ_INIT_FIELD_NUM, TransferQueueController # noqa: E402
|
|
29
|
+
from transfer_queue.storage import TransferQueueStorageSimpleUnit # noqa: E402
|
|
30
|
+
|
|
31
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture(scope="function")
|
|
36
|
+
def ray_setup():
|
|
37
|
+
if ray.is_initialized():
|
|
38
|
+
ray.shutdown()
|
|
39
|
+
ray.init(
|
|
40
|
+
ignore_reinit_error=True,
|
|
41
|
+
runtime_env={"env_vars": {"RAY_DEBUG": "1", "RAY_DEDUP_LOGS": "0"}},
|
|
42
|
+
log_to_driver=True,
|
|
43
|
+
)
|
|
44
|
+
yield
|
|
45
|
+
if ray.is_initialized():
|
|
46
|
+
ray.shutdown()
|
|
47
|
+
logger.info("Ray has been shut down completely after test")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.fixture(scope="function")
|
|
51
|
+
def setup_teardown_transfer_queue_controller(ray_setup):
|
|
52
|
+
# Used as the offset for the global index to distinguish which global step the data corresponds to
|
|
53
|
+
global_batch_size = 8
|
|
54
|
+
num_global_batch = 2
|
|
55
|
+
num_n_samples = 2
|
|
56
|
+
num_data_storage_units = 2
|
|
57
|
+
|
|
58
|
+
tq_controller = TransferQueueController.remote(
|
|
59
|
+
num_storage_units=num_data_storage_units,
|
|
60
|
+
global_batch_size=global_batch_size,
|
|
61
|
+
num_global_batch=num_global_batch,
|
|
62
|
+
num_n_samples=num_n_samples,
|
|
63
|
+
)
|
|
64
|
+
yield tq_controller, global_batch_size, num_global_batch, num_n_samples
|
|
65
|
+
ray.get(tq_controller.clear.remote(0))
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@pytest.fixture(scope="function")
|
|
69
|
+
def setup_teardown_register_controller_info(setup_teardown_transfer_queue_controller):
|
|
70
|
+
tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller
|
|
71
|
+
total_storage_size = global_batch_size * num_global_batch * num_n_samples
|
|
72
|
+
num_data_storage_units = 2
|
|
73
|
+
|
|
74
|
+
data_system_storage_units = {}
|
|
75
|
+
for storage_unit_rank in range(num_data_storage_units):
|
|
76
|
+
storage_node = TransferQueueStorageSimpleUnit.remote(
|
|
77
|
+
storage_size=math.ceil(total_storage_size / num_data_storage_units)
|
|
78
|
+
)
|
|
79
|
+
data_system_storage_units[storage_unit_rank] = storage_node
|
|
80
|
+
logger.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
|
|
81
|
+
|
|
82
|
+
# Register controller info
|
|
83
|
+
zmq_server_info = ray.get(tq_controller.get_zmq_server_info.remote())
|
|
84
|
+
controller_infos = {zmq_server_info.id: zmq_server_info}
|
|
85
|
+
|
|
86
|
+
ray.get(
|
|
87
|
+
[
|
|
88
|
+
storage_unit.register_controller_info.remote(controller_infos)
|
|
89
|
+
for storage_unit in data_system_storage_units.values()
|
|
90
|
+
]
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
yield tq_controller, global_batch_size, num_n_samples, data_system_storage_units
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class TestTransferQueueController:
|
|
97
|
+
@pytest.mark.parametrize("num_n_samples", [1, 2])
|
|
98
|
+
@pytest.mark.parametrize("num_global_batch", [1, 2])
|
|
99
|
+
def test_build_index_storage_mapping(self, num_n_samples, num_global_batch, ray_setup):
|
|
100
|
+
# Used as the offset for the global index to distinguish which global step the data corresponds to
|
|
101
|
+
global_batch_size = 8
|
|
102
|
+
num_data_storage_units = 2
|
|
103
|
+
|
|
104
|
+
self.tq_controller = TransferQueueController.remote(
|
|
105
|
+
num_storage_units=num_data_storage_units,
|
|
106
|
+
global_batch_size=global_batch_size,
|
|
107
|
+
num_global_batch=num_global_batch,
|
|
108
|
+
num_n_samples=num_n_samples,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
global_index_storage_mapping, global_index_local_index_mapping = ray.get(
|
|
112
|
+
self.tq_controller.get_global_index_mapping.remote()
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if num_global_batch == 1 and num_n_samples == 1:
|
|
116
|
+
assert np.array_equal(global_index_storage_mapping, np.array([0, 0, 0, 0, 1, 1, 1, 1]))
|
|
117
|
+
assert np.array_equal(global_index_local_index_mapping, np.array([0, 1, 2, 3, 0, 1, 2, 3]))
|
|
118
|
+
# The data of a single GBS will be distributed across different storage units
|
|
119
|
+
elif num_global_batch == 2 and num_n_samples == 1:
|
|
120
|
+
assert np.array_equal(
|
|
121
|
+
global_index_storage_mapping, np.array([0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1])
|
|
122
|
+
)
|
|
123
|
+
assert np.array_equal(
|
|
124
|
+
global_index_local_index_mapping, np.array([0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7])
|
|
125
|
+
)
|
|
126
|
+
# When num_n_samples is larger than 1
|
|
127
|
+
elif num_global_batch == 1 and num_n_samples == 2:
|
|
128
|
+
assert np.array_equal(
|
|
129
|
+
global_index_storage_mapping, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
|
|
130
|
+
)
|
|
131
|
+
assert np.array_equal(
|
|
132
|
+
global_index_local_index_mapping, np.array([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7])
|
|
133
|
+
)
|
|
134
|
+
elif num_global_batch == 2 and num_n_samples == 2:
|
|
135
|
+
assert np.array_equal(
|
|
136
|
+
global_index_storage_mapping,
|
|
137
|
+
np.array(
|
|
138
|
+
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
|
|
139
|
+
),
|
|
140
|
+
)
|
|
141
|
+
assert np.array_equal(
|
|
142
|
+
global_index_local_index_mapping,
|
|
143
|
+
np.array(
|
|
144
|
+
[
|
|
145
|
+
0,
|
|
146
|
+
1,
|
|
147
|
+
2,
|
|
148
|
+
3,
|
|
149
|
+
4,
|
|
150
|
+
5,
|
|
151
|
+
6,
|
|
152
|
+
7,
|
|
153
|
+
0,
|
|
154
|
+
1,
|
|
155
|
+
2,
|
|
156
|
+
3,
|
|
157
|
+
4,
|
|
158
|
+
5,
|
|
159
|
+
6,
|
|
160
|
+
7,
|
|
161
|
+
8,
|
|
162
|
+
9,
|
|
163
|
+
10,
|
|
164
|
+
11,
|
|
165
|
+
12,
|
|
166
|
+
13,
|
|
167
|
+
14,
|
|
168
|
+
15,
|
|
169
|
+
8,
|
|
170
|
+
9,
|
|
171
|
+
10,
|
|
172
|
+
11,
|
|
173
|
+
12,
|
|
174
|
+
13,
|
|
175
|
+
14,
|
|
176
|
+
15,
|
|
177
|
+
]
|
|
178
|
+
),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def test_update_production_status(self, setup_teardown_transfer_queue_controller):
|
|
182
|
+
tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller
|
|
183
|
+
|
|
184
|
+
total_storage_size = global_batch_size * num_global_batch * num_n_samples
|
|
185
|
+
# Initialize get_data_production_status and filed_name_mapping
|
|
186
|
+
init_update_production_status = torch.zeros(total_storage_size, TQ_INIT_FIELD_NUM, dtype=torch.int8)
|
|
187
|
+
assert torch.equal(ray.get(tq_controller.get_data_production_status.remote()), init_update_production_status)
|
|
188
|
+
assert ray.get(tq_controller.get_field_name_mapping.remote()) == {}
|
|
189
|
+
|
|
190
|
+
columns_list = ["test_prompts"]
|
|
191
|
+
global_indexes = list(range(global_batch_size * num_n_samples))
|
|
192
|
+
|
|
193
|
+
# update production status
|
|
194
|
+
tq_controller._update_production_status.remote(global_indexes, columns_list)
|
|
195
|
+
new_field_name_mapping = ray.get(tq_controller.get_field_name_mapping.remote())
|
|
196
|
+
assert new_field_name_mapping["test_prompts"] == 0
|
|
197
|
+
|
|
198
|
+
new_data_production_status = ray.get(tq_controller.get_data_production_status.remote())
|
|
199
|
+
assert new_data_production_status[:, 0][: len(global_indexes)].sum() == len(global_indexes)
|
|
200
|
+
|
|
201
|
+
def test_data_consumption_status(self, setup_teardown_transfer_queue_controller):
|
|
202
|
+
tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller
|
|
203
|
+
total_storage_size = global_batch_size * num_global_batch * num_n_samples
|
|
204
|
+
|
|
205
|
+
init_data_consumption_status = {}
|
|
206
|
+
assert ray.get(tq_controller.get_data_consumption_status.remote()) == init_data_consumption_status
|
|
207
|
+
|
|
208
|
+
task_name = "test_task1"
|
|
209
|
+
ray.get(tq_controller._get_consumption_status.remote(task_name))
|
|
210
|
+
new_data_consumption_status = ray.get(tq_controller.get_data_consumption_status.remote())
|
|
211
|
+
assert torch.equal(new_data_consumption_status[task_name], torch.zeros(total_storage_size, dtype=torch.int8))
|
|
212
|
+
|
|
213
|
+
def test_get_prompt_metadata(self, setup_teardown_register_controller_info):
|
|
214
|
+
tq_controller, global_batch_size, n_samples, _ = setup_teardown_register_controller_info
|
|
215
|
+
|
|
216
|
+
data_fields = ["test_prompts"]
|
|
217
|
+
global_step = 5
|
|
218
|
+
|
|
219
|
+
metadata = ray.get(
|
|
220
|
+
tq_controller._get_metadata.remote(
|
|
221
|
+
data_fields=data_fields,
|
|
222
|
+
batch_size=global_batch_size * n_samples,
|
|
223
|
+
global_step=global_step,
|
|
224
|
+
mode="insert",
|
|
225
|
+
)
|
|
226
|
+
)
|
|
227
|
+
metadata.reorder([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
|
|
228
|
+
assert metadata.global_indexes == [
|
|
229
|
+
31,
|
|
230
|
+
30,
|
|
231
|
+
29,
|
|
232
|
+
28,
|
|
233
|
+
27,
|
|
234
|
+
26,
|
|
235
|
+
25,
|
|
236
|
+
24,
|
|
237
|
+
23,
|
|
238
|
+
22,
|
|
239
|
+
21,
|
|
240
|
+
20,
|
|
241
|
+
19,
|
|
242
|
+
18,
|
|
243
|
+
17,
|
|
244
|
+
16,
|
|
245
|
+
]
|
|
246
|
+
assert metadata.local_indexes == [
|
|
247
|
+
15,
|
|
248
|
+
14,
|
|
249
|
+
13,
|
|
250
|
+
12,
|
|
251
|
+
11,
|
|
252
|
+
10,
|
|
253
|
+
9,
|
|
254
|
+
8,
|
|
255
|
+
15,
|
|
256
|
+
14,
|
|
257
|
+
13,
|
|
258
|
+
12,
|
|
259
|
+
11,
|
|
260
|
+
10,
|
|
261
|
+
9,
|
|
262
|
+
8,
|
|
263
|
+
]
|
|
264
|
+
storage_ids = metadata.storage_ids
|
|
265
|
+
assert len(set(storage_ids[: len(storage_ids) // 2])) == 1
|
|
266
|
+
|
|
267
|
+
# TODO: Test case where multiple clients concurrently read datameta from a single controller,
|
|
268
|
+
# and each client receives the correct response
|