TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import sys
|
|
16
|
+
import time
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import pytest
|
|
20
|
+
import ray
|
|
21
|
+
import tensordict
|
|
22
|
+
import torch
|
|
23
|
+
import zmq
|
|
24
|
+
from tensordict import TensorDict
|
|
25
|
+
|
|
26
|
+
# Setup path
|
|
27
|
+
parent_dir = Path(__file__).resolve().parent.parent
|
|
28
|
+
sys.path.append(str(parent_dir))
|
|
29
|
+
|
|
30
|
+
from transfer_queue import SimpleStorageUnit # noqa: E402
|
|
31
|
+
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType # noqa: E402
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class MockStorageClient:
|
|
35
|
+
"""Mock client for testing storage unit operations."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, storage_put_get_address):
|
|
38
|
+
self.context = zmq.Context()
|
|
39
|
+
self.socket = self.context.socket(zmq.DEALER)
|
|
40
|
+
self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
|
41
|
+
self.socket.connect(storage_put_get_address)
|
|
42
|
+
|
|
43
|
+
def send_put(self, client_id, local_indexes, field_data):
|
|
44
|
+
msg = ZMQMessage.create(
|
|
45
|
+
request_type=ZMQRequestType.PUT_DATA,
|
|
46
|
+
sender_id=f"mock_client_{client_id}",
|
|
47
|
+
body={"local_indexes": local_indexes, "data": field_data},
|
|
48
|
+
)
|
|
49
|
+
self.socket.send(msg.serialize())
|
|
50
|
+
return ZMQMessage.deserialize(self.socket.recv())
|
|
51
|
+
|
|
52
|
+
def send_get(self, client_id, local_indexes, fields):
|
|
53
|
+
msg = ZMQMessage.create(
|
|
54
|
+
request_type=ZMQRequestType.GET_DATA,
|
|
55
|
+
sender_id=f"mock_client_{client_id}",
|
|
56
|
+
body={"local_indexes": local_indexes, "fields": fields},
|
|
57
|
+
)
|
|
58
|
+
self.socket.send(msg.serialize())
|
|
59
|
+
return ZMQMessage.deserialize(self.socket.recv())
|
|
60
|
+
|
|
61
|
+
def send_clear(self, client_id, local_indexes):
|
|
62
|
+
msg = ZMQMessage.create(
|
|
63
|
+
request_type=ZMQRequestType.CLEAR_DATA,
|
|
64
|
+
sender_id=f"mock_client_{client_id}",
|
|
65
|
+
body={"local_indexes": local_indexes},
|
|
66
|
+
)
|
|
67
|
+
self.socket.send(msg.serialize())
|
|
68
|
+
return ZMQMessage.deserialize(self.socket.recv())
|
|
69
|
+
|
|
70
|
+
def close(self):
|
|
71
|
+
self.socket.close()
|
|
72
|
+
self.context.term()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@pytest.fixture(scope="session")
|
|
76
|
+
def ray_setup():
|
|
77
|
+
"""Initialize Ray for testing."""
|
|
78
|
+
ray.init(ignore_reinit_error=True)
|
|
79
|
+
yield
|
|
80
|
+
ray.shutdown()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@pytest.fixture
|
|
84
|
+
def storage_setup(ray_setup):
|
|
85
|
+
"""Set up storage unit for testing."""
|
|
86
|
+
storage_size = 10000
|
|
87
|
+
tensordict.set_list_to_stack(True).set()
|
|
88
|
+
|
|
89
|
+
# Start Ray actor for SimpleStorageUnit
|
|
90
|
+
storage_actor = SimpleStorageUnit.options(max_concurrency=50, num_cpus=1).remote(storage_unit_size=storage_size)
|
|
91
|
+
|
|
92
|
+
# Get ZMQ server info from storage unit
|
|
93
|
+
zmq_info = ray.get(storage_actor.get_zmq_server_info.remote())
|
|
94
|
+
put_get_address = zmq_info.to_addr("put_get_socket")
|
|
95
|
+
time.sleep(1) # Wait for socket to be ready
|
|
96
|
+
|
|
97
|
+
yield storage_actor, put_get_address
|
|
98
|
+
|
|
99
|
+
# Cleanup
|
|
100
|
+
ray.kill(storage_actor)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_put_get_single_client(storage_setup):
|
|
104
|
+
"""Test basic put and get operations with a single client."""
|
|
105
|
+
_, put_get_address = storage_setup
|
|
106
|
+
|
|
107
|
+
client = MockStorageClient(put_get_address)
|
|
108
|
+
|
|
109
|
+
# PUT data
|
|
110
|
+
local_indexes = [0, 1, 2]
|
|
111
|
+
field_data = TensorDict(
|
|
112
|
+
{
|
|
113
|
+
"log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
|
|
114
|
+
"rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])],
|
|
115
|
+
},
|
|
116
|
+
batch_size=[],
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
response = client.send_put(0, local_indexes, field_data)
|
|
120
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
121
|
+
|
|
122
|
+
# GET data
|
|
123
|
+
response = client.send_get(0, [0, 1], ["log_probs", "rewards"])
|
|
124
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
125
|
+
|
|
126
|
+
retrieved_data = response.body["data"]
|
|
127
|
+
assert "log_probs" in retrieved_data
|
|
128
|
+
assert "rewards" in retrieved_data
|
|
129
|
+
assert retrieved_data["log_probs"].size(0) == 2
|
|
130
|
+
assert retrieved_data["rewards"].size(0) == 2
|
|
131
|
+
|
|
132
|
+
# Verify data correctness
|
|
133
|
+
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
|
|
134
|
+
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
|
|
135
|
+
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([10.0]))
|
|
136
|
+
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([20.0]))
|
|
137
|
+
|
|
138
|
+
client.close()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_put_get_multiple_clients(storage_setup):
|
|
142
|
+
"""Test put and get operations with multiple clients."""
|
|
143
|
+
_, put_get_address = storage_setup
|
|
144
|
+
|
|
145
|
+
num_clients = 3
|
|
146
|
+
clients = [MockStorageClient(put_get_address) for _ in range(num_clients)]
|
|
147
|
+
|
|
148
|
+
# Each client puts unique data using different local_indexes
|
|
149
|
+
for i, client in enumerate(clients):
|
|
150
|
+
local_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2]
|
|
151
|
+
field_data = TensorDict(
|
|
152
|
+
{
|
|
153
|
+
"log_probs": [
|
|
154
|
+
torch.tensor([i, i + 1, i + 2]),
|
|
155
|
+
torch.tensor([i + 3, i + 4, i + 5]),
|
|
156
|
+
torch.tensor([i + 6, i + 7, i + 8]),
|
|
157
|
+
],
|
|
158
|
+
"rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])],
|
|
159
|
+
}
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
response = client.send_put(i, local_indexes, field_data)
|
|
163
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
164
|
+
|
|
165
|
+
# Test overlapping local indexes
|
|
166
|
+
overlapping_client = MockStorageClient(put_get_address)
|
|
167
|
+
overlap_local_indexes = [0] # Overlaps with first client's index 0
|
|
168
|
+
overlap_field_data = TensorDict({"log_probs": [torch.tensor([999, 999, 999])], "rewards": [torch.tensor([999])]})
|
|
169
|
+
response = overlapping_client.send_put(99, overlap_local_indexes, overlap_field_data)
|
|
170
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
171
|
+
|
|
172
|
+
# Each original client gets its own data (except for index 0 which was overwritten)
|
|
173
|
+
for i, client in enumerate(clients):
|
|
174
|
+
response = client.send_get(i, [i * 10 + 0, i * 10 + 1], ["log_probs", "rewards"])
|
|
175
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
176
|
+
|
|
177
|
+
retrieved_data = response.body["data"]
|
|
178
|
+
assert retrieved_data["log_probs"].size(0) == 2
|
|
179
|
+
assert retrieved_data["rewards"].size(0) == 2
|
|
180
|
+
|
|
181
|
+
# For index 0, expect data from overlapping_client; others from original client
|
|
182
|
+
if i == 0:
|
|
183
|
+
# Index 0 was overwritten
|
|
184
|
+
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([999, 999, 999]))
|
|
185
|
+
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([999]))
|
|
186
|
+
# Index 1 remains original
|
|
187
|
+
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([3, 4, 5]))
|
|
188
|
+
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([10]))
|
|
189
|
+
else:
|
|
190
|
+
# All data remains original
|
|
191
|
+
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([i, i + 1, i + 2]))
|
|
192
|
+
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([i + 3, i + 4, i + 5]))
|
|
193
|
+
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([i * 10]))
|
|
194
|
+
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([i * 10 + 10]))
|
|
195
|
+
|
|
196
|
+
# Cleanup
|
|
197
|
+
for client in clients:
|
|
198
|
+
client.close()
|
|
199
|
+
overlapping_client.close()
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def test_performance_basic(storage_setup):
|
|
203
|
+
"""Basic performance test with larger data volume."""
|
|
204
|
+
_, put_get_address = storage_setup
|
|
205
|
+
|
|
206
|
+
client = MockStorageClient(put_get_address)
|
|
207
|
+
|
|
208
|
+
# PUT performance test
|
|
209
|
+
put_latencies = []
|
|
210
|
+
num_puts = 10 # Reduced for faster testing
|
|
211
|
+
batch_size = 16 # Reduced for faster testing
|
|
212
|
+
|
|
213
|
+
for i in range(num_puts):
|
|
214
|
+
start = time.time()
|
|
215
|
+
|
|
216
|
+
# Use batch size and index mapping
|
|
217
|
+
local_indexes = list(range(i * batch_size, (i + 1) * batch_size))
|
|
218
|
+
|
|
219
|
+
# Create tensor data
|
|
220
|
+
log_probs_data = []
|
|
221
|
+
rewards_data = []
|
|
222
|
+
|
|
223
|
+
for _ in range(batch_size):
|
|
224
|
+
# Smaller tensors for faster testing
|
|
225
|
+
log_probs_tensor = torch.randn(100)
|
|
226
|
+
rewards_tensor = torch.randn(100)
|
|
227
|
+
log_probs_data.append(log_probs_tensor)
|
|
228
|
+
rewards_data.append(rewards_tensor)
|
|
229
|
+
|
|
230
|
+
field_data = TensorDict({"log_probs": log_probs_data, "rewards": rewards_data}, batch_size=[batch_size])
|
|
231
|
+
|
|
232
|
+
response = client.send_put(0, local_indexes, field_data)
|
|
233
|
+
latency = time.time() - start
|
|
234
|
+
put_latencies.append(latency)
|
|
235
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
236
|
+
|
|
237
|
+
# GET performance test
|
|
238
|
+
get_latencies = []
|
|
239
|
+
num_gets = 10
|
|
240
|
+
|
|
241
|
+
for i in range(num_gets):
|
|
242
|
+
start = time.time()
|
|
243
|
+
# Retrieve batch of data
|
|
244
|
+
local_indexes = list(range(i * batch_size, (i + 1) * batch_size))
|
|
245
|
+
response = client.send_get(0, local_indexes, ["log_probs", "rewards"])
|
|
246
|
+
latency = time.time() - start
|
|
247
|
+
get_latencies.append(latency)
|
|
248
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
249
|
+
|
|
250
|
+
avg_put_latency = sum(put_latencies) / len(put_latencies) * 1000 # ms
|
|
251
|
+
avg_get_latency = sum(get_latencies) / len(get_latencies) * 1000 # ms
|
|
252
|
+
|
|
253
|
+
# More lenient performance thresholds for testing environment
|
|
254
|
+
assert avg_put_latency < 1500, f"Avg PUT latency {avg_put_latency}ms exceeds threshold"
|
|
255
|
+
assert avg_get_latency < 1500, f"Avg GET latency {avg_get_latency}ms exceeds threshold"
|
|
256
|
+
|
|
257
|
+
client.close()
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def test_put_get_nested_tensor(storage_setup):
|
|
261
|
+
"""Test put and get operations with nested tensors."""
|
|
262
|
+
_, put_get_address = storage_setup
|
|
263
|
+
|
|
264
|
+
client = MockStorageClient(put_get_address)
|
|
265
|
+
|
|
266
|
+
# PUT data with nested tensors
|
|
267
|
+
local_indexes = [0, 1, 2]
|
|
268
|
+
field_data = TensorDict(
|
|
269
|
+
{
|
|
270
|
+
"variable_length_sequences": [
|
|
271
|
+
torch.tensor([-0.5, -1.2, -0.8]),
|
|
272
|
+
torch.tensor([-0.3, -1.5, -2.1, -0.9]),
|
|
273
|
+
torch.tensor([-1.1, -0.7]),
|
|
274
|
+
],
|
|
275
|
+
"attention_mask": [torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1]), torch.tensor([1, 1])],
|
|
276
|
+
},
|
|
277
|
+
batch_size=[],
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
response = client.send_put(0, local_indexes, field_data)
|
|
281
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
282
|
+
|
|
283
|
+
# GET data
|
|
284
|
+
response = client.send_get(0, [0, 2], ["variable_length_sequences", "attention_mask"])
|
|
285
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
286
|
+
|
|
287
|
+
retrieved_data = response.body["data"]
|
|
288
|
+
assert "variable_length_sequences" in retrieved_data
|
|
289
|
+
assert "attention_mask" in retrieved_data
|
|
290
|
+
assert retrieved_data["variable_length_sequences"].size(0) == 2
|
|
291
|
+
assert retrieved_data["attention_mask"].size(0) == 2
|
|
292
|
+
|
|
293
|
+
# Verify data correctness
|
|
294
|
+
torch.testing.assert_close(retrieved_data["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
|
|
295
|
+
torch.testing.assert_close(retrieved_data["variable_length_sequences"][1], torch.tensor([-1.1, -0.7]))
|
|
296
|
+
torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))
|
|
297
|
+
torch.testing.assert_close(retrieved_data["attention_mask"][1], torch.tensor([1, 1]))
|
|
298
|
+
|
|
299
|
+
client.close()
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def test_put_get_non_tensor_data(storage_setup):
|
|
303
|
+
"""Test put and get operations with non-tensor data (strings)."""
|
|
304
|
+
_, put_get_address = storage_setup
|
|
305
|
+
|
|
306
|
+
client = MockStorageClient(put_get_address)
|
|
307
|
+
|
|
308
|
+
# PUT data with non-tensor data
|
|
309
|
+
local_indexes = [0, 1, 2]
|
|
310
|
+
field_data = TensorDict(
|
|
311
|
+
{
|
|
312
|
+
"prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
|
|
313
|
+
"response_text": ["Hi there!", "This is the response to the longer sentence", "Test response"],
|
|
314
|
+
},
|
|
315
|
+
batch_size=[],
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
response = client.send_put(0, local_indexes, field_data)
|
|
319
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
320
|
+
|
|
321
|
+
# GET data
|
|
322
|
+
response = client.send_get(0, [0, 1, 2], ["prompt_text", "response_text"])
|
|
323
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
324
|
+
|
|
325
|
+
retrieved_data = response.body["data"]
|
|
326
|
+
assert "prompt_text" in retrieved_data
|
|
327
|
+
assert "response_text" in retrieved_data
|
|
328
|
+
|
|
329
|
+
# Verify data correctness
|
|
330
|
+
assert isinstance(retrieved_data["prompt_text"][0], str)
|
|
331
|
+
assert isinstance(retrieved_data["response_text"][0], str)
|
|
332
|
+
|
|
333
|
+
assert retrieved_data["prompt_text"][0] == "Hello world!"
|
|
334
|
+
assert retrieved_data["prompt_text"][1] == "This is a longer sentence for testing"
|
|
335
|
+
assert retrieved_data["prompt_text"][2] == "Test case"
|
|
336
|
+
assert retrieved_data["response_text"][0] == "Hi there!"
|
|
337
|
+
assert retrieved_data["response_text"][1] == "This is the response to the longer sentence"
|
|
338
|
+
assert retrieved_data["response_text"][2] == "Test response"
|
|
339
|
+
|
|
340
|
+
client.close()
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def test_put_get_single_item(storage_setup):
|
|
344
|
+
"""Test put and get operations for a single item."""
|
|
345
|
+
_, put_get_address = storage_setup
|
|
346
|
+
|
|
347
|
+
client = MockStorageClient(put_get_address)
|
|
348
|
+
|
|
349
|
+
# PUT single item data
|
|
350
|
+
field_data = TensorDict(
|
|
351
|
+
{
|
|
352
|
+
"prompt_text": ["Hello world!"],
|
|
353
|
+
"attention_mask": [torch.tensor([1, 1, 1])],
|
|
354
|
+
},
|
|
355
|
+
batch_size=[],
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
response = client.send_put(0, [0], field_data)
|
|
359
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
360
|
+
|
|
361
|
+
# GET data
|
|
362
|
+
response = client.send_get(0, [0], ["prompt_text", "attention_mask"])
|
|
363
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
364
|
+
|
|
365
|
+
retrieved_data = response.body["data"]
|
|
366
|
+
assert "prompt_text" in retrieved_data
|
|
367
|
+
assert "attention_mask" in retrieved_data
|
|
368
|
+
|
|
369
|
+
assert retrieved_data["prompt_text"][0] == "Hello world!"
|
|
370
|
+
assert retrieved_data["attention_mask"].shape == (1, 3)
|
|
371
|
+
torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))
|
|
372
|
+
|
|
373
|
+
client.close()
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def test_clear_data(storage_setup):
|
|
377
|
+
"""Test clear operations."""
|
|
378
|
+
_, put_get_address = storage_setup
|
|
379
|
+
|
|
380
|
+
client = MockStorageClient(put_get_address)
|
|
381
|
+
|
|
382
|
+
# PUT data first
|
|
383
|
+
local_indexes = [0, 1, 2]
|
|
384
|
+
field_data = TensorDict(
|
|
385
|
+
{
|
|
386
|
+
"log_probs": [torch.tensor([1.0]), torch.tensor([2.0]), torch.tensor([3.0])],
|
|
387
|
+
"rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])],
|
|
388
|
+
},
|
|
389
|
+
batch_size=[],
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
response = client.send_put(0, local_indexes, field_data)
|
|
393
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
394
|
+
|
|
395
|
+
# Verify data exists
|
|
396
|
+
response = client.send_get(0, [0, 1, 2], ["log_probs"])
|
|
397
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
398
|
+
assert response.body["data"]["log_probs"].size(0) == 3
|
|
399
|
+
|
|
400
|
+
# Clear data
|
|
401
|
+
response = client.send_clear(0, [0, 2]) # Clear only indexes 0 and 2
|
|
402
|
+
assert response.request_type == ZMQRequestType.CLEAR_DATA_RESPONSE
|
|
403
|
+
|
|
404
|
+
# Verify some data is cleared (but index 1 should still exist)
|
|
405
|
+
response = client.send_get(0, [1], ["log_probs"])
|
|
406
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
407
|
+
assert response.body["data"]["log_probs"].size(0) == 1
|
|
408
|
+
torch.testing.assert_close(response.body["data"]["log_probs"][0], torch.tensor([2.0]))
|
|
409
|
+
|
|
410
|
+
client.close()
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def test_storage_unit_data_direct():
|
|
414
|
+
"""Test StorageUnitData class directly without ZMQ."""
|
|
415
|
+
from transfer_queue.storage import StorageUnitData
|
|
416
|
+
|
|
417
|
+
storage_data = StorageUnitData(storage_size=10)
|
|
418
|
+
|
|
419
|
+
# Test put_data
|
|
420
|
+
field_data = TensorDict(
|
|
421
|
+
{
|
|
422
|
+
"log_probs": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])],
|
|
423
|
+
"rewards": [torch.tensor([10.0]), torch.tensor([20.0])],
|
|
424
|
+
},
|
|
425
|
+
batch_size=[],
|
|
426
|
+
)
|
|
427
|
+
storage_data.put_data(field_data, [0, 1])
|
|
428
|
+
|
|
429
|
+
# Test get_data
|
|
430
|
+
result = storage_data.get_data(["log_probs", "rewards"], [0, 1])
|
|
431
|
+
assert "log_probs" in result
|
|
432
|
+
assert "rewards" in result
|
|
433
|
+
assert result["log_probs"].size(0) == 2
|
|
434
|
+
assert result["rewards"].size(0) == 2
|
|
435
|
+
|
|
436
|
+
# Test single index get
|
|
437
|
+
result_single = storage_data.get_data(["log_probs"], [0])
|
|
438
|
+
assert result_single["log_probs"].shape == (1, 2)
|
|
439
|
+
|
|
440
|
+
# Test clear
|
|
441
|
+
storage_data.clear([0])
|
|
442
|
+
result_after_clear = storage_data.get_data(["log_probs"], [0])
|
|
443
|
+
assert result_after_clear["log_probs"][0] is None
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from importlib.util import find_spec
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from transfer_queue.storage.clients.factory import StorageClientFactory
|
|
8
|
+
from transfer_queue.storage.clients.yuanrong_client import YRStorageClient
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Test(unittest.TestCase):
|
|
12
|
+
def setUp(self):
|
|
13
|
+
self.cfg = {"host": "127.0.0.1", "port": 31501, "device_id": 0}
|
|
14
|
+
|
|
15
|
+
@pytest.mark.skipif(find_spec("datasystem") is None, reason="datasystem is not available")
|
|
16
|
+
def test_create_client(self):
|
|
17
|
+
self.assertIn("Yuanrong", StorageClientFactory._registry)
|
|
18
|
+
self.assertIs(StorageClientFactory._registry["Yuanrong"], YRStorageClient)
|
|
19
|
+
StorageClientFactory.create("Yuanrong", self.cfg)
|
|
20
|
+
|
|
21
|
+
with self.assertRaises(ValueError) as cm:
|
|
22
|
+
StorageClientFactory.create("abc", self.cfg)
|
|
23
|
+
self.assertIn("Unknown StorageClient", str(cm.exception))
|
|
24
|
+
|
|
25
|
+
@pytest.mark.skipif(
|
|
26
|
+
find_spec("torch_npu") is None or find_spec("datasystem") is None, reason="torch_npu is not available"
|
|
27
|
+
)
|
|
28
|
+
def test_client_create_empty_tensorlist(self):
|
|
29
|
+
tensors = [torch.Tensor([2, 1]), torch.Tensor([1, 5]), torch.Tensor([0]), torch.Tensor([-1.5])]
|
|
30
|
+
shapes = []
|
|
31
|
+
dtypes = []
|
|
32
|
+
for t in tensors:
|
|
33
|
+
shapes.append(t.shape)
|
|
34
|
+
dtypes.append(t.dtype)
|
|
35
|
+
client = StorageClientFactory.create("Yuanrong", self.cfg)
|
|
36
|
+
|
|
37
|
+
empty_tensors = client._create_empty_tensorlist(shapes, dtypes)
|
|
38
|
+
self.assertEqual(len(tensors), len(empty_tensors))
|
|
39
|
+
for t, et in zip(tensors, empty_tensors, strict=False):
|
|
40
|
+
self.assertEqual(t.shape, et.shape)
|
|
41
|
+
self.assertEqual(t.dtype, et.dtype)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if __name__ == "__main__":
|
|
45
|
+
unittest.main()
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
from .client import (
|
|
18
|
+
AsyncTransferQueueClient,
|
|
19
|
+
TransferQueueClient,
|
|
20
|
+
process_zmq_server_info,
|
|
21
|
+
)
|
|
22
|
+
from .controller import TransferQueueController
|
|
23
|
+
from .metadata import BatchMeta
|
|
24
|
+
from .sampler import BaseSampler
|
|
25
|
+
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
|
|
26
|
+
from .sampler.sequential_sampler import SequentialSampler
|
|
27
|
+
from .storage import SimpleStorageUnit
|
|
28
|
+
from .utils.utils import get_placement_group
|
|
29
|
+
from .utils.zmq_utils import ZMQServerInfo
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"AsyncTransferQueueClient",
|
|
33
|
+
"BatchMeta",
|
|
34
|
+
"TransferQueueClient",
|
|
35
|
+
"TransferQueueController",
|
|
36
|
+
"SimpleStorageUnit",
|
|
37
|
+
"ZMQServerInfo",
|
|
38
|
+
"process_zmq_server_info",
|
|
39
|
+
"get_placement_group",
|
|
40
|
+
"BaseSampler",
|
|
41
|
+
"GRPOGroupNSampler",
|
|
42
|
+
"SequentialSampler",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
|
|
46
|
+
|
|
47
|
+
with open(os.path.join(version_folder, "version/version")) as f:
|
|
48
|
+
__version__ = f.read().strip()
|