TransferQueue 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,515 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ import time
18
+ from operator import itemgetter
19
+ from threading import Thread
20
+ from uuid import uuid4
21
+
22
+ import ray
23
+ import torch
24
+ import zmq
25
+ from ray.util import get_node_ip_address
26
+ from tensordict import NonTensorStack, TensorDict
27
+
28
+ from transfer_queue.utils.utils import TransferQueueRole
29
+ from transfer_queue.utils.zmq_utils import (
30
+ ZMQMessage,
31
+ ZMQRequestType,
32
+ ZMQServerInfo,
33
+ create_zmq_socket,
34
+ get_free_port,
35
+ )
36
+
37
+ logger = logging.getLogger(__name__)
38
+ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
39
+
40
+ TQ_STORAGE_POLLER_TIMEOUT = os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 1000)
41
+ TQ_STORAGE_HANDSHAKE_TIMEOUT = int(os.environ.get("TQ_STORAGE_HANDSHAKE_TIMEOUT", 30))
42
+ TQ_DATA_UPDATE_RESPONSE_TIMEOUT = int(os.environ.get("TQ_DATA_UPDATE_RESPONSE_TIMEOUT", 600))
43
+
44
+
45
+ class StorageUnitData:
46
+ """
47
+ Class used for storing several elements, each element is composed of several fields and corresponding data, like:
48
+ #####################################################
49
+ # local_index | field_name1 | field_name2 | ... #
50
+ # 0 | item1 | item2 | ... #
51
+ # 1 | item3 | item4 | ... #
52
+ # 2 | item5 | item6 | ... #
53
+ #####################################################
54
+ """
55
+
56
+ def __init__(self, storage_size: int):
57
+ # Dict containing field names and corresponding data in the field, e.g. {"field_name1": [data1, data2, ...]}
58
+ self.field_data: dict[str, list] = {}
59
+
60
+ # Maximum number of elements stored in storage unit
61
+ self.storage_size = storage_size
62
+
63
+ def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[str, list]:
64
+ """
65
+ Get data from storage unit according to given fields and local_indexes.
66
+
67
+ param:
68
+ fields: Field names used for getting data.
69
+ local_indexes: Local indexes used for getting data.
70
+ return:
71
+ TensorDict with field names as keys, corresponding data list as values.
72
+ """
73
+ result: dict[str, list] = {}
74
+
75
+ for field in fields:
76
+ # Validate field name
77
+ if field not in self.field_data:
78
+ raise ValueError(
79
+ f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}"
80
+ )
81
+
82
+ if len(local_indexes) == 1:
83
+ # The unsqueeze op make the shape from n to (1, n)
84
+ gathered_item = self.field_data[field][local_indexes[0]]
85
+ if not isinstance(gathered_item, torch.Tensor):
86
+ result[field] = NonTensorStack(gathered_item)
87
+ else:
88
+ result[field] = gathered_item.unsqueeze(0)
89
+ else:
90
+ gathered_items = list(itemgetter(*local_indexes)(self.field_data[field]))
91
+
92
+ if gathered_items:
93
+ all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
94
+ if all_tensors:
95
+ result[field] = torch.nested.as_nested_tensor(gathered_items)
96
+ else:
97
+ result[field] = NonTensorStack(*gathered_items)
98
+
99
+ return TensorDict(result)
100
+
101
+ def put_data(self, field_data: TensorDict[str, list], local_indexes: list[int]) -> None:
102
+ """
103
+ Put or update data into storage unit according to given field_data and local_indexes.
104
+
105
+ param:
106
+ field_data: Dict with field names as keys, corresponding data in the field as values.
107
+ local_indexes: Local indexes used for putting data.
108
+ """
109
+ extracted_data = dict(field_data)
110
+
111
+ for f, values in extracted_data.items():
112
+ if f not in self.field_data:
113
+ self.field_data[f] = [None] * self.storage_size
114
+
115
+ for i, idx in enumerate(local_indexes):
116
+ if idx < 0 or idx >= self.storage_size:
117
+ raise ValueError(
118
+ f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond "
119
+ f"storage_size: {self.storage_size}"
120
+ )
121
+
122
+ self.field_data[f][idx] = values[i]
123
+
124
+ def clear(self, local_indexes: list[int]) -> None:
125
+ """
126
+ Clear data at specified local_indexes by setting all related fields to None.
127
+
128
+ param:
129
+ local_indexes: local_indexes to clear.
130
+ """
131
+ # Validate local_indexes
132
+ for idx in local_indexes:
133
+ if idx < 0 or idx >= self.storage_size:
134
+ raise ValueError(
135
+ f"StorageUnitData clear operation receive invalid local_index: {idx} beyond "
136
+ f"storage_size: {self.storage_size}"
137
+ )
138
+
139
+ # Clear data at specified local_indexes
140
+ for f in self.field_data:
141
+ for idx in local_indexes:
142
+ self.field_data[f][idx] = None
143
+
144
+
145
+ @ray.remote(num_cpus=1)
146
+ class TransferQueueStorageSimpleUnit:
147
+ def __init__(self, storage_size: int):
148
+ super().__init__()
149
+ self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4()}"
150
+ self.storage_size = storage_size
151
+ self.controller_infos: dict[str, ZMQServerInfo] = {}
152
+
153
+ self.experience_data = StorageUnitData(self.storage_size)
154
+
155
+ self.zmq_server_info = ZMQServerInfo.create(
156
+ role=TransferQueueRole.STORAGE,
157
+ id=str(self.storage_unit_id),
158
+ ip=get_node_ip_address(),
159
+ ports={"put_get_socket": get_free_port()},
160
+ )
161
+ self._init_zmq_socket()
162
+
163
+ def _init_zmq_socket(self) -> None:
164
+ """
165
+ Initialize ZMQ socket connections between storage unit and controllers/clients:
166
+ - controller_handshake_sockets:
167
+ Handshake between storage unit and controllers.
168
+ - data_status_update_sockets:
169
+ Broadcast data update status from storage unit to controllers when handling put operation.
170
+ - put_get_socket:
171
+ Handle put/get requests from clients.
172
+ """
173
+ self.zmq_context = zmq.Context()
174
+
175
+ self.controller_handshake_sockets: dict[str, zmq.Socket] = {}
176
+ self.data_status_update_sockets: dict[str, zmq.Socket] = {}
177
+
178
+ self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER)
179
+ self.put_get_socket.bind(self.zmq_server_info.to_addr("put_get_socket"))
180
+
181
+ def register_controller_info(self, controller_infos: dict[str, ZMQServerInfo]) -> None:
182
+ """
183
+ Build connections between storage unit and controllers, start put/get process.
184
+
185
+ param:
186
+ controller_infos: Dict with controller infos.
187
+ """
188
+ self.controller_infos = controller_infos
189
+
190
+ self._init_zmq_sockets_with_controller_infos()
191
+ self._connect_to_controller()
192
+ self._start_process_put_get()
193
+
194
+ def _init_zmq_sockets_with_controller_infos(self) -> None:
195
+ """Initialize ZMQ sockets between storage unit and controllers for handshake."""
196
+ for controller_id in self.controller_infos.keys():
197
+ self.controller_handshake_sockets[controller_id] = create_zmq_socket(
198
+ self.zmq_context,
199
+ zmq.DEALER,
200
+ identity=f"{self.storage_unit_id}-controller_handshake_sockets-{uuid4()}".encode(),
201
+ )
202
+ self.data_status_update_sockets[controller_id] = create_zmq_socket(
203
+ self.zmq_context,
204
+ zmq.DEALER,
205
+ identity=f"{self.storage_unit_id}-data_status_update_sockets-{uuid4()}".encode(),
206
+ )
207
+
208
+ def _connect_to_controller(self) -> None:
209
+ """Connect storage unit to all controllers."""
210
+ connected_controllers: set[str] = set()
211
+
212
+ # Create zmq poller for handshake confirmation between controller and storage unit
213
+ poller = zmq.Poller()
214
+
215
+ for controller_id, controller_info in self.controller_infos.items():
216
+ self.controller_handshake_sockets[controller_id].connect(controller_info.to_addr("handshake_socket"))
217
+ logger.debug(
218
+ f"[{self.zmq_server_info.id}]: Handshake connection from storage unit id #{self.zmq_server_info.id} "
219
+ f"to controller id #{controller_id} establish successfully."
220
+ )
221
+
222
+ # Send handshake request to controllers
223
+ request_msg = ZMQMessage.create(
224
+ request_type=ZMQRequestType.HANDSHAKE,
225
+ sender_id=self.zmq_server_info.id,
226
+ body={
227
+ "storage_unit_id": self.storage_unit_id,
228
+ "storage_size": self.storage_size,
229
+ },
230
+ ).serialize()
231
+
232
+ self.controller_handshake_sockets[controller_id].send(request_msg)
233
+ logger.debug(
234
+ f"[{self.zmq_server_info.id}]: Send handshake request from storage unit id #{self.zmq_server_info.id} "
235
+ f"to controller id #{controller_id} successfully."
236
+ )
237
+
238
+ poller.register(self.controller_handshake_sockets[controller_id], zmq.POLLIN)
239
+
240
+ start_time = time.time()
241
+ while (
242
+ len(connected_controllers) < len(self.controller_infos)
243
+ and time.time() - start_time < TQ_STORAGE_HANDSHAKE_TIMEOUT
244
+ ):
245
+ socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT))
246
+
247
+ for controller_handshake_socket in self.controller_handshake_sockets.values():
248
+ if controller_handshake_socket in socks:
249
+ response_msg = ZMQMessage.deserialize(controller_handshake_socket.recv())
250
+
251
+ if response_msg.request_type == ZMQRequestType.HANDSHAKE_ACK:
252
+ connected_controllers.add(response_msg.sender_id)
253
+ logger.debug(
254
+ f"[{self.zmq_server_info.id}]: Get handshake ACK response from "
255
+ f"controller id #{str(response_msg.sender_id)} to storage unit id "
256
+ f"#{self.zmq_server_info.id} successfully."
257
+ )
258
+
259
+ if len(connected_controllers) < len(self.controller_infos):
260
+ logger.warning(
261
+ f"[{self.zmq_server_info.id}]: Only get {len(connected_controllers)} / {len(self.controller_infos)} "
262
+ f"successful handshake connections to controllers from storage unit id #{self.zmq_server_info.id}"
263
+ )
264
+
265
+ def _start_process_put_get(self) -> None:
266
+ """Create a daemon thread and start put/get process."""
267
+ self.process_put_get_thread = Thread(
268
+ target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.zmq_server_info.id}", daemon=True
269
+ )
270
+ self.process_put_get_thread.start()
271
+
272
+ def _process_put_get(self) -> None:
273
+ """Process put_get_socket request."""
274
+ poller = zmq.Poller()
275
+ poller.register(self.put_get_socket, zmq.POLLIN)
276
+
277
+ while True:
278
+ socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT))
279
+
280
+ if self.put_get_socket in socks:
281
+ identity, serialized_msg = self.put_get_socket.recv_multipart()
282
+
283
+ try:
284
+ request_msg = ZMQMessage.deserialize(serialized_msg)
285
+ operation = request_msg.request_type
286
+ logger.debug(f"[{self.zmq_server_info.id}]: receive operation: {operation}, message: {request_msg}")
287
+
288
+ if operation == ZMQRequestType.PUT_DATA:
289
+ response_msg = self._handle_put(request_msg)
290
+ elif operation == ZMQRequestType.GET_DATA:
291
+ response_msg = self._handle_get(request_msg)
292
+ elif operation == ZMQRequestType.CLEAR_DATA:
293
+ response_msg = self._handle_clear(request_msg)
294
+ else:
295
+ response_msg = ZMQMessage.create(
296
+ request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR,
297
+ sender_id=self.zmq_server_info.id,
298
+ body={
299
+ "message": f"Storage unit id #{self.zmq_server_info.id} "
300
+ f"receive invalid operation: {operation}."
301
+ },
302
+ )
303
+ except Exception as e:
304
+ response_msg = ZMQMessage.create(
305
+ request_type=ZMQRequestType.PUT_GET_ERROR,
306
+ sender_id=self.zmq_server_info.id,
307
+ body={
308
+ "message": f"Storage unit id #{self.zmq_server_info.id} occur error in processing "
309
+ f"put/get/clear request, detail error message: {str(e)}."
310
+ },
311
+ )
312
+
313
+ self.put_get_socket.send_multipart([identity, response_msg.serialize()])
314
+
315
+ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage:
316
+ """
317
+ Handle put request, add or update data into storage unit.
318
+
319
+ param:
320
+ data_parts: ZMQMessage from client.
321
+ return:
322
+ Put data success response ZMQMessage.
323
+ """
324
+ try:
325
+ global_indexes = data_parts.body["global_indexes"]
326
+ local_indexes = data_parts.body["local_indexes"]
327
+ field_data = data_parts.body["field_data"] # field_data should be in {field_name: [real data]} format.
328
+
329
+ self.experience_data.put_data(field_data, local_indexes)
330
+
331
+ # After put operation finish, send a message to the client
332
+ response_msg = ZMQMessage.create(
333
+ request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.zmq_server_info.id, body={}
334
+ )
335
+
336
+ # Gather per-field dtype and shape information for each field
337
+ # global_indexes, local_indexes, and field_data correspond one-to-one
338
+ per_field_dtypes = {}
339
+ per_field_shapes = {}
340
+
341
+ # Initialize the data structure for each global index
342
+ for global_idx in global_indexes:
343
+ per_field_dtypes[global_idx] = {}
344
+ per_field_shapes[global_idx] = {}
345
+
346
+ # For each field, extract dtype and shape for each sample
347
+ for field in field_data.keys():
348
+ for i, data_item in enumerate(field_data[field]):
349
+ global_idx = global_indexes[i]
350
+ per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None
351
+ per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None
352
+
353
+ # Broadcast data update message to all controllers with per-field dtype/shape information
354
+ self._notify_data_update(list(field_data.keys()), global_indexes, per_field_dtypes, per_field_shapes)
355
+ return response_msg
356
+ except Exception as e:
357
+ return ZMQMessage.create(
358
+ request_type=ZMQRequestType.PUT_ERROR,
359
+ sender_id=self.zmq_server_info.id,
360
+ body={
361
+ "message": f"Failed to put data into storage unit id "
362
+ f"#{self.zmq_server_info.id}, detail error message: {str(e)}"
363
+ },
364
+ )
365
+
366
+ def _notify_data_update(self, fields, global_indexes, dtypes, shapes) -> None:
367
+ """
368
+ Broadcast data status update to all controllers.
369
+
370
+ param:
371
+ fields: data update related fields.
372
+ global_indexes: data update related global_indexes.
373
+ dtypes: per-field dtypes for each field, in {global_index: {field: dtype}} format.
374
+ shapes: per-field shapes for each field, in {global_index: {field: shape}} format.
375
+ """
376
+ # Create zmq poller for notifying data update information
377
+ poller = zmq.Poller()
378
+
379
+ # Connect data status update socket to all controllers
380
+ for controller_id, controller_info in self.controller_infos.items():
381
+ data_status_update_socket = self.data_status_update_sockets[controller_id]
382
+ data_status_update_socket.connect(controller_info.to_addr("data_status_update_socket"))
383
+ logger.debug(
384
+ f"[{self.zmq_server_info.id}]: Data status update connection from "
385
+ f"storage unit id #{self.zmq_server_info.id} to "
386
+ f"controller id #{controller_id} establish successfully."
387
+ )
388
+
389
+ try:
390
+ poller.register(data_status_update_socket, zmq.POLLIN)
391
+
392
+ request_msg = ZMQMessage.create(
393
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE,
394
+ sender_id=self.zmq_server_info.id,
395
+ body={
396
+ "fields": fields,
397
+ "global_indexes": global_indexes,
398
+ "dtypes": dtypes,
399
+ "shapes": shapes,
400
+ },
401
+ ).serialize()
402
+
403
+ data_status_update_socket.send(request_msg)
404
+ logger.debug(
405
+ f"[{self.zmq_server_info.id}]: Send data status update request "
406
+ f"from storage unit id #{self.zmq_server_info.id} "
407
+ f"to controller id #{controller_id} successfully."
408
+ )
409
+ except Exception as e:
410
+ request_msg = ZMQMessage.create(
411
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR,
412
+ sender_id=self.zmq_server_info.id,
413
+ body={
414
+ "message": f"Failed to notify data status update information from "
415
+ f"storage unit id #{self.zmq_server_info.id}, "
416
+ f"detail error message: {str(e)}"
417
+ },
418
+ ).serialize()
419
+
420
+ data_status_update_socket.send(request_msg)
421
+
422
+ # Make sure all controllers successfully receive data status update information.
423
+ response_controllers: set[str] = set()
424
+ start_time = time.time()
425
+
426
+ while (
427
+ len(response_controllers) < len(self.controller_infos)
428
+ and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT
429
+ ):
430
+ socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT))
431
+
432
+ for data_status_update_socket in self.data_status_update_sockets.values():
433
+ if data_status_update_socket in socks:
434
+ response_msg = ZMQMessage.deserialize(data_status_update_socket.recv())
435
+
436
+ if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK:
437
+ response_controllers.add(response_msg.sender_id)
438
+ logger.debug(
439
+ f"[{self.zmq_server_info.id}]: Get data status update ACK response "
440
+ f"from controller id #{response_msg.sender_id} "
441
+ f"to storage unit id #{self.zmq_server_info.id} successfully."
442
+ )
443
+
444
+ if len(response_controllers) < len(self.controller_infos):
445
+ logger.warning(
446
+ f"[{self.zmq_server_info.id}]: Storage unit id #{self.zmq_server_info.id} "
447
+ f"only get {len(response_controllers)} / {len(self.controller_infos)} "
448
+ f"data status update ACK responses from controllers."
449
+ )
450
+
451
+ def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage:
452
+ """
453
+ Handle get request, return data from storage unit.
454
+
455
+ param:
456
+ data_parts: ZMQMessage from client.
457
+ return:
458
+ Get data success response ZMQMessage, containing target data.
459
+ """
460
+ try:
461
+ fields = data_parts.body["fields"]
462
+ local_indexes = data_parts.body["local_indexes"]
463
+
464
+ result_data = self.experience_data.get_data(fields, local_indexes)
465
+
466
+ response_msg = ZMQMessage.create(
467
+ request_type=ZMQRequestType.GET_DATA_RESPONSE,
468
+ sender_id=self.zmq_server_info.id,
469
+ body={
470
+ "data": result_data,
471
+ },
472
+ )
473
+ except Exception as e:
474
+ response_msg = ZMQMessage.create(
475
+ request_type=ZMQRequestType.GET_ERROR,
476
+ sender_id=self.zmq_server_info.id,
477
+ body={
478
+ "message": f"Failed to get data from storage unit id #{self.zmq_server_info.id}, "
479
+ f"detail error message: {str(e)}"
480
+ },
481
+ )
482
+ return response_msg
483
+
484
+ def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage:
485
+ """
486
+ Handle clear request, clear data in storage unit according to given local_indexes.
487
+
488
+ param:
489
+ data_parts: ZMQMessage from client, including target local_indexes.
490
+ return:
491
+ Clear data success response ZMQMessage.
492
+ """
493
+ try:
494
+ local_indexes = data_parts.body["local_indexes"]
495
+
496
+ self.experience_data.clear(local_indexes)
497
+
498
+ response_msg = ZMQMessage.create(
499
+ request_type=ZMQRequestType.CLEAR_DATA_RESPONSE,
500
+ sender_id=self.zmq_server_info.id,
501
+ body={"message": f"Clear data in storage unit id #{self.zmq_server_info.id} successfully."},
502
+ )
503
+ except Exception as e:
504
+ response_msg = ZMQMessage.create(
505
+ request_type=ZMQRequestType.CLEAR_DATA_ERROR,
506
+ sender_id=self.zmq_server_info.id,
507
+ body={
508
+ "message": f"Failed to clear data in storage unit id #{self.zmq_server_info.id}, "
509
+ f"detail error message: {str(e)}"
510
+ },
511
+ )
512
+ return response_msg
513
+
514
+ def get_zmq_server_info(self) -> ZMQServerInfo:
515
+ return self.zmq_server_info
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.