TransferQueue 0.0.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +307 -0
- recipe/simple_use_case/sync_demo.py +223 -0
- tests/test_client.py +390 -0
- tests/test_controller.py +268 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +479 -0
- transfer_queue/__init__.py +42 -0
- transfer_queue/client.py +663 -0
- transfer_queue/controller.py +772 -0
- transfer_queue/metadata.py +603 -0
- transfer_queue/storage.py +515 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +98 -0
- transfer_queue/utils/zmq_utils.py +175 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.0.1.dev0.dist-info/METADATA +15 -0
- transferqueue-0.0.1.dev0.dist-info/RECORD +21 -0
- transferqueue-0.0.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.0.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.0.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,515 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import time
|
|
18
|
+
from operator import itemgetter
|
|
19
|
+
from threading import Thread
|
|
20
|
+
from uuid import uuid4
|
|
21
|
+
|
|
22
|
+
import ray
|
|
23
|
+
import torch
|
|
24
|
+
import zmq
|
|
25
|
+
from ray.util import get_node_ip_address
|
|
26
|
+
from tensordict import NonTensorStack, TensorDict
|
|
27
|
+
|
|
28
|
+
from transfer_queue.utils.utils import TransferQueueRole
|
|
29
|
+
from transfer_queue.utils.zmq_utils import (
|
|
30
|
+
ZMQMessage,
|
|
31
|
+
ZMQRequestType,
|
|
32
|
+
ZMQServerInfo,
|
|
33
|
+
create_zmq_socket,
|
|
34
|
+
get_free_port,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
39
|
+
|
|
40
|
+
TQ_STORAGE_POLLER_TIMEOUT = os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 1000)
|
|
41
|
+
TQ_STORAGE_HANDSHAKE_TIMEOUT = int(os.environ.get("TQ_STORAGE_HANDSHAKE_TIMEOUT", 30))
|
|
42
|
+
TQ_DATA_UPDATE_RESPONSE_TIMEOUT = int(os.environ.get("TQ_DATA_UPDATE_RESPONSE_TIMEOUT", 600))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class StorageUnitData:
|
|
46
|
+
"""
|
|
47
|
+
Class used for storing several elements, each element is composed of several fields and corresponding data, like:
|
|
48
|
+
#####################################################
|
|
49
|
+
# local_index | field_name1 | field_name2 | ... #
|
|
50
|
+
# 0 | item1 | item2 | ... #
|
|
51
|
+
# 1 | item3 | item4 | ... #
|
|
52
|
+
# 2 | item5 | item6 | ... #
|
|
53
|
+
#####################################################
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, storage_size: int):
|
|
57
|
+
# Dict containing field names and corresponding data in the field, e.g. {"field_name1": [data1, data2, ...]}
|
|
58
|
+
self.field_data: dict[str, list] = {}
|
|
59
|
+
|
|
60
|
+
# Maximum number of elements stored in storage unit
|
|
61
|
+
self.storage_size = storage_size
|
|
62
|
+
|
|
63
|
+
def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[str, list]:
|
|
64
|
+
"""
|
|
65
|
+
Get data from storage unit according to given fields and local_indexes.
|
|
66
|
+
|
|
67
|
+
param:
|
|
68
|
+
fields: Field names used for getting data.
|
|
69
|
+
local_indexes: Local indexes used for getting data.
|
|
70
|
+
return:
|
|
71
|
+
TensorDict with field names as keys, corresponding data list as values.
|
|
72
|
+
"""
|
|
73
|
+
result: dict[str, list] = {}
|
|
74
|
+
|
|
75
|
+
for field in fields:
|
|
76
|
+
# Validate field name
|
|
77
|
+
if field not in self.field_data:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if len(local_indexes) == 1:
|
|
83
|
+
# The unsqueeze op make the shape from n to (1, n)
|
|
84
|
+
gathered_item = self.field_data[field][local_indexes[0]]
|
|
85
|
+
if not isinstance(gathered_item, torch.Tensor):
|
|
86
|
+
result[field] = NonTensorStack(gathered_item)
|
|
87
|
+
else:
|
|
88
|
+
result[field] = gathered_item.unsqueeze(0)
|
|
89
|
+
else:
|
|
90
|
+
gathered_items = list(itemgetter(*local_indexes)(self.field_data[field]))
|
|
91
|
+
|
|
92
|
+
if gathered_items:
|
|
93
|
+
all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
|
|
94
|
+
if all_tensors:
|
|
95
|
+
result[field] = torch.nested.as_nested_tensor(gathered_items)
|
|
96
|
+
else:
|
|
97
|
+
result[field] = NonTensorStack(*gathered_items)
|
|
98
|
+
|
|
99
|
+
return TensorDict(result)
|
|
100
|
+
|
|
101
|
+
def put_data(self, field_data: TensorDict[str, list], local_indexes: list[int]) -> None:
|
|
102
|
+
"""
|
|
103
|
+
Put or update data into storage unit according to given field_data and local_indexes.
|
|
104
|
+
|
|
105
|
+
param:
|
|
106
|
+
field_data: Dict with field names as keys, corresponding data in the field as values.
|
|
107
|
+
local_indexes: Local indexes used for putting data.
|
|
108
|
+
"""
|
|
109
|
+
extracted_data = dict(field_data)
|
|
110
|
+
|
|
111
|
+
for f, values in extracted_data.items():
|
|
112
|
+
if f not in self.field_data:
|
|
113
|
+
self.field_data[f] = [None] * self.storage_size
|
|
114
|
+
|
|
115
|
+
for i, idx in enumerate(local_indexes):
|
|
116
|
+
if idx < 0 or idx >= self.storage_size:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond "
|
|
119
|
+
f"storage_size: {self.storage_size}"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
self.field_data[f][idx] = values[i]
|
|
123
|
+
|
|
124
|
+
def clear(self, local_indexes: list[int]) -> None:
|
|
125
|
+
"""
|
|
126
|
+
Clear data at specified local_indexes by setting all related fields to None.
|
|
127
|
+
|
|
128
|
+
param:
|
|
129
|
+
local_indexes: local_indexes to clear.
|
|
130
|
+
"""
|
|
131
|
+
# Validate local_indexes
|
|
132
|
+
for idx in local_indexes:
|
|
133
|
+
if idx < 0 or idx >= self.storage_size:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
f"StorageUnitData clear operation receive invalid local_index: {idx} beyond "
|
|
136
|
+
f"storage_size: {self.storage_size}"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Clear data at specified local_indexes
|
|
140
|
+
for f in self.field_data:
|
|
141
|
+
for idx in local_indexes:
|
|
142
|
+
self.field_data[f][idx] = None
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@ray.remote(num_cpus=1)
|
|
146
|
+
class TransferQueueStorageSimpleUnit:
|
|
147
|
+
def __init__(self, storage_size: int):
|
|
148
|
+
super().__init__()
|
|
149
|
+
self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4()}"
|
|
150
|
+
self.storage_size = storage_size
|
|
151
|
+
self.controller_infos: dict[str, ZMQServerInfo] = {}
|
|
152
|
+
|
|
153
|
+
self.experience_data = StorageUnitData(self.storage_size)
|
|
154
|
+
|
|
155
|
+
self.zmq_server_info = ZMQServerInfo.create(
|
|
156
|
+
role=TransferQueueRole.STORAGE,
|
|
157
|
+
id=str(self.storage_unit_id),
|
|
158
|
+
ip=get_node_ip_address(),
|
|
159
|
+
ports={"put_get_socket": get_free_port()},
|
|
160
|
+
)
|
|
161
|
+
self._init_zmq_socket()
|
|
162
|
+
|
|
163
|
+
def _init_zmq_socket(self) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Initialize ZMQ socket connections between storage unit and controllers/clients:
|
|
166
|
+
- controller_handshake_sockets:
|
|
167
|
+
Handshake between storage unit and controllers.
|
|
168
|
+
- data_status_update_sockets:
|
|
169
|
+
Broadcast data update status from storage unit to controllers when handling put operation.
|
|
170
|
+
- put_get_socket:
|
|
171
|
+
Handle put/get requests from clients.
|
|
172
|
+
"""
|
|
173
|
+
self.zmq_context = zmq.Context()
|
|
174
|
+
|
|
175
|
+
self.controller_handshake_sockets: dict[str, zmq.Socket] = {}
|
|
176
|
+
self.data_status_update_sockets: dict[str, zmq.Socket] = {}
|
|
177
|
+
|
|
178
|
+
self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER)
|
|
179
|
+
self.put_get_socket.bind(self.zmq_server_info.to_addr("put_get_socket"))
|
|
180
|
+
|
|
181
|
+
def register_controller_info(self, controller_infos: dict[str, ZMQServerInfo]) -> None:
|
|
182
|
+
"""
|
|
183
|
+
Build connections between storage unit and controllers, start put/get process.
|
|
184
|
+
|
|
185
|
+
param:
|
|
186
|
+
controller_infos: Dict with controller infos.
|
|
187
|
+
"""
|
|
188
|
+
self.controller_infos = controller_infos
|
|
189
|
+
|
|
190
|
+
self._init_zmq_sockets_with_controller_infos()
|
|
191
|
+
self._connect_to_controller()
|
|
192
|
+
self._start_process_put_get()
|
|
193
|
+
|
|
194
|
+
def _init_zmq_sockets_with_controller_infos(self) -> None:
|
|
195
|
+
"""Initialize ZMQ sockets between storage unit and controllers for handshake."""
|
|
196
|
+
for controller_id in self.controller_infos.keys():
|
|
197
|
+
self.controller_handshake_sockets[controller_id] = create_zmq_socket(
|
|
198
|
+
self.zmq_context,
|
|
199
|
+
zmq.DEALER,
|
|
200
|
+
identity=f"{self.storage_unit_id}-controller_handshake_sockets-{uuid4()}".encode(),
|
|
201
|
+
)
|
|
202
|
+
self.data_status_update_sockets[controller_id] = create_zmq_socket(
|
|
203
|
+
self.zmq_context,
|
|
204
|
+
zmq.DEALER,
|
|
205
|
+
identity=f"{self.storage_unit_id}-data_status_update_sockets-{uuid4()}".encode(),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def _connect_to_controller(self) -> None:
|
|
209
|
+
"""Connect storage unit to all controllers."""
|
|
210
|
+
connected_controllers: set[str] = set()
|
|
211
|
+
|
|
212
|
+
# Create zmq poller for handshake confirmation between controller and storage unit
|
|
213
|
+
poller = zmq.Poller()
|
|
214
|
+
|
|
215
|
+
for controller_id, controller_info in self.controller_infos.items():
|
|
216
|
+
self.controller_handshake_sockets[controller_id].connect(controller_info.to_addr("handshake_socket"))
|
|
217
|
+
logger.debug(
|
|
218
|
+
f"[{self.zmq_server_info.id}]: Handshake connection from storage unit id #{self.zmq_server_info.id} "
|
|
219
|
+
f"to controller id #{controller_id} establish successfully."
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Send handshake request to controllers
|
|
223
|
+
request_msg = ZMQMessage.create(
|
|
224
|
+
request_type=ZMQRequestType.HANDSHAKE,
|
|
225
|
+
sender_id=self.zmq_server_info.id,
|
|
226
|
+
body={
|
|
227
|
+
"storage_unit_id": self.storage_unit_id,
|
|
228
|
+
"storage_size": self.storage_size,
|
|
229
|
+
},
|
|
230
|
+
).serialize()
|
|
231
|
+
|
|
232
|
+
self.controller_handshake_sockets[controller_id].send(request_msg)
|
|
233
|
+
logger.debug(
|
|
234
|
+
f"[{self.zmq_server_info.id}]: Send handshake request from storage unit id #{self.zmq_server_info.id} "
|
|
235
|
+
f"to controller id #{controller_id} successfully."
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
poller.register(self.controller_handshake_sockets[controller_id], zmq.POLLIN)
|
|
239
|
+
|
|
240
|
+
start_time = time.time()
|
|
241
|
+
while (
|
|
242
|
+
len(connected_controllers) < len(self.controller_infos)
|
|
243
|
+
and time.time() - start_time < TQ_STORAGE_HANDSHAKE_TIMEOUT
|
|
244
|
+
):
|
|
245
|
+
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT))
|
|
246
|
+
|
|
247
|
+
for controller_handshake_socket in self.controller_handshake_sockets.values():
|
|
248
|
+
if controller_handshake_socket in socks:
|
|
249
|
+
response_msg = ZMQMessage.deserialize(controller_handshake_socket.recv())
|
|
250
|
+
|
|
251
|
+
if response_msg.request_type == ZMQRequestType.HANDSHAKE_ACK:
|
|
252
|
+
connected_controllers.add(response_msg.sender_id)
|
|
253
|
+
logger.debug(
|
|
254
|
+
f"[{self.zmq_server_info.id}]: Get handshake ACK response from "
|
|
255
|
+
f"controller id #{str(response_msg.sender_id)} to storage unit id "
|
|
256
|
+
f"#{self.zmq_server_info.id} successfully."
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if len(connected_controllers) < len(self.controller_infos):
|
|
260
|
+
logger.warning(
|
|
261
|
+
f"[{self.zmq_server_info.id}]: Only get {len(connected_controllers)} / {len(self.controller_infos)} "
|
|
262
|
+
f"successful handshake connections to controllers from storage unit id #{self.zmq_server_info.id}"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def _start_process_put_get(self) -> None:
|
|
266
|
+
"""Create a daemon thread and start put/get process."""
|
|
267
|
+
self.process_put_get_thread = Thread(
|
|
268
|
+
target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.zmq_server_info.id}", daemon=True
|
|
269
|
+
)
|
|
270
|
+
self.process_put_get_thread.start()
|
|
271
|
+
|
|
272
|
+
def _process_put_get(self) -> None:
|
|
273
|
+
"""Process put_get_socket request."""
|
|
274
|
+
poller = zmq.Poller()
|
|
275
|
+
poller.register(self.put_get_socket, zmq.POLLIN)
|
|
276
|
+
|
|
277
|
+
while True:
|
|
278
|
+
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT))
|
|
279
|
+
|
|
280
|
+
if self.put_get_socket in socks:
|
|
281
|
+
identity, serialized_msg = self.put_get_socket.recv_multipart()
|
|
282
|
+
|
|
283
|
+
try:
|
|
284
|
+
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
285
|
+
operation = request_msg.request_type
|
|
286
|
+
logger.debug(f"[{self.zmq_server_info.id}]: receive operation: {operation}, message: {request_msg}")
|
|
287
|
+
|
|
288
|
+
if operation == ZMQRequestType.PUT_DATA:
|
|
289
|
+
response_msg = self._handle_put(request_msg)
|
|
290
|
+
elif operation == ZMQRequestType.GET_DATA:
|
|
291
|
+
response_msg = self._handle_get(request_msg)
|
|
292
|
+
elif operation == ZMQRequestType.CLEAR_DATA:
|
|
293
|
+
response_msg = self._handle_clear(request_msg)
|
|
294
|
+
else:
|
|
295
|
+
response_msg = ZMQMessage.create(
|
|
296
|
+
request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR,
|
|
297
|
+
sender_id=self.zmq_server_info.id,
|
|
298
|
+
body={
|
|
299
|
+
"message": f"Storage unit id #{self.zmq_server_info.id} "
|
|
300
|
+
f"receive invalid operation: {operation}."
|
|
301
|
+
},
|
|
302
|
+
)
|
|
303
|
+
except Exception as e:
|
|
304
|
+
response_msg = ZMQMessage.create(
|
|
305
|
+
request_type=ZMQRequestType.PUT_GET_ERROR,
|
|
306
|
+
sender_id=self.zmq_server_info.id,
|
|
307
|
+
body={
|
|
308
|
+
"message": f"Storage unit id #{self.zmq_server_info.id} occur error in processing "
|
|
309
|
+
f"put/get/clear request, detail error message: {str(e)}."
|
|
310
|
+
},
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
self.put_get_socket.send_multipart([identity, response_msg.serialize()])
|
|
314
|
+
|
|
315
|
+
def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage:
|
|
316
|
+
"""
|
|
317
|
+
Handle put request, add or update data into storage unit.
|
|
318
|
+
|
|
319
|
+
param:
|
|
320
|
+
data_parts: ZMQMessage from client.
|
|
321
|
+
return:
|
|
322
|
+
Put data success response ZMQMessage.
|
|
323
|
+
"""
|
|
324
|
+
try:
|
|
325
|
+
global_indexes = data_parts.body["global_indexes"]
|
|
326
|
+
local_indexes = data_parts.body["local_indexes"]
|
|
327
|
+
field_data = data_parts.body["field_data"] # field_data should be in {field_name: [real data]} format.
|
|
328
|
+
|
|
329
|
+
self.experience_data.put_data(field_data, local_indexes)
|
|
330
|
+
|
|
331
|
+
# After put operation finish, send a message to the client
|
|
332
|
+
response_msg = ZMQMessage.create(
|
|
333
|
+
request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.zmq_server_info.id, body={}
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Gather per-field dtype and shape information for each field
|
|
337
|
+
# global_indexes, local_indexes, and field_data correspond one-to-one
|
|
338
|
+
per_field_dtypes = {}
|
|
339
|
+
per_field_shapes = {}
|
|
340
|
+
|
|
341
|
+
# Initialize the data structure for each global index
|
|
342
|
+
for global_idx in global_indexes:
|
|
343
|
+
per_field_dtypes[global_idx] = {}
|
|
344
|
+
per_field_shapes[global_idx] = {}
|
|
345
|
+
|
|
346
|
+
# For each field, extract dtype and shape for each sample
|
|
347
|
+
for field in field_data.keys():
|
|
348
|
+
for i, data_item in enumerate(field_data[field]):
|
|
349
|
+
global_idx = global_indexes[i]
|
|
350
|
+
per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None
|
|
351
|
+
per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None
|
|
352
|
+
|
|
353
|
+
# Broadcast data update message to all controllers with per-field dtype/shape information
|
|
354
|
+
self._notify_data_update(list(field_data.keys()), global_indexes, per_field_dtypes, per_field_shapes)
|
|
355
|
+
return response_msg
|
|
356
|
+
except Exception as e:
|
|
357
|
+
return ZMQMessage.create(
|
|
358
|
+
request_type=ZMQRequestType.PUT_ERROR,
|
|
359
|
+
sender_id=self.zmq_server_info.id,
|
|
360
|
+
body={
|
|
361
|
+
"message": f"Failed to put data into storage unit id "
|
|
362
|
+
f"#{self.zmq_server_info.id}, detail error message: {str(e)}"
|
|
363
|
+
},
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def _notify_data_update(self, fields, global_indexes, dtypes, shapes) -> None:
|
|
367
|
+
"""
|
|
368
|
+
Broadcast data status update to all controllers.
|
|
369
|
+
|
|
370
|
+
param:
|
|
371
|
+
fields: data update related fields.
|
|
372
|
+
global_indexes: data update related global_indexes.
|
|
373
|
+
dtypes: per-field dtypes for each field, in {global_index: {field: dtype}} format.
|
|
374
|
+
shapes: per-field shapes for each field, in {global_index: {field: shape}} format.
|
|
375
|
+
"""
|
|
376
|
+
# Create zmq poller for notifying data update information
|
|
377
|
+
poller = zmq.Poller()
|
|
378
|
+
|
|
379
|
+
# Connect data status update socket to all controllers
|
|
380
|
+
for controller_id, controller_info in self.controller_infos.items():
|
|
381
|
+
data_status_update_socket = self.data_status_update_sockets[controller_id]
|
|
382
|
+
data_status_update_socket.connect(controller_info.to_addr("data_status_update_socket"))
|
|
383
|
+
logger.debug(
|
|
384
|
+
f"[{self.zmq_server_info.id}]: Data status update connection from "
|
|
385
|
+
f"storage unit id #{self.zmq_server_info.id} to "
|
|
386
|
+
f"controller id #{controller_id} establish successfully."
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
try:
|
|
390
|
+
poller.register(data_status_update_socket, zmq.POLLIN)
|
|
391
|
+
|
|
392
|
+
request_msg = ZMQMessage.create(
|
|
393
|
+
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE,
|
|
394
|
+
sender_id=self.zmq_server_info.id,
|
|
395
|
+
body={
|
|
396
|
+
"fields": fields,
|
|
397
|
+
"global_indexes": global_indexes,
|
|
398
|
+
"dtypes": dtypes,
|
|
399
|
+
"shapes": shapes,
|
|
400
|
+
},
|
|
401
|
+
).serialize()
|
|
402
|
+
|
|
403
|
+
data_status_update_socket.send(request_msg)
|
|
404
|
+
logger.debug(
|
|
405
|
+
f"[{self.zmq_server_info.id}]: Send data status update request "
|
|
406
|
+
f"from storage unit id #{self.zmq_server_info.id} "
|
|
407
|
+
f"to controller id #{controller_id} successfully."
|
|
408
|
+
)
|
|
409
|
+
except Exception as e:
|
|
410
|
+
request_msg = ZMQMessage.create(
|
|
411
|
+
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR,
|
|
412
|
+
sender_id=self.zmq_server_info.id,
|
|
413
|
+
body={
|
|
414
|
+
"message": f"Failed to notify data status update information from "
|
|
415
|
+
f"storage unit id #{self.zmq_server_info.id}, "
|
|
416
|
+
f"detail error message: {str(e)}"
|
|
417
|
+
},
|
|
418
|
+
).serialize()
|
|
419
|
+
|
|
420
|
+
data_status_update_socket.send(request_msg)
|
|
421
|
+
|
|
422
|
+
# Make sure all controllers successfully receive data status update information.
|
|
423
|
+
response_controllers: set[str] = set()
|
|
424
|
+
start_time = time.time()
|
|
425
|
+
|
|
426
|
+
while (
|
|
427
|
+
len(response_controllers) < len(self.controller_infos)
|
|
428
|
+
and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT
|
|
429
|
+
):
|
|
430
|
+
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT))
|
|
431
|
+
|
|
432
|
+
for data_status_update_socket in self.data_status_update_sockets.values():
|
|
433
|
+
if data_status_update_socket in socks:
|
|
434
|
+
response_msg = ZMQMessage.deserialize(data_status_update_socket.recv())
|
|
435
|
+
|
|
436
|
+
if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK:
|
|
437
|
+
response_controllers.add(response_msg.sender_id)
|
|
438
|
+
logger.debug(
|
|
439
|
+
f"[{self.zmq_server_info.id}]: Get data status update ACK response "
|
|
440
|
+
f"from controller id #{response_msg.sender_id} "
|
|
441
|
+
f"to storage unit id #{self.zmq_server_info.id} successfully."
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
if len(response_controllers) < len(self.controller_infos):
|
|
445
|
+
logger.warning(
|
|
446
|
+
f"[{self.zmq_server_info.id}]: Storage unit id #{self.zmq_server_info.id} "
|
|
447
|
+
f"only get {len(response_controllers)} / {len(self.controller_infos)} "
|
|
448
|
+
f"data status update ACK responses from controllers."
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage:
|
|
452
|
+
"""
|
|
453
|
+
Handle get request, return data from storage unit.
|
|
454
|
+
|
|
455
|
+
param:
|
|
456
|
+
data_parts: ZMQMessage from client.
|
|
457
|
+
return:
|
|
458
|
+
Get data success response ZMQMessage, containing target data.
|
|
459
|
+
"""
|
|
460
|
+
try:
|
|
461
|
+
fields = data_parts.body["fields"]
|
|
462
|
+
local_indexes = data_parts.body["local_indexes"]
|
|
463
|
+
|
|
464
|
+
result_data = self.experience_data.get_data(fields, local_indexes)
|
|
465
|
+
|
|
466
|
+
response_msg = ZMQMessage.create(
|
|
467
|
+
request_type=ZMQRequestType.GET_DATA_RESPONSE,
|
|
468
|
+
sender_id=self.zmq_server_info.id,
|
|
469
|
+
body={
|
|
470
|
+
"data": result_data,
|
|
471
|
+
},
|
|
472
|
+
)
|
|
473
|
+
except Exception as e:
|
|
474
|
+
response_msg = ZMQMessage.create(
|
|
475
|
+
request_type=ZMQRequestType.GET_ERROR,
|
|
476
|
+
sender_id=self.zmq_server_info.id,
|
|
477
|
+
body={
|
|
478
|
+
"message": f"Failed to get data from storage unit id #{self.zmq_server_info.id}, "
|
|
479
|
+
f"detail error message: {str(e)}"
|
|
480
|
+
},
|
|
481
|
+
)
|
|
482
|
+
return response_msg
|
|
483
|
+
|
|
484
|
+
def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage:
|
|
485
|
+
"""
|
|
486
|
+
Handle clear request, clear data in storage unit according to given local_indexes.
|
|
487
|
+
|
|
488
|
+
param:
|
|
489
|
+
data_parts: ZMQMessage from client, including target local_indexes.
|
|
490
|
+
return:
|
|
491
|
+
Clear data success response ZMQMessage.
|
|
492
|
+
"""
|
|
493
|
+
try:
|
|
494
|
+
local_indexes = data_parts.body["local_indexes"]
|
|
495
|
+
|
|
496
|
+
self.experience_data.clear(local_indexes)
|
|
497
|
+
|
|
498
|
+
response_msg = ZMQMessage.create(
|
|
499
|
+
request_type=ZMQRequestType.CLEAR_DATA_RESPONSE,
|
|
500
|
+
sender_id=self.zmq_server_info.id,
|
|
501
|
+
body={"message": f"Clear data in storage unit id #{self.zmq_server_info.id} successfully."},
|
|
502
|
+
)
|
|
503
|
+
except Exception as e:
|
|
504
|
+
response_msg = ZMQMessage.create(
|
|
505
|
+
request_type=ZMQRequestType.CLEAR_DATA_ERROR,
|
|
506
|
+
sender_id=self.zmq_server_info.id,
|
|
507
|
+
body={
|
|
508
|
+
"message": f"Failed to clear data in storage unit id #{self.zmq_server_info.id}, "
|
|
509
|
+
f"detail error message: {str(e)}"
|
|
510
|
+
},
|
|
511
|
+
)
|
|
512
|
+
return response_msg
|
|
513
|
+
|
|
514
|
+
def get_zmq_server_info(self) -> ZMQServerInfo:
|
|
515
|
+
return self.zmq_server_info
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|