TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
tests/test_put.py ADDED
@@ -0,0 +1,327 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import logging
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ import pytest
21
+ import ray
22
+ import torch
23
+ from tensordict import TensorDict
24
+
25
+ parent_dir = Path(__file__).resolve().parent.parent
26
+ sys.path.append(str(parent_dir))
27
+
28
+ from transfer_queue import ( # noqa: E402
29
+ AsyncTransferQueueClient,
30
+ SimpleStorageUnit,
31
+ TransferQueueController,
32
+ process_zmq_server_info,
33
+ )
34
+ from transfer_queue.utils.utils import get_placement_group # noqa: E402
35
+
36
+ # Set up logging
37
+ logging.basicConfig(level=logging.INFO)
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ @pytest.fixture(scope="module")
42
+ def ray_setup():
43
+ """Initialize Ray for testing."""
44
+ ray.init(ignore_reinit_error=True)
45
+ yield
46
+ ray.shutdown()
47
+
48
+
49
+ @pytest.fixture(scope="module")
50
+ def data_system_setup(ray_setup):
51
+ """Set up data system for testing."""
52
+ # Initialize storage units
53
+ num_storage_units = 2
54
+ storage_size = 10000
55
+ storage_units = {}
56
+
57
+ storage_placement_group = get_placement_group(num_storage_units, num_cpus_per_actor=1)
58
+ for storage_unit_rank in range(num_storage_units):
59
+ storage_node = SimpleStorageUnit.options(
60
+ placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
61
+ ).remote(storage_unit_size=storage_size)
62
+ storage_units[storage_unit_rank] = storage_node
63
+ logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")
64
+
65
+ # Initialize controller
66
+ controller = TransferQueueController.remote()
67
+ logger.info("TransferQueueController has been created.")
68
+
69
+ # Prepare connection info
70
+ controller_info = process_zmq_server_info(controller)
71
+ storage_unit_infos = process_zmq_server_info(storage_units)
72
+
73
+ # Create config
74
+ config = {
75
+ "controller_info": controller_info,
76
+ "storage_unit_infos": storage_unit_infos,
77
+ }
78
+
79
+ yield controller, storage_units, config
80
+
81
+ # Cleanup
82
+ ray.kill(controller)
83
+ for storage_unit in storage_units.values():
84
+ ray.kill(storage_unit)
85
+
86
+
87
+ @pytest.fixture
88
+ async def client_setup(data_system_setup):
89
+ """Set up client for testing."""
90
+ controller, storage_units, config = data_system_setup
91
+
92
+ client = AsyncTransferQueueClient(
93
+ client_id="TestClient",
94
+ controller_info=config["controller_info"],
95
+ )
96
+
97
+ client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
98
+
99
+ # Wait a bit for connections to establish
100
+ await asyncio.sleep(1)
101
+
102
+ return client
103
+
104
+
105
+ class TestMultipleAsyncPut:
106
+ """Test class for multiple async_put operations."""
107
+
108
+ def __init__(self):
109
+ self.controller = None
110
+ self.storage_units = None
111
+ self.config = None
112
+
113
+ async def setup(self):
114
+ """Setup for the test class."""
115
+ if not ray.is_initialized():
116
+ ray.init(ignore_reinit_error=True)
117
+
118
+ # Initialize data system
119
+ num_storage_units = 2
120
+ self.storage_units = {}
121
+ storage_placement_group = get_placement_group(num_storage_units, num_cpus_per_actor=1)
122
+
123
+ for i in range(num_storage_units):
124
+ self.storage_units[i] = SimpleStorageUnit.options(
125
+ placement_group=storage_placement_group, placement_group_bundle_index=i
126
+ ).remote(storage_unit_size=10000)
127
+
128
+ self.controller = TransferQueueController.remote()
129
+
130
+ # Wait for initialization
131
+ await asyncio.sleep(2)
132
+
133
+ controller_info = process_zmq_server_info(self.controller)
134
+ storage_unit_infos = process_zmq_server_info(self.storage_units)
135
+
136
+ self.config = {
137
+ "controller_info": controller_info,
138
+ "storage_unit_infos": storage_unit_infos,
139
+ }
140
+ logger.info("TestMultipleAsyncPut setup completed")
141
+
142
+ async def teardown(self):
143
+ """Teardown for the test class."""
144
+ if self.controller:
145
+ ray.kill(self.controller)
146
+ if self.storage_units:
147
+ for storage in self.storage_units.values():
148
+ ray.kill(storage)
149
+ if ray.is_initialized():
150
+ ray.shutdown()
151
+
152
+ async def create_client(self, client_id="TestClient"):
153
+ """Create a new client instance."""
154
+ if self.config is None:
155
+ await self.setup()
156
+
157
+ client = AsyncTransferQueueClient(
158
+ client_id=client_id,
159
+ controller_info=self.config["controller_info"],
160
+ )
161
+ client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config)
162
+ await asyncio.sleep(1) # Wait for connections
163
+ return client
164
+
165
+
166
+ @pytest.mark.asyncio
167
+ async def test_concurrent_async_put():
168
+ """Test concurrent async_put operations."""
169
+ test_instance = TestMultipleAsyncPut()
170
+ await test_instance.setup()
171
+
172
+ try:
173
+ client = await test_instance.create_client("ConcurrentTestClient")
174
+
175
+ async def put_operation(partition_id, data):
176
+ await client.async_put(data=data, partition_id=partition_id)
177
+ return partition_id
178
+
179
+ # Create multiple put operations
180
+ tasks = []
181
+ for i in range(5):
182
+ data = TensorDict({f"concurrent_data_{i}": torch.randn(4, 8)}, batch_size=[4])
183
+ task = put_operation(f"concurrent_partition_{i}", data)
184
+ tasks.append(task)
185
+
186
+ # Execute concurrently
187
+ results = await asyncio.gather(*tasks)
188
+
189
+ assert len(results) == 5
190
+ logger.info(f"Completed {len(results)} concurrent put operations")
191
+
192
+ # Cleanup
193
+ for partition in results:
194
+ try:
195
+ await client.async_clear(partition)
196
+ except Exception as e:
197
+ logger.warning(f"Failed to clear partition {partition}: {e}")
198
+
199
+ client.close()
200
+ finally:
201
+ await test_instance.teardown()
202
+
203
+
204
+ @pytest.mark.asyncio
205
+ async def test_sequential_async_put_with_verification():
206
+ """Test sequential async_put operations with data verification."""
207
+ test_instance = TestMultipleAsyncPut()
208
+ await test_instance.setup()
209
+
210
+ try:
211
+ client = await test_instance.create_client("SequentialTestClient")
212
+
213
+ # Test data
214
+ test_cases = [
215
+ ("sequential_1", torch.randn(4, 5)),
216
+ ("sequential_2", torch.randn(4, 10)),
217
+ ("sequential_3", torch.randn(4, 15)),
218
+ ]
219
+
220
+ for partition_id, tensor_data in test_cases:
221
+ data = TensorDict({"sequential_data": tensor_data}, batch_size=[4])
222
+
223
+ # Put data
224
+ await client.async_put(data=data, partition_id=partition_id)
225
+ logger.info(f"Put data to {partition_id}")
226
+
227
+ # Verify by reading back
228
+ metadata = await client.async_get_meta(
229
+ data_fields=["sequential_data"],
230
+ batch_size=4,
231
+ partition_id=partition_id,
232
+ mode="fetch",
233
+ task_name="verification_task",
234
+ )
235
+
236
+ retrieved_data = await client.async_get_data(metadata)
237
+
238
+ # Verify shape and content
239
+ assert retrieved_data["sequential_data"].shape == tensor_data.shape
240
+ logger.info(f"Verified data in {partition_id}")
241
+
242
+ # Cleanup
243
+ await client.async_clear(partition_id)
244
+
245
+ client.close()
246
+ finally:
247
+ await test_instance.teardown()
248
+
249
+
250
+ @pytest.mark.asyncio
251
+ async def test_multiple_puts_same_partition():
252
+ """Test multiple puts to the same partition."""
253
+ test_instance = TestMultipleAsyncPut()
254
+ await test_instance.setup()
255
+
256
+ try:
257
+ client = await test_instance.create_client("SamePartitionTestClient")
258
+
259
+ partition_id = "same_partition_test"
260
+ batch_size = 4
261
+
262
+ # First put
263
+ data1 = TensorDict({"first_batch": torch.randn(batch_size, 8)}, batch_size=[batch_size])
264
+
265
+ await client.async_put(data=data1, partition_id=partition_id)
266
+ logger.info("First put completed")
267
+
268
+ # Second put to same partition
269
+ data2 = TensorDict({"second_batch": torch.randn(batch_size, 8) * 2}, batch_size=[batch_size])
270
+
271
+ await client.async_put(data=data2, partition_id=partition_id)
272
+ logger.info("Second put completed")
273
+
274
+ # Third put to same partition
275
+ data3 = TensorDict({"third_batch": torch.randn(batch_size, 8) * 3}, batch_size=[batch_size])
276
+
277
+ await client.async_put(data=data3, partition_id=partition_id)
278
+ logger.info("Third put completed")
279
+
280
+ # Verify the last data is what we get
281
+ metadata = await client.async_get_meta(
282
+ data_fields=["third_batch"],
283
+ batch_size=batch_size,
284
+ partition_id=partition_id,
285
+ mode="fetch",
286
+ task_name="verification_task",
287
+ )
288
+
289
+ retrieved_data = await client.async_get_data(metadata)
290
+
291
+ assert "third_batch" in retrieved_data
292
+ assert retrieved_data["third_batch"].shape == (batch_size, 8)
293
+ logger.info("Verified multiple puts to same partition")
294
+
295
+ # Cleanup
296
+ await client.async_clear(partition_id)
297
+
298
+ client.close()
299
+ finally:
300
+ await test_instance.teardown()
301
+
302
+
303
+ @pytest.mark.asyncio
304
+ async def test_simple_multiple_async_put():
305
+ """Simple test using the existing client_setup fixture."""
306
+ test_instance = TestMultipleAsyncPut()
307
+ await test_instance.setup()
308
+
309
+ try:
310
+ client = await test_instance.create_client("SimpleTestClient")
311
+
312
+ # Test basic multiple puts
313
+ for i in range(3):
314
+ data = TensorDict({f"simple_data_{i}": torch.randn(2, 4)}, batch_size=[2])
315
+
316
+ partition_id = f"simple_partition_{i}"
317
+ await client.async_put(data=data, partition_id=partition_id)
318
+ logger.info(f"Put data to {partition_id}")
319
+
320
+ # Cleanup each partition
321
+ await client.async_clear(partition_id)
322
+
323
+ logger.info("Simple multiple async_put test completed")
324
+
325
+ client.close()
326
+ finally:
327
+ await test_instance.teardown()