TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
tests/test_put.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import logging
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
import pytest
|
|
21
|
+
import ray
|
|
22
|
+
import torch
|
|
23
|
+
from tensordict import TensorDict
|
|
24
|
+
|
|
25
|
+
parent_dir = Path(__file__).resolve().parent.parent
|
|
26
|
+
sys.path.append(str(parent_dir))
|
|
27
|
+
|
|
28
|
+
from transfer_queue import ( # noqa: E402
|
|
29
|
+
AsyncTransferQueueClient,
|
|
30
|
+
SimpleStorageUnit,
|
|
31
|
+
TransferQueueController,
|
|
32
|
+
process_zmq_server_info,
|
|
33
|
+
)
|
|
34
|
+
from transfer_queue.utils.utils import get_placement_group # noqa: E402
|
|
35
|
+
|
|
36
|
+
# Set up logging
|
|
37
|
+
logging.basicConfig(level=logging.INFO)
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.fixture(scope="module")
|
|
42
|
+
def ray_setup():
|
|
43
|
+
"""Initialize Ray for testing."""
|
|
44
|
+
ray.init(ignore_reinit_error=True)
|
|
45
|
+
yield
|
|
46
|
+
ray.shutdown()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.fixture(scope="module")
|
|
50
|
+
def data_system_setup(ray_setup):
|
|
51
|
+
"""Set up data system for testing."""
|
|
52
|
+
# Initialize storage units
|
|
53
|
+
num_storage_units = 2
|
|
54
|
+
storage_size = 10000
|
|
55
|
+
storage_units = {}
|
|
56
|
+
|
|
57
|
+
storage_placement_group = get_placement_group(num_storage_units, num_cpus_per_actor=1)
|
|
58
|
+
for storage_unit_rank in range(num_storage_units):
|
|
59
|
+
storage_node = SimpleStorageUnit.options(
|
|
60
|
+
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
|
|
61
|
+
).remote(storage_unit_size=storage_size)
|
|
62
|
+
storage_units[storage_unit_rank] = storage_node
|
|
63
|
+
logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
|
|
64
|
+
|
|
65
|
+
# Initialize controller
|
|
66
|
+
controller = TransferQueueController.remote()
|
|
67
|
+
logger.info("TransferQueueController has been created.")
|
|
68
|
+
|
|
69
|
+
# Prepare connection info
|
|
70
|
+
controller_info = process_zmq_server_info(controller)
|
|
71
|
+
storage_unit_infos = process_zmq_server_info(storage_units)
|
|
72
|
+
|
|
73
|
+
# Create config
|
|
74
|
+
config = {
|
|
75
|
+
"controller_info": controller_info,
|
|
76
|
+
"storage_unit_infos": storage_unit_infos,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
yield controller, storage_units, config
|
|
80
|
+
|
|
81
|
+
# Cleanup
|
|
82
|
+
ray.kill(controller)
|
|
83
|
+
for storage_unit in storage_units.values():
|
|
84
|
+
ray.kill(storage_unit)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@pytest.fixture
|
|
88
|
+
async def client_setup(data_system_setup):
|
|
89
|
+
"""Set up client for testing."""
|
|
90
|
+
controller, storage_units, config = data_system_setup
|
|
91
|
+
|
|
92
|
+
client = AsyncTransferQueueClient(
|
|
93
|
+
client_id="TestClient",
|
|
94
|
+
controller_info=config["controller_info"],
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
|
|
98
|
+
|
|
99
|
+
# Wait a bit for connections to establish
|
|
100
|
+
await asyncio.sleep(1)
|
|
101
|
+
|
|
102
|
+
return client
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class TestMultipleAsyncPut:
|
|
106
|
+
"""Test class for multiple async_put operations."""
|
|
107
|
+
|
|
108
|
+
def __init__(self):
|
|
109
|
+
self.controller = None
|
|
110
|
+
self.storage_units = None
|
|
111
|
+
self.config = None
|
|
112
|
+
|
|
113
|
+
async def setup(self):
|
|
114
|
+
"""Setup for the test class."""
|
|
115
|
+
if not ray.is_initialized():
|
|
116
|
+
ray.init(ignore_reinit_error=True)
|
|
117
|
+
|
|
118
|
+
# Initialize data system
|
|
119
|
+
num_storage_units = 2
|
|
120
|
+
self.storage_units = {}
|
|
121
|
+
storage_placement_group = get_placement_group(num_storage_units, num_cpus_per_actor=1)
|
|
122
|
+
|
|
123
|
+
for i in range(num_storage_units):
|
|
124
|
+
self.storage_units[i] = SimpleStorageUnit.options(
|
|
125
|
+
placement_group=storage_placement_group, placement_group_bundle_index=i
|
|
126
|
+
).remote(storage_unit_size=10000)
|
|
127
|
+
|
|
128
|
+
self.controller = TransferQueueController.remote()
|
|
129
|
+
|
|
130
|
+
# Wait for initialization
|
|
131
|
+
await asyncio.sleep(2)
|
|
132
|
+
|
|
133
|
+
controller_info = process_zmq_server_info(self.controller)
|
|
134
|
+
storage_unit_infos = process_zmq_server_info(self.storage_units)
|
|
135
|
+
|
|
136
|
+
self.config = {
|
|
137
|
+
"controller_info": controller_info,
|
|
138
|
+
"storage_unit_infos": storage_unit_infos,
|
|
139
|
+
}
|
|
140
|
+
logger.info("TestMultipleAsyncPut setup completed")
|
|
141
|
+
|
|
142
|
+
async def teardown(self):
|
|
143
|
+
"""Teardown for the test class."""
|
|
144
|
+
if self.controller:
|
|
145
|
+
ray.kill(self.controller)
|
|
146
|
+
if self.storage_units:
|
|
147
|
+
for storage in self.storage_units.values():
|
|
148
|
+
ray.kill(storage)
|
|
149
|
+
if ray.is_initialized():
|
|
150
|
+
ray.shutdown()
|
|
151
|
+
|
|
152
|
+
async def create_client(self, client_id="TestClient"):
|
|
153
|
+
"""Create a new client instance."""
|
|
154
|
+
if self.config is None:
|
|
155
|
+
await self.setup()
|
|
156
|
+
|
|
157
|
+
client = AsyncTransferQueueClient(
|
|
158
|
+
client_id=client_id,
|
|
159
|
+
controller_info=self.config["controller_info"],
|
|
160
|
+
)
|
|
161
|
+
client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config)
|
|
162
|
+
await asyncio.sleep(1) # Wait for connections
|
|
163
|
+
return client
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@pytest.mark.asyncio
|
|
167
|
+
async def test_concurrent_async_put():
|
|
168
|
+
"""Test concurrent async_put operations."""
|
|
169
|
+
test_instance = TestMultipleAsyncPut()
|
|
170
|
+
await test_instance.setup()
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
client = await test_instance.create_client("ConcurrentTestClient")
|
|
174
|
+
|
|
175
|
+
async def put_operation(partition_id, data):
|
|
176
|
+
await client.async_put(data=data, partition_id=partition_id)
|
|
177
|
+
return partition_id
|
|
178
|
+
|
|
179
|
+
# Create multiple put operations
|
|
180
|
+
tasks = []
|
|
181
|
+
for i in range(5):
|
|
182
|
+
data = TensorDict({f"concurrent_data_{i}": torch.randn(4, 8)}, batch_size=[4])
|
|
183
|
+
task = put_operation(f"concurrent_partition_{i}", data)
|
|
184
|
+
tasks.append(task)
|
|
185
|
+
|
|
186
|
+
# Execute concurrently
|
|
187
|
+
results = await asyncio.gather(*tasks)
|
|
188
|
+
|
|
189
|
+
assert len(results) == 5
|
|
190
|
+
logger.info(f"Completed {len(results)} concurrent put operations")
|
|
191
|
+
|
|
192
|
+
# Cleanup
|
|
193
|
+
for partition in results:
|
|
194
|
+
try:
|
|
195
|
+
await client.async_clear(partition)
|
|
196
|
+
except Exception as e:
|
|
197
|
+
logger.warning(f"Failed to clear partition {partition}: {e}")
|
|
198
|
+
|
|
199
|
+
client.close()
|
|
200
|
+
finally:
|
|
201
|
+
await test_instance.teardown()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@pytest.mark.asyncio
|
|
205
|
+
async def test_sequential_async_put_with_verification():
|
|
206
|
+
"""Test sequential async_put operations with data verification."""
|
|
207
|
+
test_instance = TestMultipleAsyncPut()
|
|
208
|
+
await test_instance.setup()
|
|
209
|
+
|
|
210
|
+
try:
|
|
211
|
+
client = await test_instance.create_client("SequentialTestClient")
|
|
212
|
+
|
|
213
|
+
# Test data
|
|
214
|
+
test_cases = [
|
|
215
|
+
("sequential_1", torch.randn(4, 5)),
|
|
216
|
+
("sequential_2", torch.randn(4, 10)),
|
|
217
|
+
("sequential_3", torch.randn(4, 15)),
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
for partition_id, tensor_data in test_cases:
|
|
221
|
+
data = TensorDict({"sequential_data": tensor_data}, batch_size=[4])
|
|
222
|
+
|
|
223
|
+
# Put data
|
|
224
|
+
await client.async_put(data=data, partition_id=partition_id)
|
|
225
|
+
logger.info(f"Put data to {partition_id}")
|
|
226
|
+
|
|
227
|
+
# Verify by reading back
|
|
228
|
+
metadata = await client.async_get_meta(
|
|
229
|
+
data_fields=["sequential_data"],
|
|
230
|
+
batch_size=4,
|
|
231
|
+
partition_id=partition_id,
|
|
232
|
+
mode="fetch",
|
|
233
|
+
task_name="verification_task",
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
retrieved_data = await client.async_get_data(metadata)
|
|
237
|
+
|
|
238
|
+
# Verify shape and content
|
|
239
|
+
assert retrieved_data["sequential_data"].shape == tensor_data.shape
|
|
240
|
+
logger.info(f"Verified data in {partition_id}")
|
|
241
|
+
|
|
242
|
+
# Cleanup
|
|
243
|
+
await client.async_clear(partition_id)
|
|
244
|
+
|
|
245
|
+
client.close()
|
|
246
|
+
finally:
|
|
247
|
+
await test_instance.teardown()
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@pytest.mark.asyncio
|
|
251
|
+
async def test_multiple_puts_same_partition():
|
|
252
|
+
"""Test multiple puts to the same partition."""
|
|
253
|
+
test_instance = TestMultipleAsyncPut()
|
|
254
|
+
await test_instance.setup()
|
|
255
|
+
|
|
256
|
+
try:
|
|
257
|
+
client = await test_instance.create_client("SamePartitionTestClient")
|
|
258
|
+
|
|
259
|
+
partition_id = "same_partition_test"
|
|
260
|
+
batch_size = 4
|
|
261
|
+
|
|
262
|
+
# First put
|
|
263
|
+
data1 = TensorDict({"first_batch": torch.randn(batch_size, 8)}, batch_size=[batch_size])
|
|
264
|
+
|
|
265
|
+
await client.async_put(data=data1, partition_id=partition_id)
|
|
266
|
+
logger.info("First put completed")
|
|
267
|
+
|
|
268
|
+
# Second put to same partition
|
|
269
|
+
data2 = TensorDict({"second_batch": torch.randn(batch_size, 8) * 2}, batch_size=[batch_size])
|
|
270
|
+
|
|
271
|
+
await client.async_put(data=data2, partition_id=partition_id)
|
|
272
|
+
logger.info("Second put completed")
|
|
273
|
+
|
|
274
|
+
# Third put to same partition
|
|
275
|
+
data3 = TensorDict({"third_batch": torch.randn(batch_size, 8) * 3}, batch_size=[batch_size])
|
|
276
|
+
|
|
277
|
+
await client.async_put(data=data3, partition_id=partition_id)
|
|
278
|
+
logger.info("Third put completed")
|
|
279
|
+
|
|
280
|
+
# Verify the last data is what we get
|
|
281
|
+
metadata = await client.async_get_meta(
|
|
282
|
+
data_fields=["third_batch"],
|
|
283
|
+
batch_size=batch_size,
|
|
284
|
+
partition_id=partition_id,
|
|
285
|
+
mode="fetch",
|
|
286
|
+
task_name="verification_task",
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
retrieved_data = await client.async_get_data(metadata)
|
|
290
|
+
|
|
291
|
+
assert "third_batch" in retrieved_data
|
|
292
|
+
assert retrieved_data["third_batch"].shape == (batch_size, 8)
|
|
293
|
+
logger.info("Verified multiple puts to same partition")
|
|
294
|
+
|
|
295
|
+
# Cleanup
|
|
296
|
+
await client.async_clear(partition_id)
|
|
297
|
+
|
|
298
|
+
client.close()
|
|
299
|
+
finally:
|
|
300
|
+
await test_instance.teardown()
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@pytest.mark.asyncio
|
|
304
|
+
async def test_simple_multiple_async_put():
|
|
305
|
+
"""Simple test using the existing client_setup fixture."""
|
|
306
|
+
test_instance = TestMultipleAsyncPut()
|
|
307
|
+
await test_instance.setup()
|
|
308
|
+
|
|
309
|
+
try:
|
|
310
|
+
client = await test_instance.create_client("SimpleTestClient")
|
|
311
|
+
|
|
312
|
+
# Test basic multiple puts
|
|
313
|
+
for i in range(3):
|
|
314
|
+
data = TensorDict({f"simple_data_{i}": torch.randn(2, 4)}, batch_size=[2])
|
|
315
|
+
|
|
316
|
+
partition_id = f"simple_partition_{i}"
|
|
317
|
+
await client.async_put(data=data, partition_id=partition_id)
|
|
318
|
+
logger.info(f"Put data to {partition_id}")
|
|
319
|
+
|
|
320
|
+
# Cleanup each partition
|
|
321
|
+
await client.async_clear(partition_id)
|
|
322
|
+
|
|
323
|
+
logger.info("Simple multiple async_put test completed")
|
|
324
|
+
|
|
325
|
+
client.close()
|
|
326
|
+
finally:
|
|
327
|
+
await test_instance.teardown()
|