TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,443 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+ import time
17
+ from pathlib import Path
18
+
19
+ import pytest
20
+ import ray
21
+ import tensordict
22
+ import torch
23
+ import zmq
24
+ from tensordict import TensorDict
25
+
26
+ # Setup path
27
+ parent_dir = Path(__file__).resolve().parent.parent
28
+ sys.path.append(str(parent_dir))
29
+
30
+ from transfer_queue import SimpleStorageUnit # noqa: E402
31
+ from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType # noqa: E402
32
+
33
+
34
+ class MockStorageClient:
35
+ """Mock client for testing storage unit operations."""
36
+
37
+ def __init__(self, storage_put_get_address):
38
+ self.context = zmq.Context()
39
+ self.socket = self.context.socket(zmq.DEALER)
40
+ self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
41
+ self.socket.connect(storage_put_get_address)
42
+
43
+ def send_put(self, client_id, local_indexes, field_data):
44
+ msg = ZMQMessage.create(
45
+ request_type=ZMQRequestType.PUT_DATA,
46
+ sender_id=f"mock_client_{client_id}",
47
+ body={"local_indexes": local_indexes, "data": field_data},
48
+ )
49
+ self.socket.send(msg.serialize())
50
+ return ZMQMessage.deserialize(self.socket.recv())
51
+
52
+ def send_get(self, client_id, local_indexes, fields):
53
+ msg = ZMQMessage.create(
54
+ request_type=ZMQRequestType.GET_DATA,
55
+ sender_id=f"mock_client_{client_id}",
56
+ body={"local_indexes": local_indexes, "fields": fields},
57
+ )
58
+ self.socket.send(msg.serialize())
59
+ return ZMQMessage.deserialize(self.socket.recv())
60
+
61
+ def send_clear(self, client_id, local_indexes):
62
+ msg = ZMQMessage.create(
63
+ request_type=ZMQRequestType.CLEAR_DATA,
64
+ sender_id=f"mock_client_{client_id}",
65
+ body={"local_indexes": local_indexes},
66
+ )
67
+ self.socket.send(msg.serialize())
68
+ return ZMQMessage.deserialize(self.socket.recv())
69
+
70
+ def close(self):
71
+ self.socket.close()
72
+ self.context.term()
73
+
74
+
75
+ @pytest.fixture(scope="session")
76
+ def ray_setup():
77
+ """Initialize Ray for testing."""
78
+ ray.init(ignore_reinit_error=True)
79
+ yield
80
+ ray.shutdown()
81
+
82
+
83
+ @pytest.fixture
84
+ def storage_setup(ray_setup):
85
+ """Set up storage unit for testing."""
86
+ storage_size = 10000
87
+ tensordict.set_list_to_stack(True).set()
88
+
89
+ # Start Ray actor for SimpleStorageUnit
90
+ storage_actor = SimpleStorageUnit.options(max_concurrency=50, num_cpus=1).remote(storage_unit_size=storage_size)
91
+
92
+ # Get ZMQ server info from storage unit
93
+ zmq_info = ray.get(storage_actor.get_zmq_server_info.remote())
94
+ put_get_address = zmq_info.to_addr("put_get_socket")
95
+ time.sleep(1) # Wait for socket to be ready
96
+
97
+ yield storage_actor, put_get_address
98
+
99
+ # Cleanup
100
+ ray.kill(storage_actor)
101
+
102
+
103
+ def test_put_get_single_client(storage_setup):
104
+ """Test basic put and get operations with a single client."""
105
+ _, put_get_address = storage_setup
106
+
107
+ client = MockStorageClient(put_get_address)
108
+
109
+ # PUT data
110
+ local_indexes = [0, 1, 2]
111
+ field_data = TensorDict(
112
+ {
113
+ "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
114
+ "rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])],
115
+ },
116
+ batch_size=[],
117
+ )
118
+
119
+ response = client.send_put(0, local_indexes, field_data)
120
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
121
+
122
+ # GET data
123
+ response = client.send_get(0, [0, 1], ["log_probs", "rewards"])
124
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
125
+
126
+ retrieved_data = response.body["data"]
127
+ assert "log_probs" in retrieved_data
128
+ assert "rewards" in retrieved_data
129
+ assert retrieved_data["log_probs"].size(0) == 2
130
+ assert retrieved_data["rewards"].size(0) == 2
131
+
132
+ # Verify data correctness
133
+ torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
134
+ torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
135
+ torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([10.0]))
136
+ torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([20.0]))
137
+
138
+ client.close()
139
+
140
+
141
+ def test_put_get_multiple_clients(storage_setup):
142
+ """Test put and get operations with multiple clients."""
143
+ _, put_get_address = storage_setup
144
+
145
+ num_clients = 3
146
+ clients = [MockStorageClient(put_get_address) for _ in range(num_clients)]
147
+
148
+ # Each client puts unique data using different local_indexes
149
+ for i, client in enumerate(clients):
150
+ local_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2]
151
+ field_data = TensorDict(
152
+ {
153
+ "log_probs": [
154
+ torch.tensor([i, i + 1, i + 2]),
155
+ torch.tensor([i + 3, i + 4, i + 5]),
156
+ torch.tensor([i + 6, i + 7, i + 8]),
157
+ ],
158
+ "rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])],
159
+ }
160
+ )
161
+
162
+ response = client.send_put(i, local_indexes, field_data)
163
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
164
+
165
+ # Test overlapping local indexes
166
+ overlapping_client = MockStorageClient(put_get_address)
167
+ overlap_local_indexes = [0] # Overlaps with first client's index 0
168
+ overlap_field_data = TensorDict({"log_probs": [torch.tensor([999, 999, 999])], "rewards": [torch.tensor([999])]})
169
+ response = overlapping_client.send_put(99, overlap_local_indexes, overlap_field_data)
170
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
171
+
172
+ # Each original client gets its own data (except for index 0 which was overwritten)
173
+ for i, client in enumerate(clients):
174
+ response = client.send_get(i, [i * 10 + 0, i * 10 + 1], ["log_probs", "rewards"])
175
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
176
+
177
+ retrieved_data = response.body["data"]
178
+ assert retrieved_data["log_probs"].size(0) == 2
179
+ assert retrieved_data["rewards"].size(0) == 2
180
+
181
+ # For index 0, expect data from overlapping_client; others from original client
182
+ if i == 0:
183
+ # Index 0 was overwritten
184
+ torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([999, 999, 999]))
185
+ torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([999]))
186
+ # Index 1 remains original
187
+ torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([3, 4, 5]))
188
+ torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([10]))
189
+ else:
190
+ # All data remains original
191
+ torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([i, i + 1, i + 2]))
192
+ torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([i + 3, i + 4, i + 5]))
193
+ torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([i * 10]))
194
+ torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([i * 10 + 10]))
195
+
196
+ # Cleanup
197
+ for client in clients:
198
+ client.close()
199
+ overlapping_client.close()
200
+
201
+
202
+ def test_performance_basic(storage_setup):
203
+ """Basic performance test with larger data volume."""
204
+ _, put_get_address = storage_setup
205
+
206
+ client = MockStorageClient(put_get_address)
207
+
208
+ # PUT performance test
209
+ put_latencies = []
210
+ num_puts = 10 # Reduced for faster testing
211
+ batch_size = 16 # Reduced for faster testing
212
+
213
+ for i in range(num_puts):
214
+ start = time.time()
215
+
216
+ # Use batch size and index mapping
217
+ local_indexes = list(range(i * batch_size, (i + 1) * batch_size))
218
+
219
+ # Create tensor data
220
+ log_probs_data = []
221
+ rewards_data = []
222
+
223
+ for _ in range(batch_size):
224
+ # Smaller tensors for faster testing
225
+ log_probs_tensor = torch.randn(100)
226
+ rewards_tensor = torch.randn(100)
227
+ log_probs_data.append(log_probs_tensor)
228
+ rewards_data.append(rewards_tensor)
229
+
230
+ field_data = TensorDict({"log_probs": log_probs_data, "rewards": rewards_data}, batch_size=[batch_size])
231
+
232
+ response = client.send_put(0, local_indexes, field_data)
233
+ latency = time.time() - start
234
+ put_latencies.append(latency)
235
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
236
+
237
+ # GET performance test
238
+ get_latencies = []
239
+ num_gets = 10
240
+
241
+ for i in range(num_gets):
242
+ start = time.time()
243
+ # Retrieve batch of data
244
+ local_indexes = list(range(i * batch_size, (i + 1) * batch_size))
245
+ response = client.send_get(0, local_indexes, ["log_probs", "rewards"])
246
+ latency = time.time() - start
247
+ get_latencies.append(latency)
248
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
249
+
250
+ avg_put_latency = sum(put_latencies) / len(put_latencies) * 1000 # ms
251
+ avg_get_latency = sum(get_latencies) / len(get_latencies) * 1000 # ms
252
+
253
+ # More lenient performance thresholds for testing environment
254
+ assert avg_put_latency < 1500, f"Avg PUT latency {avg_put_latency}ms exceeds threshold"
255
+ assert avg_get_latency < 1500, f"Avg GET latency {avg_get_latency}ms exceeds threshold"
256
+
257
+ client.close()
258
+
259
+
260
+ def test_put_get_nested_tensor(storage_setup):
261
+ """Test put and get operations with nested tensors."""
262
+ _, put_get_address = storage_setup
263
+
264
+ client = MockStorageClient(put_get_address)
265
+
266
+ # PUT data with nested tensors
267
+ local_indexes = [0, 1, 2]
268
+ field_data = TensorDict(
269
+ {
270
+ "variable_length_sequences": [
271
+ torch.tensor([-0.5, -1.2, -0.8]),
272
+ torch.tensor([-0.3, -1.5, -2.1, -0.9]),
273
+ torch.tensor([-1.1, -0.7]),
274
+ ],
275
+ "attention_mask": [torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1]), torch.tensor([1, 1])],
276
+ },
277
+ batch_size=[],
278
+ )
279
+
280
+ response = client.send_put(0, local_indexes, field_data)
281
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
282
+
283
+ # GET data
284
+ response = client.send_get(0, [0, 2], ["variable_length_sequences", "attention_mask"])
285
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
286
+
287
+ retrieved_data = response.body["data"]
288
+ assert "variable_length_sequences" in retrieved_data
289
+ assert "attention_mask" in retrieved_data
290
+ assert retrieved_data["variable_length_sequences"].size(0) == 2
291
+ assert retrieved_data["attention_mask"].size(0) == 2
292
+
293
+ # Verify data correctness
294
+ torch.testing.assert_close(retrieved_data["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
295
+ torch.testing.assert_close(retrieved_data["variable_length_sequences"][1], torch.tensor([-1.1, -0.7]))
296
+ torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))
297
+ torch.testing.assert_close(retrieved_data["attention_mask"][1], torch.tensor([1, 1]))
298
+
299
+ client.close()
300
+
301
+
302
+ def test_put_get_non_tensor_data(storage_setup):
303
+ """Test put and get operations with non-tensor data (strings)."""
304
+ _, put_get_address = storage_setup
305
+
306
+ client = MockStorageClient(put_get_address)
307
+
308
+ # PUT data with non-tensor data
309
+ local_indexes = [0, 1, 2]
310
+ field_data = TensorDict(
311
+ {
312
+ "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
313
+ "response_text": ["Hi there!", "This is the response to the longer sentence", "Test response"],
314
+ },
315
+ batch_size=[],
316
+ )
317
+
318
+ response = client.send_put(0, local_indexes, field_data)
319
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
320
+
321
+ # GET data
322
+ response = client.send_get(0, [0, 1, 2], ["prompt_text", "response_text"])
323
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
324
+
325
+ retrieved_data = response.body["data"]
326
+ assert "prompt_text" in retrieved_data
327
+ assert "response_text" in retrieved_data
328
+
329
+ # Verify data correctness
330
+ assert isinstance(retrieved_data["prompt_text"][0], str)
331
+ assert isinstance(retrieved_data["response_text"][0], str)
332
+
333
+ assert retrieved_data["prompt_text"][0] == "Hello world!"
334
+ assert retrieved_data["prompt_text"][1] == "This is a longer sentence for testing"
335
+ assert retrieved_data["prompt_text"][2] == "Test case"
336
+ assert retrieved_data["response_text"][0] == "Hi there!"
337
+ assert retrieved_data["response_text"][1] == "This is the response to the longer sentence"
338
+ assert retrieved_data["response_text"][2] == "Test response"
339
+
340
+ client.close()
341
+
342
+
343
+ def test_put_get_single_item(storage_setup):
344
+ """Test put and get operations for a single item."""
345
+ _, put_get_address = storage_setup
346
+
347
+ client = MockStorageClient(put_get_address)
348
+
349
+ # PUT single item data
350
+ field_data = TensorDict(
351
+ {
352
+ "prompt_text": ["Hello world!"],
353
+ "attention_mask": [torch.tensor([1, 1, 1])],
354
+ },
355
+ batch_size=[],
356
+ )
357
+
358
+ response = client.send_put(0, [0], field_data)
359
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
360
+
361
+ # GET data
362
+ response = client.send_get(0, [0], ["prompt_text", "attention_mask"])
363
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
364
+
365
+ retrieved_data = response.body["data"]
366
+ assert "prompt_text" in retrieved_data
367
+ assert "attention_mask" in retrieved_data
368
+
369
+ assert retrieved_data["prompt_text"][0] == "Hello world!"
370
+ assert retrieved_data["attention_mask"].shape == (1, 3)
371
+ torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))
372
+
373
+ client.close()
374
+
375
+
376
+ def test_clear_data(storage_setup):
377
+ """Test clear operations."""
378
+ _, put_get_address = storage_setup
379
+
380
+ client = MockStorageClient(put_get_address)
381
+
382
+ # PUT data first
383
+ local_indexes = [0, 1, 2]
384
+ field_data = TensorDict(
385
+ {
386
+ "log_probs": [torch.tensor([1.0]), torch.tensor([2.0]), torch.tensor([3.0])],
387
+ "rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])],
388
+ },
389
+ batch_size=[],
390
+ )
391
+
392
+ response = client.send_put(0, local_indexes, field_data)
393
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
394
+
395
+ # Verify data exists
396
+ response = client.send_get(0, [0, 1, 2], ["log_probs"])
397
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
398
+ assert response.body["data"]["log_probs"].size(0) == 3
399
+
400
+ # Clear data
401
+ response = client.send_clear(0, [0, 2]) # Clear only indexes 0 and 2
402
+ assert response.request_type == ZMQRequestType.CLEAR_DATA_RESPONSE
403
+
404
+ # Verify some data is cleared (but index 1 should still exist)
405
+ response = client.send_get(0, [1], ["log_probs"])
406
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
407
+ assert response.body["data"]["log_probs"].size(0) == 1
408
+ torch.testing.assert_close(response.body["data"]["log_probs"][0], torch.tensor([2.0]))
409
+
410
+ client.close()
411
+
412
+
413
+ def test_storage_unit_data_direct():
414
+ """Test StorageUnitData class directly without ZMQ."""
415
+ from transfer_queue.storage import StorageUnitData
416
+
417
+ storage_data = StorageUnitData(storage_size=10)
418
+
419
+ # Test put_data
420
+ field_data = TensorDict(
421
+ {
422
+ "log_probs": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])],
423
+ "rewards": [torch.tensor([10.0]), torch.tensor([20.0])],
424
+ },
425
+ batch_size=[],
426
+ )
427
+ storage_data.put_data(field_data, [0, 1])
428
+
429
+ # Test get_data
430
+ result = storage_data.get_data(["log_probs", "rewards"], [0, 1])
431
+ assert "log_probs" in result
432
+ assert "rewards" in result
433
+ assert result["log_probs"].size(0) == 2
434
+ assert result["rewards"].size(0) == 2
435
+
436
+ # Test single index get
437
+ result_single = storage_data.get_data(["log_probs"], [0])
438
+ assert result_single["log_probs"].shape == (1, 2)
439
+
440
+ # Test clear
441
+ storage_data.clear([0])
442
+ result_after_clear = storage_data.get_data(["log_probs"], [0])
443
+ assert result_after_clear["log_probs"][0] is None
@@ -0,0 +1,45 @@
1
+ import unittest
2
+ from importlib.util import find_spec
3
+
4
+ import pytest
5
+ import torch
6
+
7
+ from transfer_queue.storage.clients.factory import StorageClientFactory
8
+ from transfer_queue.storage.clients.yuanrong_client import YRStorageClient
9
+
10
+
11
+ class Test(unittest.TestCase):
12
+ def setUp(self):
13
+ self.cfg = {"host": "127.0.0.1", "port": 31501, "device_id": 0}
14
+
15
+ @pytest.mark.skipif(find_spec("datasystem") is None, reason="datasystem is not available")
16
+ def test_create_client(self):
17
+ self.assertIn("Yuanrong", StorageClientFactory._registry)
18
+ self.assertIs(StorageClientFactory._registry["Yuanrong"], YRStorageClient)
19
+ StorageClientFactory.create("Yuanrong", self.cfg)
20
+
21
+ with self.assertRaises(ValueError) as cm:
22
+ StorageClientFactory.create("abc", self.cfg)
23
+ self.assertIn("Unknown StorageClient", str(cm.exception))
24
+
25
+ @pytest.mark.skipif(
26
+ find_spec("torch_npu") is None or find_spec("datasystem") is None, reason="torch_npu is not available"
27
+ )
28
+ def test_client_create_empty_tensorlist(self):
29
+ tensors = [torch.Tensor([2, 1]), torch.Tensor([1, 5]), torch.Tensor([0]), torch.Tensor([-1.5])]
30
+ shapes = []
31
+ dtypes = []
32
+ for t in tensors:
33
+ shapes.append(t.shape)
34
+ dtypes.append(t.dtype)
35
+ client = StorageClientFactory.create("Yuanrong", self.cfg)
36
+
37
+ empty_tensors = client._create_empty_tensorlist(shapes, dtypes)
38
+ self.assertEqual(len(tensors), len(empty_tensors))
39
+ for t, et in zip(tensors, empty_tensors, strict=False):
40
+ self.assertEqual(t.shape, et.shape)
41
+ self.assertEqual(t.dtype, et.dtype)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ unittest.main()
@@ -0,0 +1,48 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ from .client import (
18
+ AsyncTransferQueueClient,
19
+ TransferQueueClient,
20
+ process_zmq_server_info,
21
+ )
22
+ from .controller import TransferQueueController
23
+ from .metadata import BatchMeta
24
+ from .sampler import BaseSampler
25
+ from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
26
+ from .sampler.sequential_sampler import SequentialSampler
27
+ from .storage import SimpleStorageUnit
28
+ from .utils.utils import get_placement_group
29
+ from .utils.zmq_utils import ZMQServerInfo
30
+
31
+ __all__ = [
32
+ "AsyncTransferQueueClient",
33
+ "BatchMeta",
34
+ "TransferQueueClient",
35
+ "TransferQueueController",
36
+ "SimpleStorageUnit",
37
+ "ZMQServerInfo",
38
+ "process_zmq_server_info",
39
+ "get_placement_group",
40
+ "BaseSampler",
41
+ "GRPOGroupNSampler",
42
+ "SequentialSampler",
43
+ ]
44
+
45
+ version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
46
+
47
+ with open(os.path.join(version_folder, "version/version")) as f:
48
+ __version__ = f.read().strip()