TransferQueue 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,202 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import pytest
20
+ import tensordict
21
+ import torch
22
+ from tensordict import NonTensorData, NonTensorStack, TensorDict
23
+
24
+ # Import your classes here
25
+ parent_dir = Path(__file__).resolve().parent.parent
26
+ sys.path.append(str(parent_dir))
27
+
28
+ from transfer_queue.utils.serial_utils import MsgpackDecoder, MsgpackEncoder # noqa: E402
29
+
30
+
31
+ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:
32
+ if non_tensor_dict is None:
33
+ non_tensor_dict = {}
34
+
35
+ batch_size = None
36
+
37
+ for key, val in tensor_dict.items():
38
+ if isinstance(val, list):
39
+ for v in val:
40
+ assert not isinstance(v, torch.Tensor), (
41
+ "Passing a list makes the data NonTensorStack, "
42
+ "which doesn't support torch.Tensor. Please convert to numpy first"
43
+ )
44
+
45
+ assert isinstance(val, torch.Tensor | list)
46
+
47
+ if batch_size is None:
48
+ batch_size = len(val)
49
+ else:
50
+ assert len(val) == batch_size
51
+
52
+ if batch_size is None:
53
+ batch_size = []
54
+ else:
55
+ batch_size = [batch_size]
56
+
57
+ for key, val in non_tensor_dict.items():
58
+ assert key not in tensor_dict
59
+ tensor_dict[key] = NonTensorData(val)
60
+
61
+ return TensorDict(source=tensor_dict, batch_size=batch_size)
62
+
63
+
64
+ @pytest.mark.parametrize(
65
+ "dtype",
66
+ [
67
+ torch.float16,
68
+ torch.bfloat16,
69
+ torch.float32,
70
+ ],
71
+ )
72
+ def test_tensor_serialization(dtype):
73
+ encoder = MsgpackEncoder()
74
+ decoder = MsgpackDecoder(torch.Tensor)
75
+
76
+ tensor = torch.randn(100, 10, dtype=dtype)
77
+ serialized = encoder.encode(tensor)
78
+ deserialized = decoder.decode(serialized)
79
+ assert torch.allclose(tensor, deserialized)
80
+
81
+ vocab_size = 128
82
+ a = torch.randint(low=0, high=vocab_size, size=(11,))
83
+ b = torch.randint(low=0, high=vocab_size, size=(13,))
84
+ input_ids = [a, b]
85
+ input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged, dtype=dtype)
86
+
87
+ input_ids_serialized = encoder.encode(input_ids)
88
+ input_ids_deserialized = decoder.decode(input_ids_serialized)
89
+ for i in range(len(input_ids.unbind())):
90
+ assert torch.allclose(input_ids[0], input_ids_deserialized[0])
91
+
92
+
93
+ def test_tensordict_serialization_with_nontensor():
94
+ encoder = MsgpackEncoder()
95
+ decoder = MsgpackDecoder(TensorDict)
96
+
97
+ obs = torch.randn(100, 10)
98
+ data1 = {"obs": obs, "act": torch.randn(100, 3), "data_sources": ["gsm8k"] * 100}
99
+ data1 = get_tensordict(tensor_dict=data1)
100
+
101
+ serialized = encoder.encode(data1)
102
+ deserialized = decoder.decode(serialized)
103
+
104
+ assert deserialized.keys() == data1.keys()
105
+ assert deserialized.batch_size[0] == 100
106
+ assert isinstance(deserialized.get("data_sources"), NonTensorStack)
107
+ for k, v in data1.items():
108
+ if isinstance(v, torch.Tensor):
109
+ assert torch.allclose(deserialized[k], v)
110
+ elif isinstance(v, NonTensorStack):
111
+ assert deserialized[k] == data1[k]
112
+
113
+
114
+ def test_tensordict_serialization_with_images():
115
+ # each sample contains a sequence with multiple images of different sizes
116
+ vocab_size = 128
117
+ a = torch.randint(low=0, high=vocab_size, size=(11,))
118
+ b = torch.randint(low=0, high=vocab_size, size=(13,))
119
+ input_ids = [a, b]
120
+ input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
121
+
122
+ a_images = [
123
+ torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
124
+ torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
125
+ ]
126
+ b_images = [
127
+ torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
128
+ torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
129
+ torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(),
130
+ ]
131
+
132
+ images = [a_images, b_images]
133
+
134
+ data = get_tensordict({"input_ids": input_ids, "images": images})
135
+
136
+ encoder = MsgpackEncoder()
137
+ decoder = MsgpackDecoder(TensorDict)
138
+
139
+ serialized = encoder.encode(data)
140
+ deserialized = decoder.decode(serialized)
141
+
142
+ assert np.all(np.equal(deserialized[0]["images"][0], a_images[0]))
143
+ assert torch.all(torch.eq(deserialized[0]["input_ids"], a))
144
+
145
+
146
+ # Copied from https://github.com/volcengine/verl/blob/33edd95e13c72b9494585765b5fedc679fd73923/tests/test_protocol_v2_on_cpu.py#L119
147
+ def test_tensordict_with_packing():
148
+ vocab_size = 128
149
+ a = torch.randint(low=0, high=vocab_size, size=(11,))
150
+ b = torch.randint(low=0, high=vocab_size, size=(13,))
151
+ input_ids = [a, b]
152
+ input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
153
+
154
+ data = get_tensordict({"input_ids": input_ids})
155
+ encoder = MsgpackEncoder()
156
+ decoder = MsgpackDecoder(TensorDict)
157
+ deserialized_data = decoder.decode(encoder.encode(data))
158
+
159
+ # test cu_seqlens
160
+ cu_seqlens = torch.tensor([0, 11, 24])
161
+ assert torch.all(torch.eq(cu_seqlens, deserialized_data["input_ids"].offsets()))
162
+
163
+ # test index
164
+ assert torch.all(torch.eq(deserialized_data["input_ids"][0], a))
165
+ assert torch.all(torch.eq(deserialized_data["input_ids"][1], b))
166
+
167
+ assert torch.all(torch.eq(deserialized_data[0]["input_ids"], a))
168
+ assert torch.all(torch.eq(deserialized_data[1]["input_ids"], b))
169
+
170
+ data_lst = deserialized_data.chunk(2)
171
+
172
+ assert torch.all(torch.eq(data_lst[0]["input_ids"][0], a))
173
+ assert torch.all(torch.eq(data_lst[1]["input_ids"][0], b))
174
+
175
+
176
+ def test_nested_tensordict_serialization():
177
+ td1 = tensordict.TensorDict({"a": torch.randn(2, 3), "b": torch.randn(2, 4)}, batch_size=[2])
178
+
179
+ td2 = tensordict.TensorDict({"c": torch.randn(2, 5), "d": torch.randn(2, 6)}, batch_size=[2])
180
+
181
+ td = tensordict.TensorDict({"part1": td1, "part2": td2, "e": torch.randn(2, 7)}, batch_size=[2])
182
+
183
+ encoder = MsgpackEncoder()
184
+ decoder = MsgpackDecoder(TensorDict)
185
+ deserialized_td = decoder.decode(encoder.encode(td))
186
+
187
+ assert isinstance(deserialized_td, tensordict.TensorDict)
188
+ assert set(deserialized_td.keys()) == set(td.keys())
189
+ assert isinstance(deserialized_td["part1"], tensordict.TensorDict)
190
+ assert isinstance(deserialized_td["part2"], tensordict.TensorDict)
191
+
192
+ assert set(deserialized_td["part1"].keys()) == set(td1.keys())
193
+ assert set(deserialized_td["part2"].keys()) == set(td2.keys())
194
+
195
+ for key in td.keys():
196
+ if isinstance(td[key], tensordict.TensorDict):
197
+ for inner_key in td[key].keys():
198
+ assert torch.allclose(deserialized_td[key][inner_key], td[key][inner_key]), (
199
+ f"Values for key '{key}.{inner_key}' do not match"
200
+ )
201
+ else:
202
+ assert torch.allclose(deserialized_td[key], td[key]), f"Values for key '{key}' do not match"
@@ -0,0 +1,479 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+ import time
17
+ import uuid
18
+ from pathlib import Path
19
+ from threading import Thread
20
+ from unittest.mock import MagicMock
21
+
22
+ import pytest
23
+ import ray
24
+ import tensordict
25
+ import torch
26
+ import zmq
27
+ from tensordict import TensorDict
28
+
29
+ # Import your classes here
30
+ parent_dir = Path(__file__).resolve().parent.parent
31
+ sys.path.append(str(parent_dir))
32
+
33
+ try:
34
+ from transfer_queue import TransferQueueStorageSimpleUnit
35
+ from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo
36
+ except ImportError:
37
+ # For testing purposes if imports are not available
38
+ TransferQueueStorageSimpleUnit = MagicMock()
39
+ ZMQServerInfo = MagicMock()
40
+ ZMQRequestType = MagicMock()
41
+ ZMQMessage = MagicMock()
42
+
43
+
44
+ # Mock ZMQ utilities if not available in test environment
45
+ def create_zmq_socket(context, socket_type, identity=None):
46
+ sock = context.socket(socket_type)
47
+ if identity:
48
+ sock.setsockopt(zmq.IDENTITY, identity)
49
+ return sock
50
+
51
+
52
+ # Mock Controller to handle handshake and data updates
53
+ class MockController:
54
+ def __init__(self, controller_id="controller_001"):
55
+ self.controller_id = controller_id
56
+ self.context = zmq.Context()
57
+
58
+ # Socket for handshake
59
+ self.handshake_socket = self.context.socket(zmq.ROUTER)
60
+ self.handshake_port = self._bind_to_random_port(self.handshake_socket)
61
+
62
+ # Socket for data status updates
63
+ self.data_update_socket = self.context.socket(zmq.ROUTER)
64
+ self.data_update_port = self._bind_to_random_port(self.data_update_socket)
65
+
66
+ self.zmq_server_info = ZMQServerInfo.create(
67
+ role="CONTROLLER",
68
+ id=controller_id,
69
+ ip="127.0.0.1",
70
+ ports={"handshake_socket": self.handshake_port, "data_status_update_socket": self.data_update_port},
71
+ )
72
+
73
+ self.running = True
74
+ self.handshake_thread = Thread(target=self._handle_handshake, daemon=True)
75
+ self.data_update_thread = Thread(target=self._handle_data_updates, daemon=True)
76
+ self.handshake_thread.start()
77
+ self.data_update_thread.start()
78
+
79
+ def _bind_to_random_port(self, socket):
80
+ port = socket.bind_to_random_port("tcp://127.0.0.1")
81
+ return port
82
+
83
+ def _handle_handshake(self):
84
+ poller = zmq.Poller()
85
+ poller.register(self.handshake_socket, zmq.POLLIN)
86
+
87
+ while self.running:
88
+ try:
89
+ socks = dict(poller.poll(100)) # 100ms timeout
90
+ if self.handshake_socket in socks:
91
+ identity, msg_bytes = self.handshake_socket.recv_multipart()
92
+ ZMQMessage.deserialize(msg_bytes)
93
+
94
+ # Send handshake ack
95
+ ack_msg = ZMQMessage.create(
96
+ request_type=ZMQRequestType.HANDSHAKE_ACK,
97
+ sender_id=self.controller_id,
98
+ body={"message": "Handshake successful"},
99
+ )
100
+ self.handshake_socket.send_multipart([identity, ack_msg.serialize()])
101
+ except zmq.Again:
102
+ continue
103
+ except Exception:
104
+ if self.running:
105
+ pass
106
+
107
+ def _handle_data_updates(self):
108
+ poller = zmq.Poller()
109
+ poller.register(self.data_update_socket, zmq.POLLIN)
110
+
111
+ while self.running:
112
+ try:
113
+ socks = dict(poller.poll(100)) # 100ms timeout
114
+ if self.data_update_socket in socks:
115
+ identity, msg_bytes = self.data_update_socket.recv_multipart()
116
+ ZMQMessage.deserialize(msg_bytes)
117
+
118
+ # Send data update ack
119
+ ack_msg = ZMQMessage.create(
120
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
121
+ sender_id=self.controller_id,
122
+ body={"message": "Data update received"},
123
+ )
124
+ self.data_update_socket.send_multipart([identity, ack_msg.serialize()])
125
+ except zmq.Again:
126
+ continue
127
+ except Exception:
128
+ if self.running:
129
+ pass
130
+
131
+ def stop(self):
132
+ self.running = False
133
+ time.sleep(0.1) # Give threads time to stop
134
+ self.handshake_socket.close()
135
+ self.data_update_socket.close()
136
+
137
+
138
+ # Mock client to send PUT/GET requests
139
+ class MockClient:
140
+ def __init__(self, storage_put_get_address):
141
+ self.context = zmq.Context()
142
+ self.socket = self.context.socket(zmq.DEALER)
143
+ self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
144
+ self.socket.connect(storage_put_get_address)
145
+
146
+ def send_put(self, client_id, global_indexes, local_indexes, field_data):
147
+ msg = ZMQMessage.create(
148
+ request_type=ZMQRequestType.PUT_DATA,
149
+ sender_id=f"mock_client_{client_id}",
150
+ body={"global_indexes": global_indexes, "local_indexes": local_indexes, "field_data": field_data},
151
+ )
152
+ self.socket.send(msg.serialize())
153
+ return ZMQMessage.deserialize(self.socket.recv())
154
+
155
+ def send_get(self, client_id, local_indexes, fields):
156
+ msg = ZMQMessage.create(
157
+ request_type=ZMQRequestType.GET_DATA,
158
+ sender_id=f"mock_client_{client_id}",
159
+ body={"local_indexes": local_indexes, "fields": fields},
160
+ )
161
+ self.socket.send(msg.serialize())
162
+ return ZMQMessage.deserialize(self.socket.recv())
163
+
164
+ def close(self):
165
+ self.socket.close()
166
+ self.context.term()
167
+
168
+
169
+ @pytest.fixture(scope="session")
170
+ def ray_setup():
171
+ ray.init(ignore_reinit_error=True)
172
+ yield
173
+ ray.shutdown()
174
+
175
+
176
+ @pytest.fixture
177
+ def storage_setup(ray_setup):
178
+ storage_size = 10000
179
+ tensordict.set_list_to_stack(True).set()
180
+
181
+ # Start mock controller
182
+ mock_controller = MockController(f"controller_{uuid.uuid4()}")
183
+ time.sleep(0.5) # Wait for controller sockets to be ready
184
+
185
+ # Start Ray actor
186
+ storage_actor = TransferQueueStorageSimpleUnit.options(max_concurrency=50, num_cpus=1).remote(storage_size)
187
+
188
+ # Register controller info
189
+ controller_infos = {mock_controller.controller_id: mock_controller.zmq_server_info}
190
+ ray.get(storage_actor.register_controller_info.remote(controller_infos))
191
+
192
+ # Get ZMQ address to connect client
193
+ zmq_info = ray.get(storage_actor.get_zmq_server_info.remote())
194
+ put_get_address = zmq_info.to_addr("put_get_socket")
195
+ time.sleep(1) # Wait for socket to be ready
196
+
197
+ yield storage_actor, put_get_address, mock_controller
198
+
199
+ # Cleanup
200
+ mock_controller.stop()
201
+
202
+
203
+ def test_put_get_single_client(storage_setup):
204
+ """Test basic put and get operations with a single client using TensorDict and torch tensors."""
205
+ _, put_get_address, _ = storage_setup
206
+
207
+ client = MockClient(put_get_address)
208
+
209
+ # PUT data
210
+ global_indexes = [0, 1, 2]
211
+ local_indexes = [0, 1, 2]
212
+ field_data = TensorDict(
213
+ {
214
+ "log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
215
+ "rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])],
216
+ },
217
+ batch_size=[],
218
+ )
219
+
220
+ response = client.send_put(0, global_indexes, local_indexes, field_data)
221
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
222
+
223
+ # GET data
224
+ response = client.send_get(0, [0, 1], ["log_probs", "rewards"])
225
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
226
+
227
+ retrieved_data = response.body["data"]
228
+ assert "log_probs" in retrieved_data
229
+ assert "rewards" in retrieved_data
230
+ assert retrieved_data["log_probs"].size(0) == 2
231
+ assert retrieved_data["rewards"].size(0) == 2
232
+
233
+ # Verify data correctness
234
+ torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
235
+ torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
236
+ torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([10.0]))
237
+ torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([20.0]))
238
+
239
+ client.close()
240
+
241
+
242
+ def test_put_get_multiple_clients(storage_setup):
243
+ """Test put and get operations with multiple clients including overlapping local indexes"""
244
+ _, put_get_address, _ = storage_setup
245
+
246
+ num_clients = 5
247
+ clients = [MockClient(put_get_address) for _ in range(num_clients)]
248
+
249
+ # Each client puts unique data using different local_indexes
250
+ for i, client in enumerate(clients):
251
+ global_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2]
252
+ local_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2]
253
+ field_data = TensorDict(
254
+ {
255
+ "log_probs": [
256
+ torch.tensor([i, i + 1, i + 2]),
257
+ torch.tensor([i + 3, i + 4, i + 5]),
258
+ torch.tensor([i + 6, i + 7, i + 8]),
259
+ ],
260
+ "rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])],
261
+ }
262
+ )
263
+
264
+ response = client.send_put(i, global_indexes, local_indexes, field_data)
265
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
266
+
267
+ # Now simulate a third client that writes to overlapping local_indexes (e.g., index 0)
268
+ overlapping_client = MockClient(put_get_address)
269
+ overlap_local_indexes = [0] # Overlaps with first client's index 0
270
+ overlap_field_data = TensorDict({"log_probs": [torch.tensor([999, 999, 999])], "rewards": [torch.tensor([999])]})
271
+ response = overlapping_client.send_put(
272
+ client_id=99, global_indexes=[0], local_indexes=overlap_local_indexes, field_data=overlap_field_data
273
+ )
274
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
275
+
276
+ # Each original client gets its own data (except for index 0 which was overwritten)
277
+ for i, client in enumerate(clients):
278
+ response = client.send_get(i, [i * 10 + 0, i * 10 + 1], ["log_probs", "rewards"])
279
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
280
+
281
+ retrieved_data = response.body["data"]
282
+ assert retrieved_data["log_probs"].size(0) == 2
283
+ assert retrieved_data["rewards"].size(0) == 2
284
+
285
+ # For index 0, expect data from overlapping_client; others from original client
286
+ if i == 0:
287
+ # Index 0 was overwritten
288
+ torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([999, 999, 999]))
289
+ torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([999]))
290
+ # Index 1 remains original
291
+ torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([3, 4, 5]))
292
+ torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([10]))
293
+ else:
294
+ # All data remains original
295
+ torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([i, i + 1, i + 2]))
296
+ torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([i + 3, i + 4, i + 5]))
297
+ torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([i * 10]))
298
+ torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([i * 10 + 10]))
299
+
300
+ # Cleanup
301
+ for client in clients:
302
+ client.close()
303
+ overlapping_client.close()
304
+
305
+
306
+ def test_performance_basic(storage_setup):
307
+ """Basic performance test with larger data volume and proper index handling"""
308
+ _, put_get_address, _ = storage_setup
309
+
310
+ client = MockClient(put_get_address)
311
+
312
+ # PUT performance test
313
+ put_latencies = []
314
+ num_puts = 50
315
+ batch_size = 128
316
+
317
+ for i in range(num_puts):
318
+ start = time.time()
319
+
320
+ # Use larger batch size and more complex index mapping
321
+ global_indexes = list(range(i * batch_size, (i + 1) * batch_size))
322
+ local_indexes = list(range(i * batch_size, (i + 1) * batch_size))
323
+
324
+ # Create larger tensor data to increase data volume
325
+ log_probs_data = []
326
+ rewards_data = []
327
+
328
+ for j in range(batch_size):
329
+ # Each sample contains larger tensors to increase data transfer volume
330
+ log_probs_tensor = torch.randn(32768)
331
+ rewards_tensor = torch.randn(32768)
332
+ log_probs_data.append(log_probs_tensor)
333
+ rewards_data.append(rewards_tensor)
334
+
335
+ field_data = TensorDict({"log_probs": log_probs_data, "rewards": rewards_data}, batch_size=[batch_size])
336
+
337
+ response = client.send_put(0, global_indexes, local_indexes, field_data)
338
+ latency = time.time() - start
339
+ put_latencies.append(latency)
340
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
341
+
342
+ # GET performance test
343
+ get_latencies = []
344
+ num_gets = 50
345
+
346
+ for i in range(num_gets):
347
+ start = time.time()
348
+ # Retrieve larger batch of data
349
+ indices = list(range(i * batch_size, (i + 1) * batch_size)) # Retrieve batch_size indices of data each time
350
+ response = client.send_get(0, indices, ["log_probs", "rewards"])
351
+ latency = time.time() - start
352
+ get_latencies.append(latency)
353
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
354
+
355
+ avg_put_latency = sum(put_latencies) / len(put_latencies) * 1000 # ms
356
+ avg_get_latency = sum(get_latencies) / len(get_latencies) * 1000 # ms
357
+
358
+ # Adjust performance thresholds to accommodate larger data volume
359
+ assert avg_put_latency < 5000, f"Avg PUT latency {avg_put_latency}ms exceeds threshold"
360
+ assert avg_get_latency < 5000, f"Avg GET latency {avg_get_latency}ms exceeds threshold"
361
+
362
+ client.close()
363
+
364
+
365
+ def test_put_get_nested_tensor_single_client(storage_setup):
366
+ """Test basic put and get operations with a single client using TensorDict and nested tensors."""
367
+ _, put_get_address, _ = storage_setup
368
+
369
+ client = MockClient(put_get_address)
370
+
371
+ # PUT data
372
+ global_indexes = [0, 1, 2]
373
+ local_indexes = [0, 1, 2]
374
+
375
+ field_data = TensorDict(
376
+ {
377
+ "variable_length_sequences": [
378
+ torch.tensor([-0.5, -1.2, -0.8]),
379
+ torch.tensor([-0.3, -1.5, -2.1, -0.9]),
380
+ torch.tensor([-1.1, -0.7]),
381
+ ],
382
+ "attention_mask": [torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1]), torch.tensor([1, 1])],
383
+ },
384
+ batch_size=[],
385
+ )
386
+
387
+ response = client.send_put(0, global_indexes, local_indexes, field_data)
388
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
389
+
390
+ # GET data
391
+ response = client.send_get(0, [0, 2], ["variable_length_sequences", "attention_mask"])
392
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
393
+
394
+ retrieved_data = response.body["data"]
395
+ assert "variable_length_sequences" in retrieved_data
396
+ assert "attention_mask" in retrieved_data
397
+ assert retrieved_data["variable_length_sequences"].size(0) == 2
398
+ assert retrieved_data["attention_mask"].size(0) == 2
399
+
400
+ # Verify data correctness
401
+ torch.testing.assert_close(retrieved_data["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
402
+ torch.testing.assert_close(retrieved_data["variable_length_sequences"][1], torch.tensor([-1.1, -0.7]))
403
+ torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))
404
+ torch.testing.assert_close(retrieved_data["attention_mask"][1], torch.tensor([1, 1]))
405
+
406
+ client.close()
407
+
408
+
409
+ def test_put_get_nested_nontensor_single_client(storage_setup):
410
+ """Test basic put and get operations with a single client using non-tensor data (strings)."""
411
+ _, put_get_address, _ = storage_setup
412
+
413
+ client = MockClient(put_get_address)
414
+
415
+ # PUT data
416
+ global_indexes = [0, 1, 2]
417
+ local_indexes = [0, 1, 2]
418
+ field_data = TensorDict(
419
+ {
420
+ "prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
421
+ "response_text": ["Hi there!", "This is the response to the longer sentence", "Test response"],
422
+ },
423
+ batch_size=[],
424
+ )
425
+
426
+ response = client.send_put(0, global_indexes, local_indexes, field_data)
427
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
428
+
429
+ # GET data
430
+ response = client.send_get(0, [0, 1, 2], ["prompt_text", "response_text"])
431
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
432
+
433
+ retrieved_data = response.body["data"]
434
+ assert "prompt_text" in retrieved_data
435
+ assert "response_text" in retrieved_data
436
+
437
+ # Verify data correctness
438
+ assert isinstance(retrieved_data["prompt_text"][0], str)
439
+ assert isinstance(retrieved_data["response_text"][0], str)
440
+
441
+ assert retrieved_data["prompt_text"][0] == "Hello world!"
442
+ assert retrieved_data["prompt_text"][1] == "This is a longer sentence for testing"
443
+ assert retrieved_data["prompt_text"][2] == "Test case"
444
+ assert retrieved_data["response_text"][0] == "Hi there!"
445
+ assert retrieved_data["response_text"][1] == "This is the response to the longer sentence"
446
+ assert retrieved_data["response_text"][2] == "Test response"
447
+
448
+ client.close()
449
+
450
+
451
+ def test_put_get_single_item_single_client(storage_setup):
452
+ """Test put and get operations for a single item with a single client."""
453
+ _, put_get_address, _ = storage_setup
454
+
455
+ client = MockClient(put_get_address)
456
+
457
+ # PUT data
458
+ field_data = TensorDict(
459
+ {
460
+ "prompt_text": ["Hello world!"],
461
+ "attention_mask": [torch.tensor([1, 1, 1])],
462
+ },
463
+ batch_size=[],
464
+ )
465
+
466
+ response = client.send_put(0, [0], [0], field_data)
467
+ assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
468
+
469
+ # GET data
470
+ response = client.send_get(0, [0], ["prompt_text", "attention_mask"])
471
+ assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
472
+
473
+ retrieved_data = response.body["data"]
474
+ assert "prompt_text" in retrieved_data
475
+ assert "attention_mask" in retrieved_data
476
+
477
+ assert retrieved_data["prompt_text"][0] == "Hello world!"
478
+ assert retrieved_data["attention_mask"].shape == (1, 3)
479
+ torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))