TransferQueue 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_client.py ADDED
@@ -0,0 +1,390 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+ import time
17
+ from pathlib import Path
18
+ from threading import Thread
19
+
20
+ import pytest
21
+ import torch
22
+ import zmq
23
+ from tensordict import NonTensorStack, TensorDict
24
+
25
+ # Import your classes here
26
+ parent_dir = Path(__file__).resolve().parent.parent
27
+ sys.path.append(str(parent_dir))
28
+
29
+ from transfer_queue import TransferQueueClient # noqa: E402
30
+ from transfer_queue.metadata import ( # noqa: E402
31
+ BatchMeta,
32
+ FieldMeta,
33
+ SampleMeta,
34
+ )
35
+ from transfer_queue.utils.zmq_utils import ( # noqa: E402
36
+ ZMQMessage,
37
+ ZMQRequestType,
38
+ ZMQServerInfo,
39
+ )
40
+
41
+ TEST_DATA = TensorDict(
42
+ {
43
+ "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
44
+ "variable_length_sequences": torch.nested.as_nested_tensor(
45
+ [
46
+ torch.tensor([-0.5, -1.2, -0.8]),
47
+ torch.tensor([-0.3, -1.5, -2.1, -0.9]),
48
+ torch.tensor([-1.1, -0.7]),
49
+ ]
50
+ ),
51
+ "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
52
+ },
53
+ batch_size=[3],
54
+ )
55
+
56
+
57
+ # Mock Controller for Client Unit Testing
58
+ class MockController:
59
+ def __init__(self, controller_id="controller_0"):
60
+ self.controller_id = controller_id
61
+ self.context = zmq.Context()
62
+
63
+ # Socket for data requests
64
+ self.request_socket = self.context.socket(zmq.ROUTER)
65
+ self.request_port = self._bind_to_random_port(self.request_socket)
66
+
67
+ self.zmq_server_info = ZMQServerInfo.create(
68
+ role="TransferQueueController",
69
+ id=controller_id,
70
+ ip="127.0.0.1",
71
+ ports={
72
+ "request_handle_socket": self.request_port,
73
+ },
74
+ )
75
+
76
+ self.running = True
77
+ self.request_thread = Thread(target=self._handle_requests, daemon=True)
78
+ self.request_thread.start()
79
+
80
+ def _bind_to_random_port(self, socket):
81
+ port = socket.bind_to_random_port("tcp://127.0.0.1")
82
+ return port
83
+
84
+ def _handle_requests(self):
85
+ poller = zmq.Poller()
86
+ poller.register(self.request_socket, zmq.POLLIN)
87
+
88
+ while self.running:
89
+ try:
90
+ socks = dict(poller.poll(100)) # 100ms timeout
91
+ if self.request_socket in socks:
92
+ identity, serialized_msg = self.request_socket.recv_multipart()
93
+ request_msg = ZMQMessage.deserialize(serialized_msg)
94
+
95
+ # Determine response based on request type
96
+ if request_msg.request_type == ZMQRequestType.GET_META:
97
+ response_body = self._mock_batch_meta(request_msg.body)
98
+ response_type = ZMQRequestType.GET_META_RESPONSE
99
+ elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
100
+ response_body = self._mock_batch_meta(request_msg.body)
101
+ response_type = ZMQRequestType.GET_CLEAR_META_RESPONSE
102
+ elif request_msg.request_type == ZMQRequestType.CLEAR_META:
103
+ response_body = {"message": "clear ok"}
104
+ response_type = ZMQRequestType.CLEAR_META_RESPONSE
105
+
106
+ # Send response
107
+ response_msg = ZMQMessage.create(
108
+ request_type=response_type,
109
+ sender_id=self.controller_id,
110
+ receiver_id=request_msg.sender_id,
111
+ body=response_body,
112
+ )
113
+ self.request_socket.send_multipart([identity, response_msg.serialize()])
114
+ except zmq.Again:
115
+ continue
116
+ except Exception as e:
117
+ if self.is_running:
118
+ print(f"MockController running exception: {e}")
119
+ else:
120
+ print(f"MockController ERROR: {e}")
121
+ raise
122
+
123
+ def _mock_batch_meta(self, request_body):
124
+ batch_size = request_body.get("batch_size", 1)
125
+ data_fields = request_body.get("data_fields", [])
126
+
127
+ samples = []
128
+ for i in range(batch_size):
129
+ fields = []
130
+ for field_name in data_fields:
131
+ field_meta = FieldMeta(
132
+ name=field_name,
133
+ dtype=None,
134
+ shape=None,
135
+ production_status=0,
136
+ )
137
+ fields.append(field_meta)
138
+ sample = SampleMeta(
139
+ global_step=0,
140
+ global_index=i,
141
+ storage_id="storage_0",
142
+ local_index=i,
143
+ fields={field.name: field for field in fields},
144
+ )
145
+ samples.append(sample)
146
+ metadata = BatchMeta(samples=samples)
147
+
148
+ return {"metadata": metadata}
149
+
150
+ def stop(self):
151
+ self.running = False
152
+ time.sleep(0.2) # Give thread time to stop
153
+ self.request_socket.close()
154
+ self.context.term()
155
+
156
+
157
+ # Mock Storage for Client Unit Testing
158
+ class MockStorage:
159
+ def __init__(self, storage_id="storage_0"):
160
+ self.storage_id = storage_id
161
+ self.context = zmq.Context()
162
+
163
+ # Socket for data operations
164
+ self.data_socket = self.context.socket(zmq.ROUTER)
165
+ self.data_port = self._bind_to_random_port(self.data_socket)
166
+
167
+ self.zmq_server_info = ZMQServerInfo.create(
168
+ role="TransferQueueStorage",
169
+ id=storage_id,
170
+ ip="127.0.0.1",
171
+ ports={
172
+ "put_get_socket": self.data_port,
173
+ },
174
+ )
175
+
176
+ self.running = True
177
+ self.data_thread = Thread(target=self._handle_data_requests, daemon=True)
178
+ self.data_thread.start()
179
+
180
+ def _bind_to_random_port(self, socket):
181
+ port = socket.bind_to_random_port("tcp://127.0.0.1")
182
+ return port
183
+
184
+ def _handle_data_requests(self):
185
+ poller = zmq.Poller()
186
+ poller.register(self.data_socket, zmq.POLLIN)
187
+
188
+ while self.running:
189
+ try:
190
+ socks = dict(poller.poll(100)) # 100ms timeout
191
+ if self.data_socket in socks:
192
+ identity, msg_bytes = self.data_socket.recv_multipart()
193
+ msg = ZMQMessage.deserialize(msg_bytes)
194
+
195
+ # Handle different request types
196
+ if msg.request_type == ZMQRequestType.PUT_DATA:
197
+ response_body = {"message": "Data stored successfully"}
198
+ response_type = ZMQRequestType.PUT_DATA_RESPONSE
199
+ elif msg.request_type == ZMQRequestType.GET_DATA:
200
+ response_body = self._handle_get_data(msg.body)
201
+ response_type = ZMQRequestType.GET_DATA_RESPONSE
202
+ elif msg.request_type == ZMQRequestType.CLEAR_DATA:
203
+ response_body = {"message": "Data cleared successfully"}
204
+ response_type = ZMQRequestType.CLEAR_DATA_RESPONSE
205
+
206
+ # Send response
207
+ response_msg = ZMQMessage.create(
208
+ request_type=response_type,
209
+ sender_id=self.storage_id,
210
+ receiver_id=msg.sender_id,
211
+ body=response_body,
212
+ )
213
+ self.data_socket.send_multipart([identity, response_msg.serialize()])
214
+ except zmq.Again:
215
+ continue
216
+ except Exception as e:
217
+ if self.is_running:
218
+ print(f"MockStorage running exception: {e}")
219
+ else:
220
+ print(f"MockStorage ERROR: {e}")
221
+ raise
222
+
223
+ def _handle_get_data(self, request_body):
224
+ """Handle GET_DATA request by retrieving stored data"""
225
+ local_indexes = request_body.get("local_indexes", [])
226
+ fields = request_body.get("fields", [])
227
+
228
+ result: dict[str, list] = {}
229
+ for field in fields:
230
+ gathered_items = [TEST_DATA[field][i] for i in local_indexes]
231
+
232
+ if gathered_items:
233
+ all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
234
+ if all_tensors:
235
+ result[field] = torch.nested.as_nested_tensor(gathered_items)
236
+ else:
237
+ result[field] = NonTensorStack(*gathered_items)
238
+
239
+ return {"data": TensorDict(result)}
240
+
241
+ def stop(self):
242
+ self.running = False
243
+ time.sleep(0.2) # Give thread time to stop
244
+ self.data_socket.close()
245
+ self.context.term()
246
+
247
+
248
+ # Test Fixtures
249
+ @pytest.fixture
250
+ def mock_controller():
251
+ controller = MockController()
252
+ yield controller
253
+ controller.stop()
254
+
255
+
256
+ @pytest.fixture
257
+ def mock_storage():
258
+ storage = MockStorage()
259
+ yield storage
260
+ storage.stop()
261
+
262
+
263
+ @pytest.fixture
264
+ def client_setup(mock_controller, mock_storage):
265
+ # Create client with mock controller and storage
266
+ client_id = "client_0"
267
+
268
+ client = TransferQueueClient(
269
+ client_id=client_id,
270
+ controller_infos={mock_controller.controller_id: mock_controller.zmq_server_info},
271
+ storage_infos={mock_storage.storage_id: mock_storage.zmq_server_info},
272
+ )
273
+
274
+ # Give some time for connections to establish
275
+ time.sleep(0.5)
276
+
277
+ yield client, mock_controller, mock_storage
278
+
279
+
280
+ # Test basic functionality
281
+ def test_client_initialization(client_setup):
282
+ """Test client initialization and connection setup"""
283
+ client, mock_controller, mock_storage = client_setup
284
+
285
+ assert client.client_id is not None
286
+ assert mock_controller.controller_id in client._controllers
287
+ assert mock_storage.storage_id in client._storages
288
+
289
+
290
+ def test_put_and_get_data(client_setup):
291
+ """Test basic put and get operations"""
292
+ client, _, _ = client_setup
293
+
294
+ # Test put operation
295
+ client.put(data=TEST_DATA, global_step=0)
296
+
297
+ # Get metadata for retrieving data
298
+ metadata = client.get_meta(
299
+ data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, global_step=0
300
+ )
301
+
302
+ # Test get operation
303
+ result = client.get_data(metadata)
304
+
305
+ # Verify result structure
306
+ assert "log_probs" in result
307
+ assert "variable_length_sequences" in result
308
+ assert "prompt_text" in result
309
+
310
+ torch.testing.assert_close(result["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
311
+ torch.testing.assert_close(result["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
312
+ torch.testing.assert_close(result["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
313
+ torch.testing.assert_close(result["variable_length_sequences"][1], torch.tensor([-0.3, -1.5, -2.1, -0.9]))
314
+ assert result["prompt_text"][0] == "Hello world!"
315
+ assert result["prompt_text"][1] == "This is a longer sentence for testing"
316
+
317
+
318
+ def test_get_meta(client_setup):
319
+ """Test metadata retrieval"""
320
+ client, _, _ = client_setup
321
+
322
+ # Test get_meta operation
323
+ metadata = client.get_meta(data_fields=["tokens", "labels"], batch_size=10, global_step=0)
324
+
325
+ # Verify metadata structure
326
+ assert hasattr(metadata, "storage_meta_groups")
327
+ assert hasattr(metadata, "global_indexes")
328
+ assert hasattr(metadata, "field_names")
329
+ assert hasattr(metadata, "size")
330
+ assert len(metadata.global_indexes) == 10
331
+
332
+
333
+ def test_clear_operation(client_setup):
334
+ """Test clear operation"""
335
+ client, _, _ = client_setup
336
+
337
+ # Test clear operation
338
+ client.clear(global_step=0)
339
+
340
+
341
+ # Test with multiple controllers and storage units
342
+ def test_multiple_servers():
343
+ """Test client with multiple controllers and storage units"""
344
+ # Create multiple mock servers
345
+ controllers = [MockController(f"controller_{i}") for i in range(2)]
346
+ storages = [MockStorage(f"storage_{i}") for i in range(3)]
347
+
348
+ try:
349
+ # Create client with multiple servers
350
+ client_id = "client_test_multiple_servers"
351
+
352
+ controller_infos = {c.controller_id: c.zmq_server_info for c in controllers}
353
+ storage_infos = {s.storage_id: s.zmq_server_info for s in storages}
354
+
355
+ client = TransferQueueClient(
356
+ client_id=client_id, controller_infos=controller_infos, storage_infos=storage_infos
357
+ )
358
+
359
+ # Give time for connections
360
+ time.sleep(1.0)
361
+
362
+ # Verify connections
363
+ assert len(client._controllers) == 2
364
+ assert len(client._storages) == 3
365
+
366
+ # Test basic operation
367
+ test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
368
+
369
+ # Test put operation
370
+ client.put(data=test_data, global_step=0)
371
+
372
+ finally:
373
+ # Clean up
374
+ for c in controllers:
375
+ c.stop()
376
+ for s in storages:
377
+ s.stop()
378
+
379
+
380
+ # Test error handling
381
+ def test_put_without_required_params(client_setup):
382
+ """Test put operation without required parameters"""
383
+ client, _, _ = client_setup
384
+
385
+ # Create test data
386
+ test_data = TensorDict({"tokens": torch.randint(0, 100, (5, 128))}, batch_size=5)
387
+
388
+ # Test put without global_step (should fail)
389
+ with pytest.raises(AssertionError):
390
+ client.put(data=test_data)
@@ -0,0 +1,268 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import math
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ import numpy as np
21
+ import pytest
22
+ import ray
23
+ import torch
24
+
25
+ parent_dir = Path(__file__).resolve().parent.parent
26
+ sys.path.append(str(parent_dir))
27
+
28
+ from transfer_queue.controller import TQ_INIT_FIELD_NUM, TransferQueueController # noqa: E402
29
+ from transfer_queue.storage import TransferQueueStorageSimpleUnit # noqa: E402
30
+
31
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @pytest.fixture(scope="function")
36
+ def ray_setup():
37
+ if ray.is_initialized():
38
+ ray.shutdown()
39
+ ray.init(
40
+ ignore_reinit_error=True,
41
+ runtime_env={"env_vars": {"RAY_DEBUG": "1", "RAY_DEDUP_LOGS": "0"}},
42
+ log_to_driver=True,
43
+ )
44
+ yield
45
+ if ray.is_initialized():
46
+ ray.shutdown()
47
+ logger.info("Ray has been shut down completely after test")
48
+
49
+
50
+ @pytest.fixture(scope="function")
51
+ def setup_teardown_transfer_queue_controller(ray_setup):
52
+ # Used as the offset for the global index to distinguish which global step the data corresponds to
53
+ global_batch_size = 8
54
+ num_global_batch = 2
55
+ num_n_samples = 2
56
+ num_data_storage_units = 2
57
+
58
+ tq_controller = TransferQueueController.remote(
59
+ num_storage_units=num_data_storage_units,
60
+ global_batch_size=global_batch_size,
61
+ num_global_batch=num_global_batch,
62
+ num_n_samples=num_n_samples,
63
+ )
64
+ yield tq_controller, global_batch_size, num_global_batch, num_n_samples
65
+ ray.get(tq_controller.clear.remote(0))
66
+
67
+
68
+ @pytest.fixture(scope="function")
69
+ def setup_teardown_register_controller_info(setup_teardown_transfer_queue_controller):
70
+ tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller
71
+ total_storage_size = global_batch_size * num_global_batch * num_n_samples
72
+ num_data_storage_units = 2
73
+
74
+ data_system_storage_units = {}
75
+ for storage_unit_rank in range(num_data_storage_units):
76
+ storage_node = TransferQueueStorageSimpleUnit.remote(
77
+ storage_size=math.ceil(total_storage_size / num_data_storage_units)
78
+ )
79
+ data_system_storage_units[storage_unit_rank] = storage_node
80
+ logger.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")
81
+
82
+ # Register controller info
83
+ zmq_server_info = ray.get(tq_controller.get_zmq_server_info.remote())
84
+ controller_infos = {zmq_server_info.id: zmq_server_info}
85
+
86
+ ray.get(
87
+ [
88
+ storage_unit.register_controller_info.remote(controller_infos)
89
+ for storage_unit in data_system_storage_units.values()
90
+ ]
91
+ )
92
+
93
+ yield tq_controller, global_batch_size, num_n_samples, data_system_storage_units
94
+
95
+
96
+ class TestTransferQueueController:
97
+ @pytest.mark.parametrize("num_n_samples", [1, 2])
98
+ @pytest.mark.parametrize("num_global_batch", [1, 2])
99
+ def test_build_index_storage_mapping(self, num_n_samples, num_global_batch, ray_setup):
100
+ # Used as the offset for the global index to distinguish which global step the data corresponds to
101
+ global_batch_size = 8
102
+ num_data_storage_units = 2
103
+
104
+ self.tq_controller = TransferQueueController.remote(
105
+ num_storage_units=num_data_storage_units,
106
+ global_batch_size=global_batch_size,
107
+ num_global_batch=num_global_batch,
108
+ num_n_samples=num_n_samples,
109
+ )
110
+
111
+ global_index_storage_mapping, global_index_local_index_mapping = ray.get(
112
+ self.tq_controller.get_global_index_mapping.remote()
113
+ )
114
+
115
+ if num_global_batch == 1 and num_n_samples == 1:
116
+ assert np.array_equal(global_index_storage_mapping, np.array([0, 0, 0, 0, 1, 1, 1, 1]))
117
+ assert np.array_equal(global_index_local_index_mapping, np.array([0, 1, 2, 3, 0, 1, 2, 3]))
118
+ # The data of a single GBS will be distributed across different storage units
119
+ elif num_global_batch == 2 and num_n_samples == 1:
120
+ assert np.array_equal(
121
+ global_index_storage_mapping, np.array([0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1])
122
+ )
123
+ assert np.array_equal(
124
+ global_index_local_index_mapping, np.array([0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7])
125
+ )
126
+ # When num_n_samples is larger than 1
127
+ elif num_global_batch == 1 and num_n_samples == 2:
128
+ assert np.array_equal(
129
+ global_index_storage_mapping, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
130
+ )
131
+ assert np.array_equal(
132
+ global_index_local_index_mapping, np.array([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7])
133
+ )
134
+ elif num_global_batch == 2 and num_n_samples == 2:
135
+ assert np.array_equal(
136
+ global_index_storage_mapping,
137
+ np.array(
138
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
139
+ ),
140
+ )
141
+ assert np.array_equal(
142
+ global_index_local_index_mapping,
143
+ np.array(
144
+ [
145
+ 0,
146
+ 1,
147
+ 2,
148
+ 3,
149
+ 4,
150
+ 5,
151
+ 6,
152
+ 7,
153
+ 0,
154
+ 1,
155
+ 2,
156
+ 3,
157
+ 4,
158
+ 5,
159
+ 6,
160
+ 7,
161
+ 8,
162
+ 9,
163
+ 10,
164
+ 11,
165
+ 12,
166
+ 13,
167
+ 14,
168
+ 15,
169
+ 8,
170
+ 9,
171
+ 10,
172
+ 11,
173
+ 12,
174
+ 13,
175
+ 14,
176
+ 15,
177
+ ]
178
+ ),
179
+ )
180
+
181
+ def test_update_production_status(self, setup_teardown_transfer_queue_controller):
182
+ tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller
183
+
184
+ total_storage_size = global_batch_size * num_global_batch * num_n_samples
185
+ # Initialize get_data_production_status and filed_name_mapping
186
+ init_update_production_status = torch.zeros(total_storage_size, TQ_INIT_FIELD_NUM, dtype=torch.int8)
187
+ assert torch.equal(ray.get(tq_controller.get_data_production_status.remote()), init_update_production_status)
188
+ assert ray.get(tq_controller.get_field_name_mapping.remote()) == {}
189
+
190
+ columns_list = ["test_prompts"]
191
+ global_indexes = list(range(global_batch_size * num_n_samples))
192
+
193
+ # update production status
194
+ tq_controller._update_production_status.remote(global_indexes, columns_list)
195
+ new_field_name_mapping = ray.get(tq_controller.get_field_name_mapping.remote())
196
+ assert new_field_name_mapping["test_prompts"] == 0
197
+
198
+ new_data_production_status = ray.get(tq_controller.get_data_production_status.remote())
199
+ assert new_data_production_status[:, 0][: len(global_indexes)].sum() == len(global_indexes)
200
+
201
+ def test_data_consumption_status(self, setup_teardown_transfer_queue_controller):
202
+ tq_controller, global_batch_size, num_global_batch, num_n_samples = setup_teardown_transfer_queue_controller
203
+ total_storage_size = global_batch_size * num_global_batch * num_n_samples
204
+
205
+ init_data_consumption_status = {}
206
+ assert ray.get(tq_controller.get_data_consumption_status.remote()) == init_data_consumption_status
207
+
208
+ task_name = "test_task1"
209
+ ray.get(tq_controller._get_consumption_status.remote(task_name))
210
+ new_data_consumption_status = ray.get(tq_controller.get_data_consumption_status.remote())
211
+ assert torch.equal(new_data_consumption_status[task_name], torch.zeros(total_storage_size, dtype=torch.int8))
212
+
213
+ def test_get_prompt_metadata(self, setup_teardown_register_controller_info):
214
+ tq_controller, global_batch_size, n_samples, _ = setup_teardown_register_controller_info
215
+
216
+ data_fields = ["test_prompts"]
217
+ global_step = 5
218
+
219
+ metadata = ray.get(
220
+ tq_controller._get_metadata.remote(
221
+ data_fields=data_fields,
222
+ batch_size=global_batch_size * n_samples,
223
+ global_step=global_step,
224
+ mode="insert",
225
+ )
226
+ )
227
+ metadata.reorder([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
228
+ assert metadata.global_indexes == [
229
+ 31,
230
+ 30,
231
+ 29,
232
+ 28,
233
+ 27,
234
+ 26,
235
+ 25,
236
+ 24,
237
+ 23,
238
+ 22,
239
+ 21,
240
+ 20,
241
+ 19,
242
+ 18,
243
+ 17,
244
+ 16,
245
+ ]
246
+ assert metadata.local_indexes == [
247
+ 15,
248
+ 14,
249
+ 13,
250
+ 12,
251
+ 11,
252
+ 10,
253
+ 9,
254
+ 8,
255
+ 15,
256
+ 14,
257
+ 13,
258
+ 12,
259
+ 11,
260
+ 10,
261
+ 9,
262
+ 8,
263
+ ]
264
+ storage_ids = metadata.storage_ids
265
+ assert len(set(storage_ids[: len(storage_ids) // 2])) == 1
266
+
267
+ # TODO: Test case where multiple clients concurrently read datameta from a single controller,
268
+ # and each client receives the correct response