TransferQueue 0.0.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +307 -0
- recipe/simple_use_case/sync_demo.py +223 -0
- tests/test_client.py +390 -0
- tests/test_controller.py +268 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +479 -0
- transfer_queue/__init__.py +42 -0
- transfer_queue/client.py +663 -0
- transfer_queue/controller.py +772 -0
- transfer_queue/metadata.py +603 -0
- transfer_queue/storage.py +515 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +98 -0
- transfer_queue/utils/zmq_utils.py +175 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.0.1.dev0.dist-info/METADATA +15 -0
- transferqueue-0.0.1.dev0.dist-info/RECORD +21 -0
- transferqueue-0.0.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.0.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.0.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import sys
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pytest
|
|
20
|
+
import tensordict
|
|
21
|
+
import torch
|
|
22
|
+
from tensordict import NonTensorData, NonTensorStack, TensorDict
|
|
23
|
+
|
|
24
|
+
# Import your classes here
|
|
25
|
+
parent_dir = Path(__file__).resolve().parent.parent
|
|
26
|
+
sys.path.append(str(parent_dir))
|
|
27
|
+
|
|
28
|
+
from transfer_queue.utils.serial_utils import MsgpackDecoder, MsgpackEncoder # noqa: E402
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:
|
|
32
|
+
if non_tensor_dict is None:
|
|
33
|
+
non_tensor_dict = {}
|
|
34
|
+
|
|
35
|
+
batch_size = None
|
|
36
|
+
|
|
37
|
+
for key, val in tensor_dict.items():
|
|
38
|
+
if isinstance(val, list):
|
|
39
|
+
for v in val:
|
|
40
|
+
assert not isinstance(v, torch.Tensor), (
|
|
41
|
+
"Passing a list makes the data NonTensorStack, "
|
|
42
|
+
"which doesn't support torch.Tensor. Please convert to numpy first"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
assert isinstance(val, torch.Tensor | list)
|
|
46
|
+
|
|
47
|
+
if batch_size is None:
|
|
48
|
+
batch_size = len(val)
|
|
49
|
+
else:
|
|
50
|
+
assert len(val) == batch_size
|
|
51
|
+
|
|
52
|
+
if batch_size is None:
|
|
53
|
+
batch_size = []
|
|
54
|
+
else:
|
|
55
|
+
batch_size = [batch_size]
|
|
56
|
+
|
|
57
|
+
for key, val in non_tensor_dict.items():
|
|
58
|
+
assert key not in tensor_dict
|
|
59
|
+
tensor_dict[key] = NonTensorData(val)
|
|
60
|
+
|
|
61
|
+
return TensorDict(source=tensor_dict, batch_size=batch_size)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.mark.parametrize(
|
|
65
|
+
"dtype",
|
|
66
|
+
[
|
|
67
|
+
torch.float16,
|
|
68
|
+
torch.bfloat16,
|
|
69
|
+
torch.float32,
|
|
70
|
+
],
|
|
71
|
+
)
|
|
72
|
+
def test_tensor_serialization(dtype):
|
|
73
|
+
encoder = MsgpackEncoder()
|
|
74
|
+
decoder = MsgpackDecoder(torch.Tensor)
|
|
75
|
+
|
|
76
|
+
tensor = torch.randn(100, 10, dtype=dtype)
|
|
77
|
+
serialized = encoder.encode(tensor)
|
|
78
|
+
deserialized = decoder.decode(serialized)
|
|
79
|
+
assert torch.allclose(tensor, deserialized)
|
|
80
|
+
|
|
81
|
+
vocab_size = 128
|
|
82
|
+
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
|
83
|
+
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
|
84
|
+
input_ids = [a, b]
|
|
85
|
+
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged, dtype=dtype)
|
|
86
|
+
|
|
87
|
+
input_ids_serialized = encoder.encode(input_ids)
|
|
88
|
+
input_ids_deserialized = decoder.decode(input_ids_serialized)
|
|
89
|
+
for i in range(len(input_ids.unbind())):
|
|
90
|
+
assert torch.allclose(input_ids[0], input_ids_deserialized[0])
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_tensordict_serialization_with_nontensor():
|
|
94
|
+
encoder = MsgpackEncoder()
|
|
95
|
+
decoder = MsgpackDecoder(TensorDict)
|
|
96
|
+
|
|
97
|
+
obs = torch.randn(100, 10)
|
|
98
|
+
data1 = {"obs": obs, "act": torch.randn(100, 3), "data_sources": ["gsm8k"] * 100}
|
|
99
|
+
data1 = get_tensordict(tensor_dict=data1)
|
|
100
|
+
|
|
101
|
+
serialized = encoder.encode(data1)
|
|
102
|
+
deserialized = decoder.decode(serialized)
|
|
103
|
+
|
|
104
|
+
assert deserialized.keys() == data1.keys()
|
|
105
|
+
assert deserialized.batch_size[0] == 100
|
|
106
|
+
assert isinstance(deserialized.get("data_sources"), NonTensorStack)
|
|
107
|
+
for k, v in data1.items():
|
|
108
|
+
if isinstance(v, torch.Tensor):
|
|
109
|
+
assert torch.allclose(deserialized[k], v)
|
|
110
|
+
elif isinstance(v, NonTensorStack):
|
|
111
|
+
assert deserialized[k] == data1[k]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_tensordict_serialization_with_images():
|
|
115
|
+
# each sample contains a sequence with multiple images of different sizes
|
|
116
|
+
vocab_size = 128
|
|
117
|
+
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
|
118
|
+
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
|
119
|
+
input_ids = [a, b]
|
|
120
|
+
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
|
121
|
+
|
|
122
|
+
a_images = [
|
|
123
|
+
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
|
|
124
|
+
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
|
|
125
|
+
]
|
|
126
|
+
b_images = [
|
|
127
|
+
torch.randint(low=0, high=255, size=(3, 256, 256), dtype=torch.uint8).numpy(),
|
|
128
|
+
torch.randint(low=0, high=255, size=(3, 128, 128), dtype=torch.uint8).numpy(),
|
|
129
|
+
torch.randint(low=0, high=255, size=(3, 64, 64), dtype=torch.uint8).numpy(),
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
images = [a_images, b_images]
|
|
133
|
+
|
|
134
|
+
data = get_tensordict({"input_ids": input_ids, "images": images})
|
|
135
|
+
|
|
136
|
+
encoder = MsgpackEncoder()
|
|
137
|
+
decoder = MsgpackDecoder(TensorDict)
|
|
138
|
+
|
|
139
|
+
serialized = encoder.encode(data)
|
|
140
|
+
deserialized = decoder.decode(serialized)
|
|
141
|
+
|
|
142
|
+
assert np.all(np.equal(deserialized[0]["images"][0], a_images[0]))
|
|
143
|
+
assert torch.all(torch.eq(deserialized[0]["input_ids"], a))
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# Copied from https://github.com/volcengine/verl/blob/33edd95e13c72b9494585765b5fedc679fd73923/tests/test_protocol_v2_on_cpu.py#L119
|
|
147
|
+
def test_tensordict_with_packing():
|
|
148
|
+
vocab_size = 128
|
|
149
|
+
a = torch.randint(low=0, high=vocab_size, size=(11,))
|
|
150
|
+
b = torch.randint(low=0, high=vocab_size, size=(13,))
|
|
151
|
+
input_ids = [a, b]
|
|
152
|
+
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
|
153
|
+
|
|
154
|
+
data = get_tensordict({"input_ids": input_ids})
|
|
155
|
+
encoder = MsgpackEncoder()
|
|
156
|
+
decoder = MsgpackDecoder(TensorDict)
|
|
157
|
+
deserialized_data = decoder.decode(encoder.encode(data))
|
|
158
|
+
|
|
159
|
+
# test cu_seqlens
|
|
160
|
+
cu_seqlens = torch.tensor([0, 11, 24])
|
|
161
|
+
assert torch.all(torch.eq(cu_seqlens, deserialized_data["input_ids"].offsets()))
|
|
162
|
+
|
|
163
|
+
# test index
|
|
164
|
+
assert torch.all(torch.eq(deserialized_data["input_ids"][0], a))
|
|
165
|
+
assert torch.all(torch.eq(deserialized_data["input_ids"][1], b))
|
|
166
|
+
|
|
167
|
+
assert torch.all(torch.eq(deserialized_data[0]["input_ids"], a))
|
|
168
|
+
assert torch.all(torch.eq(deserialized_data[1]["input_ids"], b))
|
|
169
|
+
|
|
170
|
+
data_lst = deserialized_data.chunk(2)
|
|
171
|
+
|
|
172
|
+
assert torch.all(torch.eq(data_lst[0]["input_ids"][0], a))
|
|
173
|
+
assert torch.all(torch.eq(data_lst[1]["input_ids"][0], b))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def test_nested_tensordict_serialization():
|
|
177
|
+
td1 = tensordict.TensorDict({"a": torch.randn(2, 3), "b": torch.randn(2, 4)}, batch_size=[2])
|
|
178
|
+
|
|
179
|
+
td2 = tensordict.TensorDict({"c": torch.randn(2, 5), "d": torch.randn(2, 6)}, batch_size=[2])
|
|
180
|
+
|
|
181
|
+
td = tensordict.TensorDict({"part1": td1, "part2": td2, "e": torch.randn(2, 7)}, batch_size=[2])
|
|
182
|
+
|
|
183
|
+
encoder = MsgpackEncoder()
|
|
184
|
+
decoder = MsgpackDecoder(TensorDict)
|
|
185
|
+
deserialized_td = decoder.decode(encoder.encode(td))
|
|
186
|
+
|
|
187
|
+
assert isinstance(deserialized_td, tensordict.TensorDict)
|
|
188
|
+
assert set(deserialized_td.keys()) == set(td.keys())
|
|
189
|
+
assert isinstance(deserialized_td["part1"], tensordict.TensorDict)
|
|
190
|
+
assert isinstance(deserialized_td["part2"], tensordict.TensorDict)
|
|
191
|
+
|
|
192
|
+
assert set(deserialized_td["part1"].keys()) == set(td1.keys())
|
|
193
|
+
assert set(deserialized_td["part2"].keys()) == set(td2.keys())
|
|
194
|
+
|
|
195
|
+
for key in td.keys():
|
|
196
|
+
if isinstance(td[key], tensordict.TensorDict):
|
|
197
|
+
for inner_key in td[key].keys():
|
|
198
|
+
assert torch.allclose(deserialized_td[key][inner_key], td[key][inner_key]), (
|
|
199
|
+
f"Values for key '{key}.{inner_key}' do not match"
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
assert torch.allclose(deserialized_td[key], td[key]), f"Values for key '{key}' do not match"
|
|
@@ -0,0 +1,479 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import sys
|
|
16
|
+
import time
|
|
17
|
+
import uuid
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from threading import Thread
|
|
20
|
+
from unittest.mock import MagicMock
|
|
21
|
+
|
|
22
|
+
import pytest
|
|
23
|
+
import ray
|
|
24
|
+
import tensordict
|
|
25
|
+
import torch
|
|
26
|
+
import zmq
|
|
27
|
+
from tensordict import TensorDict
|
|
28
|
+
|
|
29
|
+
# Import your classes here
|
|
30
|
+
parent_dir = Path(__file__).resolve().parent.parent
|
|
31
|
+
sys.path.append(str(parent_dir))
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from transfer_queue import TransferQueueStorageSimpleUnit
|
|
35
|
+
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo
|
|
36
|
+
except ImportError:
|
|
37
|
+
# For testing purposes if imports are not available
|
|
38
|
+
TransferQueueStorageSimpleUnit = MagicMock()
|
|
39
|
+
ZMQServerInfo = MagicMock()
|
|
40
|
+
ZMQRequestType = MagicMock()
|
|
41
|
+
ZMQMessage = MagicMock()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# Mock ZMQ utilities if not available in test environment
|
|
45
|
+
def create_zmq_socket(context, socket_type, identity=None):
|
|
46
|
+
sock = context.socket(socket_type)
|
|
47
|
+
if identity:
|
|
48
|
+
sock.setsockopt(zmq.IDENTITY, identity)
|
|
49
|
+
return sock
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Mock Controller to handle handshake and data updates
|
|
53
|
+
class MockController:
|
|
54
|
+
def __init__(self, controller_id="controller_001"):
|
|
55
|
+
self.controller_id = controller_id
|
|
56
|
+
self.context = zmq.Context()
|
|
57
|
+
|
|
58
|
+
# Socket for handshake
|
|
59
|
+
self.handshake_socket = self.context.socket(zmq.ROUTER)
|
|
60
|
+
self.handshake_port = self._bind_to_random_port(self.handshake_socket)
|
|
61
|
+
|
|
62
|
+
# Socket for data status updates
|
|
63
|
+
self.data_update_socket = self.context.socket(zmq.ROUTER)
|
|
64
|
+
self.data_update_port = self._bind_to_random_port(self.data_update_socket)
|
|
65
|
+
|
|
66
|
+
self.zmq_server_info = ZMQServerInfo.create(
|
|
67
|
+
role="CONTROLLER",
|
|
68
|
+
id=controller_id,
|
|
69
|
+
ip="127.0.0.1",
|
|
70
|
+
ports={"handshake_socket": self.handshake_port, "data_status_update_socket": self.data_update_port},
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
self.running = True
|
|
74
|
+
self.handshake_thread = Thread(target=self._handle_handshake, daemon=True)
|
|
75
|
+
self.data_update_thread = Thread(target=self._handle_data_updates, daemon=True)
|
|
76
|
+
self.handshake_thread.start()
|
|
77
|
+
self.data_update_thread.start()
|
|
78
|
+
|
|
79
|
+
def _bind_to_random_port(self, socket):
|
|
80
|
+
port = socket.bind_to_random_port("tcp://127.0.0.1")
|
|
81
|
+
return port
|
|
82
|
+
|
|
83
|
+
def _handle_handshake(self):
|
|
84
|
+
poller = zmq.Poller()
|
|
85
|
+
poller.register(self.handshake_socket, zmq.POLLIN)
|
|
86
|
+
|
|
87
|
+
while self.running:
|
|
88
|
+
try:
|
|
89
|
+
socks = dict(poller.poll(100)) # 100ms timeout
|
|
90
|
+
if self.handshake_socket in socks:
|
|
91
|
+
identity, msg_bytes = self.handshake_socket.recv_multipart()
|
|
92
|
+
ZMQMessage.deserialize(msg_bytes)
|
|
93
|
+
|
|
94
|
+
# Send handshake ack
|
|
95
|
+
ack_msg = ZMQMessage.create(
|
|
96
|
+
request_type=ZMQRequestType.HANDSHAKE_ACK,
|
|
97
|
+
sender_id=self.controller_id,
|
|
98
|
+
body={"message": "Handshake successful"},
|
|
99
|
+
)
|
|
100
|
+
self.handshake_socket.send_multipart([identity, ack_msg.serialize()])
|
|
101
|
+
except zmq.Again:
|
|
102
|
+
continue
|
|
103
|
+
except Exception:
|
|
104
|
+
if self.running:
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
def _handle_data_updates(self):
|
|
108
|
+
poller = zmq.Poller()
|
|
109
|
+
poller.register(self.data_update_socket, zmq.POLLIN)
|
|
110
|
+
|
|
111
|
+
while self.running:
|
|
112
|
+
try:
|
|
113
|
+
socks = dict(poller.poll(100)) # 100ms timeout
|
|
114
|
+
if self.data_update_socket in socks:
|
|
115
|
+
identity, msg_bytes = self.data_update_socket.recv_multipart()
|
|
116
|
+
ZMQMessage.deserialize(msg_bytes)
|
|
117
|
+
|
|
118
|
+
# Send data update ack
|
|
119
|
+
ack_msg = ZMQMessage.create(
|
|
120
|
+
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
|
|
121
|
+
sender_id=self.controller_id,
|
|
122
|
+
body={"message": "Data update received"},
|
|
123
|
+
)
|
|
124
|
+
self.data_update_socket.send_multipart([identity, ack_msg.serialize()])
|
|
125
|
+
except zmq.Again:
|
|
126
|
+
continue
|
|
127
|
+
except Exception:
|
|
128
|
+
if self.running:
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
def stop(self):
|
|
132
|
+
self.running = False
|
|
133
|
+
time.sleep(0.1) # Give threads time to stop
|
|
134
|
+
self.handshake_socket.close()
|
|
135
|
+
self.data_update_socket.close()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# Mock client to send PUT/GET requests
|
|
139
|
+
class MockClient:
|
|
140
|
+
def __init__(self, storage_put_get_address):
|
|
141
|
+
self.context = zmq.Context()
|
|
142
|
+
self.socket = self.context.socket(zmq.DEALER)
|
|
143
|
+
self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
|
144
|
+
self.socket.connect(storage_put_get_address)
|
|
145
|
+
|
|
146
|
+
def send_put(self, client_id, global_indexes, local_indexes, field_data):
|
|
147
|
+
msg = ZMQMessage.create(
|
|
148
|
+
request_type=ZMQRequestType.PUT_DATA,
|
|
149
|
+
sender_id=f"mock_client_{client_id}",
|
|
150
|
+
body={"global_indexes": global_indexes, "local_indexes": local_indexes, "field_data": field_data},
|
|
151
|
+
)
|
|
152
|
+
self.socket.send(msg.serialize())
|
|
153
|
+
return ZMQMessage.deserialize(self.socket.recv())
|
|
154
|
+
|
|
155
|
+
def send_get(self, client_id, local_indexes, fields):
|
|
156
|
+
msg = ZMQMessage.create(
|
|
157
|
+
request_type=ZMQRequestType.GET_DATA,
|
|
158
|
+
sender_id=f"mock_client_{client_id}",
|
|
159
|
+
body={"local_indexes": local_indexes, "fields": fields},
|
|
160
|
+
)
|
|
161
|
+
self.socket.send(msg.serialize())
|
|
162
|
+
return ZMQMessage.deserialize(self.socket.recv())
|
|
163
|
+
|
|
164
|
+
def close(self):
|
|
165
|
+
self.socket.close()
|
|
166
|
+
self.context.term()
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@pytest.fixture(scope="session")
|
|
170
|
+
def ray_setup():
|
|
171
|
+
ray.init(ignore_reinit_error=True)
|
|
172
|
+
yield
|
|
173
|
+
ray.shutdown()
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@pytest.fixture
|
|
177
|
+
def storage_setup(ray_setup):
|
|
178
|
+
storage_size = 10000
|
|
179
|
+
tensordict.set_list_to_stack(True).set()
|
|
180
|
+
|
|
181
|
+
# Start mock controller
|
|
182
|
+
mock_controller = MockController(f"controller_{uuid.uuid4()}")
|
|
183
|
+
time.sleep(0.5) # Wait for controller sockets to be ready
|
|
184
|
+
|
|
185
|
+
# Start Ray actor
|
|
186
|
+
storage_actor = TransferQueueStorageSimpleUnit.options(max_concurrency=50, num_cpus=1).remote(storage_size)
|
|
187
|
+
|
|
188
|
+
# Register controller info
|
|
189
|
+
controller_infos = {mock_controller.controller_id: mock_controller.zmq_server_info}
|
|
190
|
+
ray.get(storage_actor.register_controller_info.remote(controller_infos))
|
|
191
|
+
|
|
192
|
+
# Get ZMQ address to connect client
|
|
193
|
+
zmq_info = ray.get(storage_actor.get_zmq_server_info.remote())
|
|
194
|
+
put_get_address = zmq_info.to_addr("put_get_socket")
|
|
195
|
+
time.sleep(1) # Wait for socket to be ready
|
|
196
|
+
|
|
197
|
+
yield storage_actor, put_get_address, mock_controller
|
|
198
|
+
|
|
199
|
+
# Cleanup
|
|
200
|
+
mock_controller.stop()
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def test_put_get_single_client(storage_setup):
|
|
204
|
+
"""Test basic put and get operations with a single client using TensorDict and torch tensors."""
|
|
205
|
+
_, put_get_address, _ = storage_setup
|
|
206
|
+
|
|
207
|
+
client = MockClient(put_get_address)
|
|
208
|
+
|
|
209
|
+
# PUT data
|
|
210
|
+
global_indexes = [0, 1, 2]
|
|
211
|
+
local_indexes = [0, 1, 2]
|
|
212
|
+
field_data = TensorDict(
|
|
213
|
+
{
|
|
214
|
+
"log_probs": [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]), torch.tensor([7.0, 8.0, 9.0])],
|
|
215
|
+
"rewards": [torch.tensor([10.0]), torch.tensor([20.0]), torch.tensor([30.0])],
|
|
216
|
+
},
|
|
217
|
+
batch_size=[],
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
response = client.send_put(0, global_indexes, local_indexes, field_data)
|
|
221
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
222
|
+
|
|
223
|
+
# GET data
|
|
224
|
+
response = client.send_get(0, [0, 1], ["log_probs", "rewards"])
|
|
225
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
226
|
+
|
|
227
|
+
retrieved_data = response.body["data"]
|
|
228
|
+
assert "log_probs" in retrieved_data
|
|
229
|
+
assert "rewards" in retrieved_data
|
|
230
|
+
assert retrieved_data["log_probs"].size(0) == 2
|
|
231
|
+
assert retrieved_data["rewards"].size(0) == 2
|
|
232
|
+
|
|
233
|
+
# Verify data correctness
|
|
234
|
+
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([1.0, 2.0, 3.0]))
|
|
235
|
+
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([4.0, 5.0, 6.0]))
|
|
236
|
+
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([10.0]))
|
|
237
|
+
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([20.0]))
|
|
238
|
+
|
|
239
|
+
client.close()
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def test_put_get_multiple_clients(storage_setup):
|
|
243
|
+
"""Test put and get operations with multiple clients including overlapping local indexes"""
|
|
244
|
+
_, put_get_address, _ = storage_setup
|
|
245
|
+
|
|
246
|
+
num_clients = 5
|
|
247
|
+
clients = [MockClient(put_get_address) for _ in range(num_clients)]
|
|
248
|
+
|
|
249
|
+
# Each client puts unique data using different local_indexes
|
|
250
|
+
for i, client in enumerate(clients):
|
|
251
|
+
global_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2]
|
|
252
|
+
local_indexes = [i * 10 + 0, i * 10 + 1, i * 10 + 2]
|
|
253
|
+
field_data = TensorDict(
|
|
254
|
+
{
|
|
255
|
+
"log_probs": [
|
|
256
|
+
torch.tensor([i, i + 1, i + 2]),
|
|
257
|
+
torch.tensor([i + 3, i + 4, i + 5]),
|
|
258
|
+
torch.tensor([i + 6, i + 7, i + 8]),
|
|
259
|
+
],
|
|
260
|
+
"rewards": [torch.tensor([i * 10]), torch.tensor([i * 10 + 10]), torch.tensor([i * 10 + 20])],
|
|
261
|
+
}
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
response = client.send_put(i, global_indexes, local_indexes, field_data)
|
|
265
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
266
|
+
|
|
267
|
+
# Now simulate a third client that writes to overlapping local_indexes (e.g., index 0)
|
|
268
|
+
overlapping_client = MockClient(put_get_address)
|
|
269
|
+
overlap_local_indexes = [0] # Overlaps with first client's index 0
|
|
270
|
+
overlap_field_data = TensorDict({"log_probs": [torch.tensor([999, 999, 999])], "rewards": [torch.tensor([999])]})
|
|
271
|
+
response = overlapping_client.send_put(
|
|
272
|
+
client_id=99, global_indexes=[0], local_indexes=overlap_local_indexes, field_data=overlap_field_data
|
|
273
|
+
)
|
|
274
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
275
|
+
|
|
276
|
+
# Each original client gets its own data (except for index 0 which was overwritten)
|
|
277
|
+
for i, client in enumerate(clients):
|
|
278
|
+
response = client.send_get(i, [i * 10 + 0, i * 10 + 1], ["log_probs", "rewards"])
|
|
279
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
280
|
+
|
|
281
|
+
retrieved_data = response.body["data"]
|
|
282
|
+
assert retrieved_data["log_probs"].size(0) == 2
|
|
283
|
+
assert retrieved_data["rewards"].size(0) == 2
|
|
284
|
+
|
|
285
|
+
# For index 0, expect data from overlapping_client; others from original client
|
|
286
|
+
if i == 0:
|
|
287
|
+
# Index 0 was overwritten
|
|
288
|
+
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([999, 999, 999]))
|
|
289
|
+
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([999]))
|
|
290
|
+
# Index 1 remains original
|
|
291
|
+
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([3, 4, 5]))
|
|
292
|
+
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([10]))
|
|
293
|
+
else:
|
|
294
|
+
# All data remains original
|
|
295
|
+
torch.testing.assert_close(retrieved_data["log_probs"][0], torch.tensor([i, i + 1, i + 2]))
|
|
296
|
+
torch.testing.assert_close(retrieved_data["log_probs"][1], torch.tensor([i + 3, i + 4, i + 5]))
|
|
297
|
+
torch.testing.assert_close(retrieved_data["rewards"][0], torch.tensor([i * 10]))
|
|
298
|
+
torch.testing.assert_close(retrieved_data["rewards"][1], torch.tensor([i * 10 + 10]))
|
|
299
|
+
|
|
300
|
+
# Cleanup
|
|
301
|
+
for client in clients:
|
|
302
|
+
client.close()
|
|
303
|
+
overlapping_client.close()
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def test_performance_basic(storage_setup):
|
|
307
|
+
"""Basic performance test with larger data volume and proper index handling"""
|
|
308
|
+
_, put_get_address, _ = storage_setup
|
|
309
|
+
|
|
310
|
+
client = MockClient(put_get_address)
|
|
311
|
+
|
|
312
|
+
# PUT performance test
|
|
313
|
+
put_latencies = []
|
|
314
|
+
num_puts = 50
|
|
315
|
+
batch_size = 128
|
|
316
|
+
|
|
317
|
+
for i in range(num_puts):
|
|
318
|
+
start = time.time()
|
|
319
|
+
|
|
320
|
+
# Use larger batch size and more complex index mapping
|
|
321
|
+
global_indexes = list(range(i * batch_size, (i + 1) * batch_size))
|
|
322
|
+
local_indexes = list(range(i * batch_size, (i + 1) * batch_size))
|
|
323
|
+
|
|
324
|
+
# Create larger tensor data to increase data volume
|
|
325
|
+
log_probs_data = []
|
|
326
|
+
rewards_data = []
|
|
327
|
+
|
|
328
|
+
for j in range(batch_size):
|
|
329
|
+
# Each sample contains larger tensors to increase data transfer volume
|
|
330
|
+
log_probs_tensor = torch.randn(32768)
|
|
331
|
+
rewards_tensor = torch.randn(32768)
|
|
332
|
+
log_probs_data.append(log_probs_tensor)
|
|
333
|
+
rewards_data.append(rewards_tensor)
|
|
334
|
+
|
|
335
|
+
field_data = TensorDict({"log_probs": log_probs_data, "rewards": rewards_data}, batch_size=[batch_size])
|
|
336
|
+
|
|
337
|
+
response = client.send_put(0, global_indexes, local_indexes, field_data)
|
|
338
|
+
latency = time.time() - start
|
|
339
|
+
put_latencies.append(latency)
|
|
340
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
341
|
+
|
|
342
|
+
# GET performance test
|
|
343
|
+
get_latencies = []
|
|
344
|
+
num_gets = 50
|
|
345
|
+
|
|
346
|
+
for i in range(num_gets):
|
|
347
|
+
start = time.time()
|
|
348
|
+
# Retrieve larger batch of data
|
|
349
|
+
indices = list(range(i * batch_size, (i + 1) * batch_size)) # Retrieve batch_size indices of data each time
|
|
350
|
+
response = client.send_get(0, indices, ["log_probs", "rewards"])
|
|
351
|
+
latency = time.time() - start
|
|
352
|
+
get_latencies.append(latency)
|
|
353
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
354
|
+
|
|
355
|
+
avg_put_latency = sum(put_latencies) / len(put_latencies) * 1000 # ms
|
|
356
|
+
avg_get_latency = sum(get_latencies) / len(get_latencies) * 1000 # ms
|
|
357
|
+
|
|
358
|
+
# Adjust performance thresholds to accommodate larger data volume
|
|
359
|
+
assert avg_put_latency < 5000, f"Avg PUT latency {avg_put_latency}ms exceeds threshold"
|
|
360
|
+
assert avg_get_latency < 5000, f"Avg GET latency {avg_get_latency}ms exceeds threshold"
|
|
361
|
+
|
|
362
|
+
client.close()
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def test_put_get_nested_tensor_single_client(storage_setup):
|
|
366
|
+
"""Test basic put and get operations with a single client using TensorDict and nested tensors."""
|
|
367
|
+
_, put_get_address, _ = storage_setup
|
|
368
|
+
|
|
369
|
+
client = MockClient(put_get_address)
|
|
370
|
+
|
|
371
|
+
# PUT data
|
|
372
|
+
global_indexes = [0, 1, 2]
|
|
373
|
+
local_indexes = [0, 1, 2]
|
|
374
|
+
|
|
375
|
+
field_data = TensorDict(
|
|
376
|
+
{
|
|
377
|
+
"variable_length_sequences": [
|
|
378
|
+
torch.tensor([-0.5, -1.2, -0.8]),
|
|
379
|
+
torch.tensor([-0.3, -1.5, -2.1, -0.9]),
|
|
380
|
+
torch.tensor([-1.1, -0.7]),
|
|
381
|
+
],
|
|
382
|
+
"attention_mask": [torch.tensor([1, 1, 1]), torch.tensor([1, 1, 1, 1]), torch.tensor([1, 1])],
|
|
383
|
+
},
|
|
384
|
+
batch_size=[],
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
response = client.send_put(0, global_indexes, local_indexes, field_data)
|
|
388
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
389
|
+
|
|
390
|
+
# GET data
|
|
391
|
+
response = client.send_get(0, [0, 2], ["variable_length_sequences", "attention_mask"])
|
|
392
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
393
|
+
|
|
394
|
+
retrieved_data = response.body["data"]
|
|
395
|
+
assert "variable_length_sequences" in retrieved_data
|
|
396
|
+
assert "attention_mask" in retrieved_data
|
|
397
|
+
assert retrieved_data["variable_length_sequences"].size(0) == 2
|
|
398
|
+
assert retrieved_data["attention_mask"].size(0) == 2
|
|
399
|
+
|
|
400
|
+
# Verify data correctness
|
|
401
|
+
torch.testing.assert_close(retrieved_data["variable_length_sequences"][0], torch.tensor([-0.5, -1.2, -0.8]))
|
|
402
|
+
torch.testing.assert_close(retrieved_data["variable_length_sequences"][1], torch.tensor([-1.1, -0.7]))
|
|
403
|
+
torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))
|
|
404
|
+
torch.testing.assert_close(retrieved_data["attention_mask"][1], torch.tensor([1, 1]))
|
|
405
|
+
|
|
406
|
+
client.close()
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def test_put_get_nested_nontensor_single_client(storage_setup):
|
|
410
|
+
"""Test basic put and get operations with a single client using non-tensor data (strings)."""
|
|
411
|
+
_, put_get_address, _ = storage_setup
|
|
412
|
+
|
|
413
|
+
client = MockClient(put_get_address)
|
|
414
|
+
|
|
415
|
+
# PUT data
|
|
416
|
+
global_indexes = [0, 1, 2]
|
|
417
|
+
local_indexes = [0, 1, 2]
|
|
418
|
+
field_data = TensorDict(
|
|
419
|
+
{
|
|
420
|
+
"prompt_text": ["Hello world!", "This is a longer sentence for testing", "Test case"],
|
|
421
|
+
"response_text": ["Hi there!", "This is the response to the longer sentence", "Test response"],
|
|
422
|
+
},
|
|
423
|
+
batch_size=[],
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
response = client.send_put(0, global_indexes, local_indexes, field_data)
|
|
427
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
428
|
+
|
|
429
|
+
# GET data
|
|
430
|
+
response = client.send_get(0, [0, 1, 2], ["prompt_text", "response_text"])
|
|
431
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
432
|
+
|
|
433
|
+
retrieved_data = response.body["data"]
|
|
434
|
+
assert "prompt_text" in retrieved_data
|
|
435
|
+
assert "response_text" in retrieved_data
|
|
436
|
+
|
|
437
|
+
# Verify data correctness
|
|
438
|
+
assert isinstance(retrieved_data["prompt_text"][0], str)
|
|
439
|
+
assert isinstance(retrieved_data["response_text"][0], str)
|
|
440
|
+
|
|
441
|
+
assert retrieved_data["prompt_text"][0] == "Hello world!"
|
|
442
|
+
assert retrieved_data["prompt_text"][1] == "This is a longer sentence for testing"
|
|
443
|
+
assert retrieved_data["prompt_text"][2] == "Test case"
|
|
444
|
+
assert retrieved_data["response_text"][0] == "Hi there!"
|
|
445
|
+
assert retrieved_data["response_text"][1] == "This is the response to the longer sentence"
|
|
446
|
+
assert retrieved_data["response_text"][2] == "Test response"
|
|
447
|
+
|
|
448
|
+
client.close()
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def test_put_get_single_item_single_client(storage_setup):
|
|
452
|
+
"""Test put and get operations for a single item with a single client."""
|
|
453
|
+
_, put_get_address, _ = storage_setup
|
|
454
|
+
|
|
455
|
+
client = MockClient(put_get_address)
|
|
456
|
+
|
|
457
|
+
# PUT data
|
|
458
|
+
field_data = TensorDict(
|
|
459
|
+
{
|
|
460
|
+
"prompt_text": ["Hello world!"],
|
|
461
|
+
"attention_mask": [torch.tensor([1, 1, 1])],
|
|
462
|
+
},
|
|
463
|
+
batch_size=[],
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
response = client.send_put(0, [0], [0], field_data)
|
|
467
|
+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE
|
|
468
|
+
|
|
469
|
+
# GET data
|
|
470
|
+
response = client.send_get(0, [0], ["prompt_text", "attention_mask"])
|
|
471
|
+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
|
|
472
|
+
|
|
473
|
+
retrieved_data = response.body["data"]
|
|
474
|
+
assert "prompt_text" in retrieved_data
|
|
475
|
+
assert "attention_mask" in retrieved_data
|
|
476
|
+
|
|
477
|
+
assert retrieved_data["prompt_text"][0] == "Hello world!"
|
|
478
|
+
assert retrieved_data["attention_mask"].shape == (1, 3)
|
|
479
|
+
torch.testing.assert_close(retrieved_data["attention_mask"][0], torch.tensor([1, 1, 1]))
|