TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
tests/test_client.py ADDED
@@ -0,0 +1,423 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+ import time
17
+ from pathlib import Path
18
+ from threading import Thread
19
+ from unittest.mock import patch
20
+
21
+ import pytest
22
+ import torch
23
+ import zmq
24
+ from tensordict import NonTensorStack, TensorDict
25
+
26
+ # Import your classes here
27
+ parent_dir = Path(__file__).resolve().parent.parent
28
+ sys.path.append(str(parent_dir))
29
+
30
+ from transfer_queue import TransferQueueClient # noqa: E402
31
+ from transfer_queue.metadata import ( # noqa: E402
32
+ BatchMeta,
33
+ FieldMeta,
34
+ SampleMeta,
35
+ )
36
+ from transfer_queue.utils.utils import TransferQueueRole # noqa: E402
37
+ from transfer_queue.utils.zmq_utils import ( # noqa: E402
38
+ ZMQMessage,
39
+ ZMQRequestType,
40
+ ZMQServerInfo,
41
+ )
42
+
43
+ TEST_DATA = TensorDict(
44
+ {
45
+ "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
46
+ "variable_length_sequences": torch.nested.as_nested_tensor(
47
+ [
48
+ torch.tensor([-0.5, -1.2, -0.8]),
49
+ torch.tensor([-0.3, -1.5, -2.1, -0.9]),
50
+ torch.tensor([-1.1, -0.7]),
51
+ ]
52
+ ),
53
+ "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
54
+ },
55
+ batch_size=[3],
56
+ )
57
+
58
+
59
+ # Mock Controller for Client Unit Testing
60
+ class MockController:
61
+ def __init__(self, controller_id="controller_0"):
62
+ self.controller_id = controller_id
63
+ self.context = zmq.Context()
64
+
65
+ # Socket for data requests
66
+ self.request_socket = self.context.socket(zmq.ROUTER)
67
+ self.request_port = self._bind_to_random_port(self.request_socket)
68
+
69
+ self.zmq_server_info = ZMQServerInfo(
70
+ role=TransferQueueRole.CONTROLLER,
71
+ id=controller_id,
72
+ ip="127.0.0.1",
73
+ ports={
74
+ "request_handle_socket": self.request_port,
75
+ },
76
+ )
77
+
78
+ self.running = True
79
+ self.request_thread = Thread(target=self._handle_requests, daemon=True)
80
+ self.request_thread.start()
81
+
82
+ def _bind_to_random_port(self, socket):
83
+ port = socket.bind_to_random_port("tcp://127.0.0.1")
84
+ return port
85
+
86
+ def _handle_requests(self):
87
+ poller = zmq.Poller()
88
+ poller.register(self.request_socket, zmq.POLLIN)
89
+
90
+ while self.running:
91
+ try:
92
+ socks = dict(poller.poll(100)) # 100ms timeout
93
+ if self.request_socket in socks:
94
+ identity, serialized_msg = self.request_socket.recv_multipart()
95
+ request_msg = ZMQMessage.deserialize(serialized_msg)
96
+
97
+ # Determine response based on request type
98
+ if request_msg.request_type == ZMQRequestType.GET_META:
99
+ response_body = self._mock_batch_meta(request_msg.body)
100
+ response_type = ZMQRequestType.GET_META_RESPONSE
101
+ elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
102
+ response_body = self._mock_batch_meta(request_msg.body)
103
+ response_type = ZMQRequestType.GET_CLEAR_META_RESPONSE
104
+ elif request_msg.request_type == ZMQRequestType.CLEAR_META:
105
+ response_body = {"message": "clear ok"}
106
+ response_type = ZMQRequestType.CLEAR_META_RESPONSE
107
+
108
+ # Send response
109
+ response_msg = ZMQMessage.create(
110
+ request_type=response_type,
111
+ sender_id=self.controller_id,
112
+ receiver_id=request_msg.sender_id,
113
+ body=response_body,
114
+ )
115
+ self.request_socket.send_multipart([identity, response_msg.serialize()])
116
+ except zmq.Again:
117
+ continue
118
+ except Exception as e:
119
+ print(f"MockController ERROR: {e}")
120
+ raise
121
+
122
+ def _mock_batch_meta(self, request_body):
123
+ batch_size = request_body.get("batch_size", 1)
124
+ data_fields = request_body.get("data_fields", [])
125
+
126
+ samples = []
127
+ for i in range(batch_size):
128
+ fields = []
129
+ for field_name in data_fields:
130
+ field_meta = FieldMeta(
131
+ name=field_name,
132
+ dtype=None,
133
+ shape=None,
134
+ production_status=0,
135
+ )
136
+ fields.append(field_meta)
137
+ sample = SampleMeta(
138
+ partition_id="0",
139
+ global_index=i,
140
+ fields={field.name: field for field in fields},
141
+ )
142
+ samples.append(sample)
143
+ metadata = BatchMeta(samples=samples)
144
+
145
+ return {"metadata": metadata}
146
+
147
+ def stop(self):
148
+ self.running = False
149
+ time.sleep(0.2) # Give thread time to stop
150
+ self.request_socket.close()
151
+ self.context.term()
152
+
153
+
154
+ # Mock Storage for Client Unit Testing
155
+ class MockStorage:
156
+ def __init__(self, storage_id="storage_0"):
157
+ self.storage_id = storage_id
158
+ self.context = zmq.Context()
159
+
160
+ # Socket for data operations
161
+ self.data_socket = self.context.socket(zmq.ROUTER)
162
+ self.data_port = self._bind_to_random_port(self.data_socket)
163
+
164
+ self.zmq_server_info = ZMQServerInfo(
165
+ role=TransferQueueRole.STORAGE,
166
+ id=storage_id,
167
+ ip="127.0.0.1",
168
+ ports={
169
+ "put_get_socket": self.data_port,
170
+ },
171
+ )
172
+
173
+ self.running = True
174
+ self.data_thread = Thread(target=self._handle_data_requests, daemon=True)
175
+ self.data_thread.start()
176
+
177
+ def _bind_to_random_port(self, socket):
178
+ port = socket.bind_to_random_port("tcp://127.0.0.1")
179
+ return port
180
+
181
+ def _handle_data_requests(self):
182
+ poller = zmq.Poller()
183
+ poller.register(self.data_socket, zmq.POLLIN)
184
+
185
+ while self.running:
186
+ try:
187
+ socks = dict(poller.poll(100)) # 100ms timeout
188
+ if self.data_socket in socks:
189
+ identity, msg_bytes = self.data_socket.recv_multipart()
190
+ msg = ZMQMessage.deserialize(msg_bytes)
191
+
192
+ # Handle different request types
193
+ if msg.request_type == ZMQRequestType.PUT_DATA:
194
+ response_body = {"message": "Data stored successfully"}
195
+ response_type = ZMQRequestType.PUT_DATA_RESPONSE
196
+ elif msg.request_type == ZMQRequestType.GET_DATA:
197
+ response_body = self._handle_get_data(msg.body)
198
+ response_type = ZMQRequestType.GET_DATA_RESPONSE
199
+ elif msg.request_type == ZMQRequestType.CLEAR_DATA:
200
+ response_body = {"message": "Data cleared successfully"}
201
+ response_type = ZMQRequestType.CLEAR_DATA_RESPONSE
202
+
203
+ # Send response
204
+ response_msg = ZMQMessage.create(
205
+ request_type=response_type,
206
+ sender_id=self.storage_id,
207
+ receiver_id=msg.sender_id,
208
+ body=response_body,
209
+ )
210
+ self.data_socket.send_multipart([identity, response_msg.serialize()])
211
+ except zmq.Again:
212
+ continue
213
+ except Exception as e:
214
+ if self.running:
215
+ print(f"MockStorage running exception: {e}")
216
+ else:
217
+ print(f"MockStorage ERROR: {e}")
218
+ raise
219
+
220
+ def _handle_get_data(self, request_body):
221
+ """Handle GET_DATA request by retrieving stored data"""
222
+ local_indexes = request_body.get("local_indexes", [])
223
+ fields = request_body.get("fields", [])
224
+
225
+ result: dict[str, list] = {}
226
+ for field in fields:
227
+ gathered_items = [TEST_DATA[field][i] for i in local_indexes]
228
+
229
+ if gathered_items:
230
+ all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
231
+ if all_tensors:
232
+ result[field] = torch.nested.as_nested_tensor(gathered_items)
233
+ else:
234
+ result[field] = NonTensorStack(*gathered_items)
235
+
236
+ return {"data": TensorDict(result)}
237
+
238
+ def stop(self):
239
+ self.running = False
240
+ time.sleep(0.2) # Give thread time to stop
241
+ self.data_socket.close()
242
+ self.context.term()
243
+
244
+
245
+ # Test Fixtures
246
+ @pytest.fixture
247
+ def mock_controller():
248
+ controller = MockController()
249
+ yield controller
250
+ controller.stop()
251
+
252
+
253
+ @pytest.fixture
254
+ def mock_storage():
255
+ storage = MockStorage()
256
+ yield storage
257
+ storage.stop()
258
+
259
+
260
+ @pytest.fixture
261
+ def client_setup(mock_controller, mock_storage):
262
+ # Create client with mock controller and storage
263
+ client_id = "client_0"
264
+
265
+ client = TransferQueueClient(
266
+ client_id=client_id,
267
+ controller_info=mock_controller.zmq_server_info,
268
+ )
269
+
270
+ # Mock the storage manager to avoid handshake issues but mock all data operations
271
+ with patch(
272
+ "transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller"
273
+ ):
274
+ config = {
275
+ "controller_info": mock_controller.zmq_server_info,
276
+ "storage_unit_infos": {mock_storage.storage_id: mock_storage.zmq_server_info},
277
+ }
278
+ client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
279
+
280
+ # Mock all storage manager methods to avoid real ZMQ operations
281
+ async def mock_put_data(data, metadata):
282
+ pass # Just pretend to store the data
283
+
284
+ async def mock_get_data(metadata):
285
+ # Return the test data when requested
286
+ return TEST_DATA
287
+
288
+ async def mock_clear_data(metadata):
289
+ pass # Just pretend to clear the data
290
+
291
+ client.storage_manager.put_data = mock_put_data
292
+ client.storage_manager.get_data = mock_get_data
293
+ client.storage_manager.clear_data = mock_clear_data
294
+
295
+ yield client, mock_controller, mock_storage
296
+
297
+
298
+ # Test basic functionality
299
+ def test_client_initialization(client_setup):
300
+ """Test client initialization and connection setup"""
301
+ client, mock_controller, mock_storage = client_setup
302
+
303
+ assert client.client_id is not None
304
+ assert client._controller is not None
305
+ assert client._controller.id == mock_controller.controller_id
306
+
307
+
308
+ def test_put_and_get_data(client_setup):
309
+ """Test basic put and get operations"""
310
+ client, _, _ = client_setup
311
+
312
+ # Test put operation
313
+ client.put(data=TEST_DATA, partition_id="0")
314
+
315
+ # Get metadata for retrieving data
316
+ metadata = client.get_meta(
317
+ data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, partition_id="0"
318
+ )
319
+
320
+ # Test get operation
321
+ result = client.get_data(metadata)
322
+
323
+ # Verify result structure
324
+ assert "log_probs" in result
325
+ assert "variable_length_sequences" in result
326
+ assert "prompt_text" in result
327
+
328
+ torch.testing.assert_close(result["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
329
+ torch.testing.assert_close(result["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
330
+ torch.testing.assert_close(result["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
331
+ torch.testing.assert_close(result["variable_length_sequences"][1], torch.tensor([-0.3, -1.5, -2.1, -0.9]))
332
+ assert result["prompt_text"][0] == "Hello world!"
333
+ assert result["prompt_text"][1] == "This is a longer sentence for testing"
334
+
335
+
336
+ def test_get_meta(client_setup):
337
+ """Test metadata retrieval"""
338
+ client, _, _ = client_setup
339
+
340
+ # Test get_meta operation
341
+ metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=10, partition_id="0")
342
+
343
+ # Verify metadata structure
344
+ assert hasattr(metadata, "global_indexes")
345
+ assert hasattr(metadata, "field_names")
346
+ assert hasattr(metadata, "size")
347
+ assert len(metadata.global_indexes) == 10
348
+
349
+
350
+ def test_clear_operation(client_setup):
351
+ """Test clear operation"""
352
+ client, _, _ = client_setup
353
+
354
+ # Test clear operation
355
+ client.clear(partition_id="0")
356
+
357
+
358
+ # Test with single controller and multiple storage units
359
+ def test_single_controller_multiple_storages():
360
+ """Test client with single controller and multiple storage units"""
361
+ # Create single controller and multiple storage units
362
+ controller = MockController("controller_0")
363
+ storages = [MockStorage(f"storage_{i}") for i in range(3)]
364
+
365
+ try:
366
+ # Create client with single controller
367
+ client_id = "client_test_single_controller"
368
+
369
+ client = TransferQueueClient(client_id=client_id, controller_info=controller.zmq_server_info)
370
+
371
+ # Mock the storage manager to avoid handshake issues but mock all data operations
372
+ with patch(
373
+ "transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller"
374
+ ):
375
+ config = {
376
+ "controller_info": controller.zmq_server_info,
377
+ "storage_unit_infos": {s.storage_id: s.zmq_server_info for s in storages},
378
+ }
379
+ client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
380
+
381
+ # Mock all storage manager methods to avoid real ZMQ operations
382
+ async def mock_put_data(data, metadata):
383
+ pass # Just pretend to store the data
384
+
385
+ async def mock_get_data(metadata):
386
+ # Return some test data when requested
387
+ return TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
388
+
389
+ async def mock_clear_data(metadata):
390
+ pass # Just pretend to clear the data
391
+
392
+ client.storage_manager.put_data = mock_put_data
393
+ client.storage_manager.get_data = mock_get_data
394
+ client.storage_manager.clear_data = mock_clear_data
395
+
396
+ # Verify controller is set
397
+ assert client._controller is not None
398
+ assert client._controller.id == controller.controller_id
399
+
400
+ # Test basic operation
401
+ test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
402
+
403
+ # Test put operation
404
+ client.put(data=test_data, partition_id="0")
405
+
406
+ finally:
407
+ # Clean up
408
+ controller.stop()
409
+ for s in storages:
410
+ s.stop()
411
+
412
+
413
+ # Test error handling
414
+ def test_put_without_required_params(client_setup):
415
+ """Test put operation without required parameters"""
416
+ client, _, _ = client_setup
417
+
418
+ # Create test data
419
+ test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
420
+
421
+ # Test put without partition id (should fail)
422
+ with pytest.raises(ValueError):
423
+ client.put(data=test_data)
@@ -0,0 +1,274 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import sys
17
+ from pathlib import Path
18
+
19
+ import pytest
20
+ import ray
21
+ import torch
22
+
23
+ parent_dir = Path(__file__).resolve().parent.parent
24
+ sys.path.append(str(parent_dir))
25
+
26
+ # Set up logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+ from transfer_queue import TransferQueueController # noqa: E402
31
+ from transfer_queue.controller import TQ_INIT_FIELD_NUM # noqa: E402
32
+ from transfer_queue.utils.utils import ProductionStatus # noqa: E402
33
+
34
+
35
+ @pytest.fixture(scope="function")
36
+ def ray_setup():
37
+ if ray.is_initialized():
38
+ ray.shutdown()
39
+ ray.init(
40
+ ignore_reinit_error=True,
41
+ runtime_env={"env_vars": {"RAY_DEBUG": "1", "RAY_DEDUP_LOGS": "0"}},
42
+ log_to_driver=True,
43
+ )
44
+ yield
45
+ if ray.is_initialized():
46
+ ray.shutdown()
47
+ logger.info("Ray has been shut down completely after test")
48
+
49
+
50
+ class TestTransferQueueController:
51
+ def test_controller_with_single_partition(self, ray_setup):
52
+ gbs = 8
53
+ num_n_samples = 4
54
+
55
+ tq_controller = TransferQueueController.remote()
56
+
57
+ # Test get metadata in insert mode
58
+ partition_id = "train_0"
59
+ data_fields = ["prompt_ids", "attention_mask"]
60
+ metadata = ray.get(
61
+ tq_controller.get_metadata.remote(
62
+ data_fields=data_fields,
63
+ batch_size=gbs * num_n_samples,
64
+ partition_id=partition_id,
65
+ mode="insert",
66
+ )
67
+ )
68
+
69
+ assert metadata.global_indexes == list(range(gbs * num_n_samples))
70
+ assert metadata.samples[0].partition_id == "train_0"
71
+ assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in metadata.samples]) == int(
72
+ ProductionStatus.NOT_PRODUCED
73
+ )
74
+ assert sum([int(sample.fields.get("attention_mask").production_status) for sample in metadata.samples]) == int(
75
+ ProductionStatus.NOT_PRODUCED
76
+ )
77
+ partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
78
+ assert partition_index_range == set(range(gbs * num_n_samples))
79
+
80
+ print("✓ Initial get metadata correct")
81
+
82
+ # Test update production status
83
+ success = ray.get(
84
+ tq_controller.update_production_status.remote(
85
+ partition_id=partition_id,
86
+ global_indexes=metadata.global_indexes,
87
+ field_names=metadata.field_names,
88
+ )
89
+ )
90
+ assert success
91
+ partition = ray.get(tq_controller.get_partition.remote(partition_id))
92
+ assert partition.production_status is not None
93
+ assert partition.production_status.size(0) == gbs * num_n_samples
94
+ assert partition.production_status.size(1) == TQ_INIT_FIELD_NUM
95
+ assert torch.equal(
96
+ sum(partition.production_status[:, : len(data_fields)]),
97
+ torch.Tensor([gbs * num_n_samples, gbs * num_n_samples]),
98
+ )
99
+ assert torch.equal(
100
+ sum(partition.production_status[:, len(data_fields) :]),
101
+ torch.zeros(1 * (TQ_INIT_FIELD_NUM - len(data_fields))),
102
+ )
103
+
104
+ print(f"✓ Updated production status for partition {partition_id}")
105
+
106
+ # Test get metadate in fetch mode
107
+ gen_meta = ray.get(
108
+ tq_controller.get_metadata.remote(
109
+ data_fields=["prompt_ids"],
110
+ batch_size=gbs * num_n_samples,
111
+ partition_id=partition_id,
112
+ mode="fetch",
113
+ task_name="generate_sequences",
114
+ )
115
+ )
116
+ assert gen_meta.global_indexes == list(range(gbs * num_n_samples))
117
+ assert gen_meta.samples[0].partition_id == "train_0"
118
+ assert gen_meta.field_names == ["prompt_ids"]
119
+ partition = ray.get(tq_controller.get_partition.remote(partition_id))
120
+ assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples))
121
+ print("✓ Get metadata in fetch mode correct")
122
+
123
+ # Test get clear meta
124
+ clear_meta = ray.get(
125
+ tq_controller.get_metadata.remote(
126
+ data_fields=[],
127
+ partition_id=partition_id,
128
+ mode="insert",
129
+ )
130
+ )
131
+ assert clear_meta.global_indexes == list(range(gbs * num_n_samples))
132
+ assert [sample.fields for sample in clear_meta.samples] == [{}] * (gbs * num_n_samples)
133
+ print("✓ Clear metadata correct")
134
+
135
+ # Test clear
136
+ ray.get(tq_controller.clear.remote(partition_id))
137
+ partition = ray.get(tq_controller.get_partition.remote(partition_id))
138
+ partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
139
+ assert partition_index_range == set()
140
+ assert torch.all(partition.production_status == 0)
141
+ assert torch.all(partition.consumption_status["generate_sequences"] == 0)
142
+ print("✓ Clear correct")
143
+
144
+ def test_controller_with_multi_partitions(self, ray_setup):
145
+ gbs_1 = 8
146
+ num_n_samples_1 = 4
147
+ partition_id_1 = "train_0"
148
+
149
+ gbs_2 = 16
150
+ num_n_samples_2 = 1
151
+ partition_id_2 = "val_0"
152
+
153
+ gbs_3 = 32
154
+ num_n_samples_3 = 2
155
+ partition_id_3 = "train_1"
156
+
157
+ tq_controller = TransferQueueController.remote()
158
+
159
+ # Test get metadata in insert mode
160
+ data_fields = ["prompt_ids", "attention_mask"]
161
+ metadata = ray.get(
162
+ tq_controller.get_metadata.remote(
163
+ data_fields=data_fields,
164
+ batch_size=gbs_1 * num_n_samples_1,
165
+ partition_id=partition_id_1,
166
+ mode="insert",
167
+ )
168
+ )
169
+
170
+ # Test update production status
171
+ success = ray.get(
172
+ tq_controller.update_production_status.remote(
173
+ partition_id=partition_id_1,
174
+ global_indexes=metadata.global_indexes,
175
+ field_names=metadata.field_names,
176
+ )
177
+ )
178
+ assert success
179
+
180
+ # Test get metadate in fetch mode
181
+ gen_meta = ray.get(
182
+ tq_controller.get_metadata.remote(
183
+ data_fields=["prompt_ids"],
184
+ batch_size=gbs_1 * num_n_samples_1,
185
+ partition_id=partition_id_1,
186
+ mode="fetch",
187
+ task_name="generate_sequences",
188
+ )
189
+ )
190
+ assert gen_meta
191
+
192
+ # Test get clear meta
193
+ clear_meta = ray.get(
194
+ tq_controller.get_metadata.remote(
195
+ data_fields=[],
196
+ partition_id=partition_id_1,
197
+ mode="insert",
198
+ )
199
+ )
200
+ assert clear_meta
201
+
202
+ # =========================partition 2=============================#
203
+ data_fields = ["prompt_ids", "attention_mask"]
204
+ val_metadata = ray.get(
205
+ tq_controller.get_metadata.remote(
206
+ data_fields=data_fields,
207
+ batch_size=gbs_2 * num_n_samples_2,
208
+ partition_id=partition_id_2,
209
+ mode="insert",
210
+ )
211
+ )
212
+
213
+ part1_index_range = gbs_1 * num_n_samples_1
214
+ part2_index_range = gbs_2 * num_n_samples_2
215
+ assert val_metadata.global_indexes == list(range(part1_index_range, part2_index_range + part1_index_range))
216
+ assert val_metadata.samples[0].partition_id == "val_0"
217
+ assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in val_metadata.samples]) == int(
218
+ ProductionStatus.NOT_PRODUCED
219
+ )
220
+ assert sum(
221
+ [int(sample.fields.get("attention_mask").production_status) for sample in val_metadata.samples]
222
+ ) == int(ProductionStatus.NOT_PRODUCED)
223
+ partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
224
+ assert partition_index_range == set(range(part1_index_range, part2_index_range + part1_index_range))
225
+
226
+ # Update production status
227
+ success = ray.get(
228
+ tq_controller.update_production_status.remote(
229
+ partition_id=partition_id_2,
230
+ global_indexes=val_metadata.global_indexes,
231
+ field_names=val_metadata.field_names,
232
+ )
233
+ )
234
+ assert success
235
+
236
+ # Clear partition 1
237
+ partition_index_range_1 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
238
+ assert partition_index_range_1
239
+ ray.get(tq_controller.clear.remote(partition_id_1))
240
+ partition_1_after_clear = ray.get(tq_controller.get_partition.remote(partition_id_1))
241
+ partition_index_range_1_after_clear = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
242
+
243
+ assert not partition_index_range_1_after_clear
244
+ assert torch.all(partition_1_after_clear.production_status[list(partition_index_range_1), :] == 0)
245
+ assert torch.all(partition_1_after_clear.consumption_status["generate_sequences"] == 0)
246
+
247
+ partition_2 = ray.get(tq_controller.get_partition.remote(partition_id_2))
248
+ partition_index_range_2 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
249
+ assert partition_index_range_2 == set([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
250
+ assert torch.all(
251
+ partition_2.production_status[list(partition_index_range_2), : len(val_metadata.field_names)] == 1
252
+ )
253
+ print("✓ Only clear partition 1 correct")
254
+
255
+ # =========================partition 3=============================#
256
+ metadata_2 = ray.get(
257
+ tq_controller.get_metadata.remote(
258
+ data_fields=data_fields,
259
+ batch_size=gbs_3 * num_n_samples_3,
260
+ partition_id=partition_id_3,
261
+ mode="insert",
262
+ )
263
+ )
264
+ assert metadata_2.global_indexes == list(range(32)) + list(range(48, 80))
265
+ assert metadata_2.samples[0].partition_id == "train_1"
266
+ assert sum([int(sample.fields.get("prompt_ids").production_status) for sample in metadata_2.samples]) == int(
267
+ ProductionStatus.NOT_PRODUCED
268
+ )
269
+ assert sum(
270
+ [int(sample.fields.get("attention_mask").production_status) for sample in metadata_2.samples]
271
+ ) == int(ProductionStatus.NOT_PRODUCED)
272
+ partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3))
273
+ assert partition_index_range == set(list(range(32)) + list(range(48, 80)))
274
+ print("✓ Correctly assign partition_3")