TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import itertools
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import time
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from typing import Any
|
|
20
|
+
from uuid import uuid4
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
import zmq
|
|
24
|
+
from tensordict import TensorDict
|
|
25
|
+
from torch import Tensor
|
|
26
|
+
|
|
27
|
+
from transfer_queue.metadata import BatchMeta
|
|
28
|
+
from transfer_queue.storage.clients.factory import StorageClientFactory
|
|
29
|
+
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ZMQ timeouts (in seconds) and retry configurations
|
|
36
|
+
TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5))
|
|
37
|
+
TQ_STORAGE_HANDSHAKE_TIMEOUT = int(os.environ.get("TQ_STORAGE_HANDSHAKE_TIMEOUT", 30))
|
|
38
|
+
TQ_STORAGE_HANDSHAKE_RETRY_INTERVAL = int(os.environ.get("TQ_STORAGE_HANDSHAKE_RETRY_INTERVAL", 1))
|
|
39
|
+
TQ_STORAGE_HANDSHAKE_MAX_RETRIES = int(os.environ.get("TQ_STORAGE_HANDSHAKE_MAX_RETRIES", 3))
|
|
40
|
+
TQ_DATA_UPDATE_RESPONSE_TIMEOUT = int(os.environ.get("TQ_DATA_UPDATE_RESPONSE_TIMEOUT", 30))
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TransferQueueStorageManager(ABC):
|
|
44
|
+
"""Base class for storage layer. It defines the interface for data operations and
|
|
45
|
+
generally provides handshake & notification capabilities."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, config: dict[str, Any]):
|
|
48
|
+
self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}"
|
|
49
|
+
self.config = config
|
|
50
|
+
self.controller_info = config.get("controller_info", None) # type: ZMQServerInfo
|
|
51
|
+
|
|
52
|
+
self.data_status_update_socket = None
|
|
53
|
+
self.controller_handshake_socket = None
|
|
54
|
+
|
|
55
|
+
self.zmq_context = None
|
|
56
|
+
self._connect_to_controller()
|
|
57
|
+
|
|
58
|
+
def _connect_to_controller(self) -> None:
|
|
59
|
+
"""Initialize ZMQ sockets between storage unit and controller for handshake."""
|
|
60
|
+
if not isinstance(self.controller_info, ZMQServerInfo):
|
|
61
|
+
raise ValueError(f"controller_info should be ZMQServerInfo, but got {type(self.controller_info)}")
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
# create zmq context
|
|
65
|
+
self.zmq_context = zmq.Context()
|
|
66
|
+
|
|
67
|
+
# create zmq sockets for handshake and data status update
|
|
68
|
+
self.controller_handshake_socket = create_zmq_socket(
|
|
69
|
+
self.zmq_context,
|
|
70
|
+
zmq.DEALER,
|
|
71
|
+
identity=f"{self.storage_manager_id}-controller_handshake_socket-{uuid4().hex[:8]}".encode(),
|
|
72
|
+
)
|
|
73
|
+
self.data_status_update_socket = create_zmq_socket(
|
|
74
|
+
self.zmq_context,
|
|
75
|
+
zmq.DEALER,
|
|
76
|
+
identity=f"{self.storage_manager_id}-data_status_update_socket-{uuid4().hex[:8]}".encode(),
|
|
77
|
+
)
|
|
78
|
+
self.data_status_update_socket.connect(self.controller_info.to_addr("data_status_update_socket"))
|
|
79
|
+
|
|
80
|
+
# do handshake with controller
|
|
81
|
+
self._do_handshake_with_controller()
|
|
82
|
+
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.error(f"Failed to connect to controller: {e}")
|
|
85
|
+
raise
|
|
86
|
+
|
|
87
|
+
def _do_handshake_with_controller(self) -> None:
|
|
88
|
+
"""Handshake with controller to establish connection with retransmission mechanism."""
|
|
89
|
+
is_connected: bool = False
|
|
90
|
+
pending_connection: bool = True
|
|
91
|
+
handshake_retries: int = 0
|
|
92
|
+
|
|
93
|
+
# Create zmq poller for handshake confirmation between controller and storage manager
|
|
94
|
+
poller = zmq.Poller()
|
|
95
|
+
|
|
96
|
+
self.controller_handshake_socket.connect(self.controller_info.to_addr("handshake_socket"))
|
|
97
|
+
logger.debug(
|
|
98
|
+
f"[{self.storage_manager_id}]: Handshake connection from storage manager id #{self.storage_manager_id} "
|
|
99
|
+
f"to controller id #{self.controller_info.id} establish successfully."
|
|
100
|
+
)
|
|
101
|
+
poller.register(self.controller_handshake_socket, zmq.POLLIN)
|
|
102
|
+
|
|
103
|
+
# Initial handshake request send
|
|
104
|
+
self._send_handshake_requests()
|
|
105
|
+
|
|
106
|
+
start_time = time.time()
|
|
107
|
+
last_retry_time = time.time()
|
|
108
|
+
|
|
109
|
+
while (
|
|
110
|
+
not is_connected # Only one controller to connect to
|
|
111
|
+
and time.time() - start_time < TQ_STORAGE_HANDSHAKE_TIMEOUT
|
|
112
|
+
):
|
|
113
|
+
# Check for timeout and retransmission
|
|
114
|
+
current_time = time.time()
|
|
115
|
+
if pending_connection:
|
|
116
|
+
if (
|
|
117
|
+
current_time - last_retry_time >= TQ_STORAGE_HANDSHAKE_RETRY_INTERVAL
|
|
118
|
+
and handshake_retries < TQ_STORAGE_HANDSHAKE_MAX_RETRIES
|
|
119
|
+
):
|
|
120
|
+
logger.warning(
|
|
121
|
+
f"[{self.storage_manager_id}]: Retransmitting handshake "
|
|
122
|
+
f"to controller {self.controller_info.id}, "
|
|
123
|
+
f"attempt {handshake_retries + 1}/{TQ_STORAGE_HANDSHAKE_MAX_RETRIES}"
|
|
124
|
+
)
|
|
125
|
+
self._send_handshake_requests()
|
|
126
|
+
last_retry_time = current_time
|
|
127
|
+
handshake_retries += 1
|
|
128
|
+
elif handshake_retries >= TQ_STORAGE_HANDSHAKE_MAX_RETRIES:
|
|
129
|
+
raise TimeoutError(
|
|
130
|
+
f"[{self.storage_manager_id}]: Handshake with controller {self.controller_info.id} "
|
|
131
|
+
f"({self.controller_info.ip}) failed after "
|
|
132
|
+
f"{TQ_STORAGE_HANDSHAKE_MAX_RETRIES} attempts."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Use shorter poll timeout for more responsive retry timing
|
|
136
|
+
# while maintaining overall handshake timeout behavior
|
|
137
|
+
poll_timeout = min(TQ_STORAGE_POLLER_TIMEOUT * 1000, 500) # Max 500ms
|
|
138
|
+
socks = dict(poller.poll(poll_timeout))
|
|
139
|
+
|
|
140
|
+
if (socks.get(self.controller_handshake_socket, 0) & zmq.POLLIN) and pending_connection:
|
|
141
|
+
try:
|
|
142
|
+
response_msg = ZMQMessage.deserialize(self.controller_handshake_socket.recv())
|
|
143
|
+
|
|
144
|
+
if response_msg.request_type == ZMQRequestType.HANDSHAKE_ACK:
|
|
145
|
+
is_connected = True
|
|
146
|
+
pending_connection = False
|
|
147
|
+
logger.debug(
|
|
148
|
+
f"[{self.storage_manager_id}]: Get handshake ACK response from "
|
|
149
|
+
f"controller id #{str(response_msg.sender_id)} to storage manager id "
|
|
150
|
+
f"#{self.storage_manager_id} successfully."
|
|
151
|
+
)
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.warning(
|
|
154
|
+
f"[{self.storage_manager_id}]: Error receiving handshake "
|
|
155
|
+
f"response from {self.controller_info.id}: {e}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def _send_handshake_requests(self) -> None:
|
|
159
|
+
"""Send handshake request to controller."""
|
|
160
|
+
request_msg = ZMQMessage.create(
|
|
161
|
+
request_type=ZMQRequestType.HANDSHAKE,
|
|
162
|
+
sender_id=self.storage_manager_id,
|
|
163
|
+
body={
|
|
164
|
+
"storage_manager_id": self.storage_manager_id,
|
|
165
|
+
"storage_manager_type": self.__class__.__name__,
|
|
166
|
+
},
|
|
167
|
+
).serialize()
|
|
168
|
+
|
|
169
|
+
self.controller_handshake_socket.send(request_msg)
|
|
170
|
+
logger.debug(
|
|
171
|
+
f"[{self.storage_manager_id}]: Send handshake request from storage manager id "
|
|
172
|
+
f"{self.storage_manager_id} to controller id #{self.controller_info.id} successfully."
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
async def notify_data_update(
|
|
176
|
+
self,
|
|
177
|
+
partition_id: str,
|
|
178
|
+
fields: list[str],
|
|
179
|
+
global_indexes: list[int],
|
|
180
|
+
dtypes: dict[int, dict[str, Any]],
|
|
181
|
+
shapes: dict[int, dict[str, Any]],
|
|
182
|
+
) -> None:
|
|
183
|
+
"""
|
|
184
|
+
Notify controller that new data is ready.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
partition_id: Current data partition id.
|
|
188
|
+
fields: Data update related fields.
|
|
189
|
+
global_indexes: Data update related global_indexes.
|
|
190
|
+
dtypes: Per-field dtypes for each field, in {global_index: {field: dtype}} format.
|
|
191
|
+
shapes: Per-field shapes for each field, in {global_index: {field: shape}} format.
|
|
192
|
+
"""
|
|
193
|
+
# Create zmq poller for notifying data update information
|
|
194
|
+
|
|
195
|
+
if not self.controller_info:
|
|
196
|
+
logger.warning(f"No controller connected for storage manager {self.storage_manager_id}")
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
# Create zmq poller for notifying data update information
|
|
200
|
+
poller = zmq.Poller()
|
|
201
|
+
# Note: data_status_update_socket is already connected during initialization
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
poller.register(self.data_status_update_socket, zmq.POLLIN)
|
|
205
|
+
|
|
206
|
+
request_msg = ZMQMessage.create(
|
|
207
|
+
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE,
|
|
208
|
+
sender_id=self.storage_manager_id,
|
|
209
|
+
body={
|
|
210
|
+
"partition_id": partition_id,
|
|
211
|
+
"fields": fields,
|
|
212
|
+
"global_indexes": global_indexes,
|
|
213
|
+
"dtypes": dtypes,
|
|
214
|
+
"shapes": shapes,
|
|
215
|
+
},
|
|
216
|
+
).serialize()
|
|
217
|
+
|
|
218
|
+
self.data_status_update_socket.send(request_msg)
|
|
219
|
+
logger.debug(
|
|
220
|
+
f"[{self.storage_manager_id}]: Send data status update request "
|
|
221
|
+
f"from storage manager id #{self.storage_manager_id} "
|
|
222
|
+
f"to controller id #{self.controller_info.id} successfully."
|
|
223
|
+
)
|
|
224
|
+
except Exception as e:
|
|
225
|
+
request_msg = ZMQMessage.create(
|
|
226
|
+
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR,
|
|
227
|
+
sender_id=self.storage_manager_id,
|
|
228
|
+
body={
|
|
229
|
+
"message": f"Failed to notify data status update information from "
|
|
230
|
+
f"storage manager id #{self.storage_manager_id}, "
|
|
231
|
+
f"detail error message: {str(e)}"
|
|
232
|
+
},
|
|
233
|
+
).serialize()
|
|
234
|
+
|
|
235
|
+
self.data_status_update_socket.send(request_msg)
|
|
236
|
+
|
|
237
|
+
# Make sure controller successfully receives data status update information.
|
|
238
|
+
response_received: bool = False
|
|
239
|
+
start_time = time.time()
|
|
240
|
+
|
|
241
|
+
while (
|
|
242
|
+
not response_received # Only one controller to get response from
|
|
243
|
+
and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT
|
|
244
|
+
):
|
|
245
|
+
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
|
|
246
|
+
|
|
247
|
+
if self.data_status_update_socket in socks:
|
|
248
|
+
response_msg = ZMQMessage.deserialize(self.data_status_update_socket.recv())
|
|
249
|
+
|
|
250
|
+
if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK:
|
|
251
|
+
response_received = True
|
|
252
|
+
logger.debug(
|
|
253
|
+
f"[{self.storage_manager_id}]: Get data status update ACK response "
|
|
254
|
+
f"from controller id #{response_msg.sender_id} "
|
|
255
|
+
f"to storage manager id #{self.storage_manager_id} successfully."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if not response_received:
|
|
259
|
+
logger.error(
|
|
260
|
+
f"[{self.storage_manager_id}]: Storage manager id #{self.storage_manager_id} "
|
|
261
|
+
f"did not receive data status update ACK response from controller."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
@abstractmethod
|
|
265
|
+
async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
|
|
266
|
+
raise NotImplementedError("Subclasses must implement put_data")
|
|
267
|
+
|
|
268
|
+
@abstractmethod
|
|
269
|
+
async def get_data(self, metadata: BatchMeta) -> TensorDict:
|
|
270
|
+
raise NotImplementedError("Subclasses must implement get_data")
|
|
271
|
+
|
|
272
|
+
@abstractmethod
|
|
273
|
+
async def clear_data(self, metadata: BatchMeta) -> None:
|
|
274
|
+
raise NotImplementedError("Subclasses must implement clear_data")
|
|
275
|
+
|
|
276
|
+
def close(self) -> None:
|
|
277
|
+
"""Close all ZMQ sockets and context to prevent resource leaks."""
|
|
278
|
+
for sock in (self.controller_handshake_socket, self.data_status_update_socket):
|
|
279
|
+
try:
|
|
280
|
+
if sock and not sock.closed:
|
|
281
|
+
sock.close(linger=0)
|
|
282
|
+
except Exception as e:
|
|
283
|
+
logger.error(f"[{self.storage_manager_id}]: Error closing socket {sock}: {str(e)}")
|
|
284
|
+
|
|
285
|
+
try:
|
|
286
|
+
if self.zmq_context:
|
|
287
|
+
self.zmq_context.term()
|
|
288
|
+
except Exception as e:
|
|
289
|
+
logger.error(f"[{self.storage_manager_id}]: Error terminating zmq_context: {str(e)}")
|
|
290
|
+
|
|
291
|
+
def __del__(self):
|
|
292
|
+
"""Destructor to ensure resources are cleaned up."""
|
|
293
|
+
try:
|
|
294
|
+
self.close()
|
|
295
|
+
except Exception as e:
|
|
296
|
+
logger.error(f"[{self.storage_manager_id}]: Exception during __del__: {str(e)}")
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class KVStorageManager(TransferQueueStorageManager):
|
|
300
|
+
"""
|
|
301
|
+
A storage manager that uses a key-value (KV) backend (e.g., YuanRong) to store and retrieve tensor data.
|
|
302
|
+
It maps structured metadata (BatchMeta) to flat lists of keys and values for efficient KV operations.
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
def __init__(self, config: dict[str, Any]):
|
|
306
|
+
"""
|
|
307
|
+
Initialize the KVStorageManager with configuration.
|
|
308
|
+
"""
|
|
309
|
+
super().__init__(config)
|
|
310
|
+
client_name = config.get("client_name", "Yuanrong")
|
|
311
|
+
self.storage_client = StorageClientFactory.create(client_name, config)
|
|
312
|
+
|
|
313
|
+
@staticmethod
|
|
314
|
+
def _generate_keys(metadata: BatchMeta) -> list[str]:
|
|
315
|
+
"""
|
|
316
|
+
Generate KV keys in the format 'global_index@field_name' for all sample-field pairs.
|
|
317
|
+
Keys are generated in sorted order by field name first, then by global index,
|
|
318
|
+
ensuring consistent ordering for batched operations.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
metadata (BatchMeta): Metadata containing global indexes and field names.
|
|
322
|
+
Returns:
|
|
323
|
+
list[str]: List of keys, e.g., ['0@field_a', '1@field_a', '0@field_b', ...]
|
|
324
|
+
"""
|
|
325
|
+
return [
|
|
326
|
+
f"{index}@{field}"
|
|
327
|
+
for field, index in itertools.product(sorted(metadata.field_names), metadata.global_indexes)
|
|
328
|
+
]
|
|
329
|
+
|
|
330
|
+
@staticmethod
|
|
331
|
+
def _generate_values(data: TensorDict) -> list[Tensor]:
|
|
332
|
+
"""
|
|
333
|
+
Extract and flatten tensor values from a TensorDict in field-major order.
|
|
334
|
+
Values are ordered by sorted field names, then by row (sample) order within each field.
|
|
335
|
+
This matches the key order generated by `_generate_keys`.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
data (TensorDict): Input data where keys are field names and values are tensors.
|
|
339
|
+
Returns:
|
|
340
|
+
list[Tensor]: Flattened list of tensors, e.g.,
|
|
341
|
+
[data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...]
|
|
342
|
+
"""
|
|
343
|
+
# TODO: We will support more complex data types ( NonTensorStack/ NonTensorData/ NestedTensor)
|
|
344
|
+
for v in data.values():
|
|
345
|
+
if not torch.is_tensor(v):
|
|
346
|
+
raise TypeError(f"TensorDict values must be torch.Tensor, but got {type(v)}")
|
|
347
|
+
|
|
348
|
+
return [row_data for field in sorted(data.keys()) for row_data in data[field]]
|
|
349
|
+
|
|
350
|
+
@staticmethod
|
|
351
|
+
def _merge_tensors_to_tensordict(metadata: BatchMeta, values: list[Tensor]) -> TensorDict:
|
|
352
|
+
"""
|
|
353
|
+
Reconstruct a TensorDict from a list of values using metadata.
|
|
354
|
+
The values list is assumed to be in the same order as keys generated by `_generate_keys`.
|
|
355
|
+
According to field names and global indexes in metadata, this method can determine
|
|
356
|
+
which dict key and which row this tensor belongs to. Then it reshapes the flat tensors list
|
|
357
|
+
back into a structured TensorDict .
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
metadata (BatchMeta): Metadata containing global indexes and field names.
|
|
361
|
+
values (list[Tensor]): List of tensors in field-major order.
|
|
362
|
+
Returns:
|
|
363
|
+
TensorDict: Reconstructed tensor dictionary with batch size equal to number of samples.
|
|
364
|
+
"""
|
|
365
|
+
global_indexes = metadata.global_indexes
|
|
366
|
+
field_names = sorted(metadata.field_names)
|
|
367
|
+
expected_length = len(global_indexes) * len(field_names)
|
|
368
|
+
if len(values) != expected_length:
|
|
369
|
+
raise ValueError(f"Length of values ({len(values)}) does not match expected ({expected_length})")
|
|
370
|
+
|
|
371
|
+
if len(values) == 0:
|
|
372
|
+
return TensorDict({}, batch_size=len(global_indexes))
|
|
373
|
+
|
|
374
|
+
merged_data: dict[str, list[Tensor]] = {field: [] for field in field_names}
|
|
375
|
+
|
|
376
|
+
# Group values by field_name
|
|
377
|
+
value_idx = 0
|
|
378
|
+
for field in field_names:
|
|
379
|
+
for _ in range(len(global_indexes)):
|
|
380
|
+
merged_data[field].append(values[value_idx])
|
|
381
|
+
value_idx += 1
|
|
382
|
+
|
|
383
|
+
# Stack or nest tensors per field
|
|
384
|
+
tensor_data = {}
|
|
385
|
+
for field, tensor_list in merged_data.items():
|
|
386
|
+
try:
|
|
387
|
+
tensor_data[field] = torch.stack(tensor_list)
|
|
388
|
+
except RuntimeError:
|
|
389
|
+
# Fallback to nested tensor if shapes are irregular
|
|
390
|
+
tensor_data[field] = torch.nested.as_nested_tensor(tensor_list)
|
|
391
|
+
|
|
392
|
+
return TensorDict(tensor_data, batch_size=len(global_indexes))
|
|
393
|
+
|
|
394
|
+
@staticmethod
|
|
395
|
+
def _get_shape_type_list(metadata: BatchMeta):
|
|
396
|
+
"""
|
|
397
|
+
Extract the expected shape and dtype for each field-sample pair in metadata.
|
|
398
|
+
The order matches the key/value order: sorted by field name, then by global index.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
metadata (BatchMeta): Metadata containing sample and field information.
|
|
402
|
+
Returns:
|
|
403
|
+
tuple[list[torch.Size], list[torch.dtype]]: Two lists containing the shape and dtype
|
|
404
|
+
for each tensor to be retrieved.
|
|
405
|
+
"""
|
|
406
|
+
shapes = []
|
|
407
|
+
dtypes = []
|
|
408
|
+
for field_name in sorted(metadata.field_names):
|
|
409
|
+
for index in range(len(metadata)):
|
|
410
|
+
field = metadata.samples[index].get_field_by_name(field_name)
|
|
411
|
+
shapes.append(field.shape)
|
|
412
|
+
dtypes.append(field.dtype)
|
|
413
|
+
return shapes, dtypes
|
|
414
|
+
|
|
415
|
+
# TODO: Test put_data/get_data/clear_data with YuanrongStorageClient
|
|
416
|
+
async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
|
|
417
|
+
"""
|
|
418
|
+
Store tensor data in the backend storage and notify the controller.
|
|
419
|
+
|
|
420
|
+
Serializes the input tensors, stores them using the storage client,
|
|
421
|
+
extracts per-sample dtype and shape information, and sends a notification
|
|
422
|
+
to the controller that new data is available.
|
|
423
|
+
"""
|
|
424
|
+
keys = self._generate_keys(metadata)
|
|
425
|
+
values = self._generate_values(data)
|
|
426
|
+
self.storage_client.put(keys=keys, values=values)
|
|
427
|
+
per_field_dtypes = {}
|
|
428
|
+
per_field_shapes = {}
|
|
429
|
+
|
|
430
|
+
# Initialize the data structure for each global index
|
|
431
|
+
for global_idx in metadata.global_indexes:
|
|
432
|
+
per_field_dtypes[global_idx] = {}
|
|
433
|
+
per_field_shapes[global_idx] = {}
|
|
434
|
+
|
|
435
|
+
# For each field, extract dtype and shape for each sample
|
|
436
|
+
for field in data.keys():
|
|
437
|
+
for i, data_item in enumerate(data[field]):
|
|
438
|
+
global_idx = metadata.global_indexes[i]
|
|
439
|
+
per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None
|
|
440
|
+
per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None
|
|
441
|
+
|
|
442
|
+
# notify controller that new data is ready
|
|
443
|
+
await self.notify_data_update(list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes)
|
|
444
|
+
|
|
445
|
+
async def get_data(self, metadata: BatchMeta) -> TensorDict:
|
|
446
|
+
"""
|
|
447
|
+
Retrieve tensor data from the backend storage.
|
|
448
|
+
|
|
449
|
+
Fetches tensors using the provided metadata, reconstructs them with the
|
|
450
|
+
correct shapes and dtypes, and merge them as a TensorDict according to metadata.
|
|
451
|
+
"""
|
|
452
|
+
keys = self._generate_keys(metadata)
|
|
453
|
+
shapes, dtypes = self._get_shape_type_list(metadata)
|
|
454
|
+
values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes)
|
|
455
|
+
return self._merge_tensors_to_tensordict(metadata, values)
|
|
456
|
+
|
|
457
|
+
async def clear_data(self, metadata: BatchMeta) -> None:
|
|
458
|
+
"""Remove stored data associated with the given metadata."""
|
|
459
|
+
keys = self._generate_keys(metadata)
|
|
460
|
+
self.storage_client.clear(keys=keys)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from transfer_queue.storage.managers.base import TransferQueueStorageManager
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TransferQueueStorageManagerFactory:
|
|
20
|
+
"""Factory that creates a StorageManager instance."""
|
|
21
|
+
|
|
22
|
+
_registry: dict[str, type[TransferQueueStorageManager]] = {}
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def register(cls, manager_type: str):
|
|
26
|
+
def decorator(manager_cls: type[TransferQueueStorageManager]):
|
|
27
|
+
if not issubclass(manager_cls, TransferQueueStorageManager):
|
|
28
|
+
raise TypeError(
|
|
29
|
+
f"manager_cls {getattr(manager_cls, '__name__', repr(manager_cls))} must be "
|
|
30
|
+
f"a subclass of TransferQueueStorageManager"
|
|
31
|
+
)
|
|
32
|
+
cls._registry[manager_type] = manager_cls
|
|
33
|
+
return manager_cls
|
|
34
|
+
|
|
35
|
+
return decorator
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager:
|
|
39
|
+
if manager_type not in cls._registry:
|
|
40
|
+
raise ValueError(
|
|
41
|
+
f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}"
|
|
42
|
+
)
|
|
43
|
+
return cls._registry[manager_type](config)
|