TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,611 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
from collections.abc import Mapping
|
|
18
|
+
from functools import wraps
|
|
19
|
+
from operator import itemgetter
|
|
20
|
+
from typing import Any, Callable
|
|
21
|
+
from uuid import uuid4
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
import zmq
|
|
25
|
+
from tensordict import NonTensorStack, TensorDict
|
|
26
|
+
|
|
27
|
+
from transfer_queue.metadata import BatchMeta
|
|
28
|
+
from transfer_queue.storage.managers.base import TransferQueueStorageManager
|
|
29
|
+
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
|
|
30
|
+
from transfer_queue.storage.simple_backend import StorageMetaGroup
|
|
31
|
+
from transfer_queue.utils.utils import limit_pytorch_auto_parallel_threads
|
|
32
|
+
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager")
|
|
39
|
+
class AsyncSimpleStorageManager(TransferQueueStorageManager):
|
|
40
|
+
"""Asynchronous storage manager that handles multiple storage units.
|
|
41
|
+
|
|
42
|
+
This manager provides async put/get/clear operations across multiple SimpleStorageUnit
|
|
43
|
+
instances using ZMQ communication and dynamic socket management.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, config: dict[str, Any]):
|
|
47
|
+
super().__init__(config)
|
|
48
|
+
|
|
49
|
+
self.config = config
|
|
50
|
+
server_infos = config.get("storage_unit_infos", None) # type: ZMQServerInfo | dict[str, ZMQServerInfo]
|
|
51
|
+
|
|
52
|
+
if server_infos is None:
|
|
53
|
+
raise ValueError("AsyncSimpleStorageManager requires non-empty 'storage_unit_infos' in config.")
|
|
54
|
+
|
|
55
|
+
self.storage_unit_infos = self._register_servers(server_infos)
|
|
56
|
+
self._build_storage_mapping_functions()
|
|
57
|
+
|
|
58
|
+
def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerInfo]"):
|
|
59
|
+
"""Register and validate server information.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
server_infos: ZMQServerInfo | dict[Any, ZMQServerInfo])
|
|
63
|
+
ZMQServerInfo or dict of server infos to register.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Dictionary with server IDs as keys and ZMQServerInfo objects as values.
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If server_infos format is invalid.
|
|
70
|
+
"""
|
|
71
|
+
server_infos_transform = {}
|
|
72
|
+
|
|
73
|
+
if isinstance(server_infos, ZMQServerInfo):
|
|
74
|
+
server_infos_transform[server_infos.id] = server_infos
|
|
75
|
+
elif isinstance(server_infos, Mapping):
|
|
76
|
+
for k, v in server_infos.items():
|
|
77
|
+
if not isinstance(v, ZMQServerInfo):
|
|
78
|
+
raise ValueError(f"Invalid server info for key {k}: {v}")
|
|
79
|
+
server_infos_transform[v.id] = v
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError(f"Invalid server infos: {server_infos}")
|
|
82
|
+
|
|
83
|
+
return server_infos_transform
|
|
84
|
+
|
|
85
|
+
def _build_storage_mapping_functions(self):
|
|
86
|
+
"""Build mapping functions for global index to storage unit and local index.
|
|
87
|
+
|
|
88
|
+
Creates round-robin mapping functions to distribute data across storage units.
|
|
89
|
+
"""
|
|
90
|
+
self.global_index_storage_unit_mapping = lambda x: list(self.storage_unit_infos.keys())[
|
|
91
|
+
x % len(self.storage_unit_infos)
|
|
92
|
+
]
|
|
93
|
+
self.global_index_local_index_mapping = lambda x: x // len(self.storage_unit_infos)
|
|
94
|
+
|
|
95
|
+
# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
|
|
96
|
+
@staticmethod
|
|
97
|
+
def dynamic_storage_manager_socket(socket_name: str):
|
|
98
|
+
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close).
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port").
|
|
102
|
+
|
|
103
|
+
Decorated Function Rules:
|
|
104
|
+
1. Must be an async class method (needs `self`).
|
|
105
|
+
2. `self` requires:
|
|
106
|
+
- `storage_unit_infos: storage unit infos (ZMQServerInfo | dict[Any, ZMQServerInfo]).
|
|
107
|
+
3. Specify target server via:
|
|
108
|
+
- `target_storage_unit` arg.
|
|
109
|
+
4. Receives ZMQ socket via `socket` keyword arg (injected by decorator).
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def decorator(func: Callable):
|
|
113
|
+
@wraps(func)
|
|
114
|
+
async def wrapper(self, *args, **kwargs):
|
|
115
|
+
server_key = kwargs.get("target_storage_unit")
|
|
116
|
+
if server_key is None:
|
|
117
|
+
for arg in args:
|
|
118
|
+
if isinstance(arg, str) and arg in self.storage_unit_infos.keys():
|
|
119
|
+
server_key = arg
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
server_info = self.storage_unit_infos.get(server_key)
|
|
123
|
+
|
|
124
|
+
if not server_info:
|
|
125
|
+
raise RuntimeError(f"Server {server_key} not found in registered servers")
|
|
126
|
+
|
|
127
|
+
context = zmq.asyncio.Context()
|
|
128
|
+
address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}"
|
|
129
|
+
identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
|
|
130
|
+
sock = create_zmq_socket(context, zmq.DEALER, identity=identity)
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
sock.connect(address)
|
|
134
|
+
# Timeouts to avoid indefinite await on recv/send
|
|
135
|
+
sock.setsockopt(zmq.RCVTIMEO, 10_000) # 10s
|
|
136
|
+
sock.setsockopt(zmq.SNDTIMEO, 10_000) # 10s
|
|
137
|
+
logger.info(
|
|
138
|
+
f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} "
|
|
139
|
+
f"with identity {identity.decode()}"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
kwargs["socket"] = sock
|
|
143
|
+
return await func(self, *args, **kwargs)
|
|
144
|
+
except Exception as e:
|
|
145
|
+
logger.error(
|
|
146
|
+
f"[{self.storage_manager_id}]: Error in socket operation with StorageUnit {server_info.id}: {e}"
|
|
147
|
+
)
|
|
148
|
+
raise
|
|
149
|
+
finally:
|
|
150
|
+
try:
|
|
151
|
+
if not sock.closed:
|
|
152
|
+
sock.close(linger=-1)
|
|
153
|
+
except Exception as e:
|
|
154
|
+
logger.warning(
|
|
155
|
+
f"[{self.storage_manager_id}]: Error closing socket to StorageUnit {server_info.id}: {e}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
context.term()
|
|
159
|
+
|
|
160
|
+
return wrapper
|
|
161
|
+
|
|
162
|
+
return decorator
|
|
163
|
+
|
|
164
|
+
async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
|
|
165
|
+
"""
|
|
166
|
+
Send data to remote StorageUnit based on metadata.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
data: TensorDict containing the data to store.
|
|
170
|
+
metadata: BatchMeta containing storage location information.
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
# group samples by storage unit
|
|
174
|
+
storage_meta_groups = build_storage_meta_groups(
|
|
175
|
+
metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# send data to each storage unit
|
|
179
|
+
tasks = [
|
|
180
|
+
self._put_to_single_storage_unit(get_transfer_data(meta_group, data), target_storage_unit=storage_id)
|
|
181
|
+
for storage_id, meta_group in storage_meta_groups.items()
|
|
182
|
+
]
|
|
183
|
+
await asyncio.gather(*tasks)
|
|
184
|
+
|
|
185
|
+
# Gather per-field dtype and shape information for each field
|
|
186
|
+
# global_indexes, local_indexes, and field_data correspond one-to-one
|
|
187
|
+
per_field_dtypes = {}
|
|
188
|
+
per_field_shapes = {}
|
|
189
|
+
|
|
190
|
+
# Initialize the data structure for each global index
|
|
191
|
+
for global_idx in metadata.global_indexes:
|
|
192
|
+
per_field_dtypes[global_idx] = {}
|
|
193
|
+
per_field_shapes[global_idx] = {}
|
|
194
|
+
|
|
195
|
+
# For each field, extract dtype and shape for each sample
|
|
196
|
+
for field in data.keys():
|
|
197
|
+
for i, data_item in enumerate(data[field]):
|
|
198
|
+
global_idx = metadata.global_indexes[i]
|
|
199
|
+
per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None
|
|
200
|
+
per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None
|
|
201
|
+
|
|
202
|
+
# Get current data partition id
|
|
203
|
+
# Note: Currently we only support putting to & getting data from a single data partition simultaneously,
|
|
204
|
+
# but in the future we may support putting to & getting data from multiple data partitions concurrently.
|
|
205
|
+
partition_id = metadata.samples[0].partition_id
|
|
206
|
+
|
|
207
|
+
# notify controller that new data is ready
|
|
208
|
+
await self.notify_data_update(
|
|
209
|
+
partition_id, list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
@dynamic_storage_manager_socket(socket_name="put_get_socket")
|
|
213
|
+
async def _put_to_single_storage_unit(self, transfer_data: dict[str, Any], target_storage_unit=None, socket=None):
|
|
214
|
+
"""
|
|
215
|
+
Send data to a specific storage unit.
|
|
216
|
+
"""
|
|
217
|
+
local_indexes = transfer_data["local_indexes"]
|
|
218
|
+
|
|
219
|
+
tensordict_data = TensorDict(
|
|
220
|
+
{
|
|
221
|
+
field: (
|
|
222
|
+
torch.nested.as_nested_tensor(transfer_data["field_data"][field])
|
|
223
|
+
if transfer_data["field_data"][field]
|
|
224
|
+
and all(isinstance(x, torch.Tensor) for x in transfer_data["field_data"][field])
|
|
225
|
+
else NonTensorStack(*transfer_data["field_data"][field])
|
|
226
|
+
)
|
|
227
|
+
for field in transfer_data["field_data"]
|
|
228
|
+
}
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
request_msg = ZMQMessage.create(
|
|
232
|
+
request_type=ZMQRequestType.PUT_DATA,
|
|
233
|
+
sender_id=self.storage_manager_id,
|
|
234
|
+
receiver_id=target_storage_unit,
|
|
235
|
+
body={"local_indexes": local_indexes, "data": tensordict_data},
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
await socket.send(request_msg.serialize())
|
|
240
|
+
serialized = await socket.recv()
|
|
241
|
+
response_msg = ZMQMessage.deserialize(serialized)
|
|
242
|
+
|
|
243
|
+
if response_msg.request_type != ZMQRequestType.PUT_DATA_RESPONSE:
|
|
244
|
+
raise RuntimeError(
|
|
245
|
+
f"Failed to put data to storage unit {target_storage_unit}: "
|
|
246
|
+
f"{response_msg.body.get('message', 'Unknown error')}"
|
|
247
|
+
)
|
|
248
|
+
except Exception as e:
|
|
249
|
+
raise RuntimeError(f"Error in put to storage unit {target_storage_unit}: {str(e)}") from e
|
|
250
|
+
|
|
251
|
+
async def get_data(self, metadata: BatchMeta) -> TensorDict:
|
|
252
|
+
"""
|
|
253
|
+
Retrieve data from remote StorageUnit based on metadata.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
metadata: BatchMeta that contains metadata for data retrieval.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
TensorDict containing the retrieved data.
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
# group samples by storage unit
|
|
263
|
+
storage_meta_groups = build_storage_meta_groups(
|
|
264
|
+
metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# retrive data
|
|
268
|
+
tasks = [
|
|
269
|
+
self._get_from_single_storage_unit(meta_group.get_transfer_data(), target_storage_unit=storage_id)
|
|
270
|
+
for storage_id, meta_group in storage_meta_groups.items()
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
results = await asyncio.gather(*tasks)
|
|
274
|
+
|
|
275
|
+
# post-process data segments to generate a batch of data
|
|
276
|
+
merged_data: dict[int, dict[str, torch.Tensor]] = {}
|
|
277
|
+
for global_indexes, fields, data_from_single_storage_unit in results:
|
|
278
|
+
field_getter = itemgetter(*fields)
|
|
279
|
+
field_values = field_getter(data_from_single_storage_unit)
|
|
280
|
+
|
|
281
|
+
if len(fields) == 1:
|
|
282
|
+
extracted_data = {fields[0]: field_values}
|
|
283
|
+
else:
|
|
284
|
+
extracted_data = dict(zip(fields, field_values, strict=False))
|
|
285
|
+
|
|
286
|
+
for idx, global_idx in enumerate(global_indexes):
|
|
287
|
+
if global_idx not in merged_data:
|
|
288
|
+
merged_data[global_idx] = {}
|
|
289
|
+
merged_data[global_idx].update({field: extracted_data[field][idx] for field in fields})
|
|
290
|
+
|
|
291
|
+
ordered_data: dict[str, list[torch.Tensor]] = {}
|
|
292
|
+
for field in metadata.field_names:
|
|
293
|
+
ordered_data[field] = [merged_data[global_idx][field] for global_idx in metadata.global_indexes]
|
|
294
|
+
|
|
295
|
+
with limit_pytorch_auto_parallel_threads():
|
|
296
|
+
tensor_data = {
|
|
297
|
+
field: (
|
|
298
|
+
torch.stack(torch.nested.as_nested_tensor(v).unbind())
|
|
299
|
+
if v
|
|
300
|
+
and all(isinstance(item, torch.Tensor) for item in v)
|
|
301
|
+
and all(item.shape == v[0].shape for item in v)
|
|
302
|
+
else (
|
|
303
|
+
torch.nested.as_nested_tensor(v)
|
|
304
|
+
if v and all(isinstance(item, torch.Tensor) for item in v)
|
|
305
|
+
else NonTensorStack(*v)
|
|
306
|
+
)
|
|
307
|
+
)
|
|
308
|
+
for field, v in ordered_data.items()
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
return TensorDict(tensor_data, batch_size=len(metadata))
|
|
312
|
+
|
|
313
|
+
@dynamic_storage_manager_socket(socket_name="put_get_socket")
|
|
314
|
+
async def _get_from_single_storage_unit(self, index_data, target_storage_unit=None, socket=None):
|
|
315
|
+
global_indexes = index_data["global_indexes"]
|
|
316
|
+
local_indexes = index_data["local_indexes"]
|
|
317
|
+
fields = index_data["fields"]
|
|
318
|
+
|
|
319
|
+
request_msg = ZMQMessage.create(
|
|
320
|
+
request_type=ZMQRequestType.GET_DATA,
|
|
321
|
+
sender_id=self.storage_manager_id,
|
|
322
|
+
receiver_id=target_storage_unit,
|
|
323
|
+
body={"local_indexes": local_indexes, "fields": fields},
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
try:
|
|
327
|
+
await socket.send(request_msg.serialize())
|
|
328
|
+
serialized = await socket.recv()
|
|
329
|
+
response_msg = ZMQMessage.deserialize(serialized)
|
|
330
|
+
logger.info(
|
|
331
|
+
f"[{self.storage_manager_id}]: get data response from storage unit "
|
|
332
|
+
f"{target_storage_unit}: {response_msg}"
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
if response_msg.request_type == ZMQRequestType.GET_DATA_RESPONSE:
|
|
336
|
+
# Return data and index information from this storage unit
|
|
337
|
+
storage_unit_data = response_msg.body["data"]
|
|
338
|
+
return global_indexes, fields, storage_unit_data
|
|
339
|
+
else:
|
|
340
|
+
raise RuntimeError(
|
|
341
|
+
f"Failed to get data from storage unit {target_storage_unit}: "
|
|
342
|
+
f"{response_msg.body.get('message', 'Unknown error')}"
|
|
343
|
+
)
|
|
344
|
+
except Exception as e:
|
|
345
|
+
raise RuntimeError(f"Error getting data from storage unit {target_storage_unit}: {str(e)}") from e
|
|
346
|
+
|
|
347
|
+
async def clear_data(self, metadata: BatchMeta) -> None:
|
|
348
|
+
"""Clear data in remote StorageUnit.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
metadata: BatchMeta that contains metadata for data clearing.
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
# group samples by storage unit
|
|
355
|
+
storage_meta_groups = build_storage_meta_groups(
|
|
356
|
+
metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# clear data
|
|
360
|
+
tasks = [
|
|
361
|
+
self._clear_single_storage_unit(
|
|
362
|
+
meta_group.get_transfer_data()["local_indexes"], target_storage_unit=storage_id
|
|
363
|
+
)
|
|
364
|
+
for storage_id, meta_group in storage_meta_groups.items()
|
|
365
|
+
]
|
|
366
|
+
|
|
367
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
368
|
+
|
|
369
|
+
for i, result in enumerate(results):
|
|
370
|
+
if isinstance(result, Exception):
|
|
371
|
+
logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}")
|
|
372
|
+
|
|
373
|
+
@dynamic_storage_manager_socket(socket_name="put_get_socket")
|
|
374
|
+
async def _clear_single_storage_unit(self, local_indexes, target_storage_unit=None, socket=None):
|
|
375
|
+
try:
|
|
376
|
+
request_msg = ZMQMessage.create(
|
|
377
|
+
request_type=ZMQRequestType.CLEAR_DATA,
|
|
378
|
+
sender_id=self.storage_manager_id,
|
|
379
|
+
receiver_id=target_storage_unit,
|
|
380
|
+
body={"local_indexes": local_indexes},
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
await socket.send(request_msg.serialize())
|
|
384
|
+
serialized_msg = await socket.recv()
|
|
385
|
+
response_msg = ZMQMessage.deserialize(serialized_msg)
|
|
386
|
+
|
|
387
|
+
if response_msg.request_type != ZMQRequestType.CLEAR_DATA_RESPONSE:
|
|
388
|
+
raise RuntimeError(
|
|
389
|
+
f"Failed to clear storage {target_storage_unit}: "
|
|
390
|
+
f"{response_msg.body.get('message', 'Unknown error')}"
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
logger.info(f"[{self.storage_manager_id}]: Successfully clear storage unit {target_storage_unit}")
|
|
394
|
+
except Exception as e:
|
|
395
|
+
logger.error(f"[{self.storage_manager_id}]: Error clearing storage unit {target_storage_unit}: {str(e)}")
|
|
396
|
+
raise
|
|
397
|
+
|
|
398
|
+
def get_zmq_server_info(self) -> dict[str, ZMQServerInfo]:
|
|
399
|
+
"""Get ZMQ server information for all storage units.
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
Dictionary mapping storage unit IDs to their ZMQServerInfo.
|
|
403
|
+
"""
|
|
404
|
+
return self.storage_unit_infos
|
|
405
|
+
|
|
406
|
+
def close(self) -> None:
|
|
407
|
+
"""Close all ZMQ sockets and context to prevent resource leaks."""
|
|
408
|
+
super().close()
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def get_transfer_data(
|
|
412
|
+
storage_meta_group: StorageMetaGroup,
|
|
413
|
+
data: TensorDict,
|
|
414
|
+
) -> dict[str, Any]:
|
|
415
|
+
"""Convert StorageMetaGroup and TensorDict to transfer format for put operations.
|
|
416
|
+
|
|
417
|
+
This function creates a bridge between the high-level metadata (StorageMetaGroup)
|
|
418
|
+
and the raw data (TensorDict), producing a transfer_dict that contains both
|
|
419
|
+
metadata structure and the actual field data needed for storage operations.
|
|
420
|
+
|
|
421
|
+
Key Data Flow:
|
|
422
|
+
1. storage_meta_group.get_transfer_data() creates metadata structure
|
|
423
|
+
2. _add_field_data() extracts data using sample_meta.batch_index as key
|
|
424
|
+
3. Final transfer_dict contains both metadata and correctly ordered data
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
storage_meta_group: StorageMetaGroup containing SampleMeta objects with:
|
|
428
|
+
- sample_meta.batch_index: Position in original TensorDict (0-based)
|
|
429
|
+
- sample_meta.global_index: Global unique identifier
|
|
430
|
+
- sample_meta.local_index: Position in target storage unit
|
|
431
|
+
data: Raw TensorDict with actual data values (as received from client):
|
|
432
|
+
Format: {"field_name": [data_at_index_0, data_at_index_1, ...]}
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
Complete transfer dictionary ready for storage operations:
|
|
436
|
+
{
|
|
437
|
+
"batch_indexes": [2, 0, 3], # Original TensorDict positions
|
|
438
|
+
"global_indexes": [10, 11, 12], # Global identifiers
|
|
439
|
+
"local_indexes": [4, 5, 6], # Storage locations
|
|
440
|
+
"fields": ["images", "labels"],
|
|
441
|
+
"field_data": {
|
|
442
|
+
"images": [img2, img0, img3], # Extracted by batch_index
|
|
443
|
+
"labels": [label2, label0, label3]
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
Example:
|
|
448
|
+
>>> # Client data: TensorDict with 5 samples (indices 0-4)
|
|
449
|
+
>>> data = TensorDict({
|
|
450
|
+
... "images": [img0, img1, img2, img3, img4],
|
|
451
|
+
... "labels": [label0, label1, label2, label3, label4]
|
|
452
|
+
... })
|
|
453
|
+
>>> # MetaGroup contains samples at positions 2, 0, 3 in original data
|
|
454
|
+
>>> group = StorageMetaGroup("storage1")
|
|
455
|
+
>>> group.add_sample_meta(SampleMeta(batch_index=2, global_index=10), 4)
|
|
456
|
+
>>> group.add_sample_meta(SampleMeta(batch_index=0, global_index=11), 5)
|
|
457
|
+
>>> group.add_sample_meta(SampleMeta(batch_index=3, global_index=12), 6)
|
|
458
|
+
>>> transfer_dict = get_transfer_data(group, data)
|
|
459
|
+
>>> transfer_dict["batch_indexes"] # [2, 0, 3] - positions in original TensorDict
|
|
460
|
+
>>> transfer_dict["field_data"]["images"] # [img2, img0, img3] - extracted data
|
|
461
|
+
|
|
462
|
+
Note:
|
|
463
|
+
The critical insight is that sample_meta.batch_index is used to index into
|
|
464
|
+
the original TensorDict to extract the correct data items. This ensures that
|
|
465
|
+
even when samples are reordered or distributed across storage units,
|
|
466
|
+
each sample's data is correctly mapped to its metadata.
|
|
467
|
+
"""
|
|
468
|
+
|
|
469
|
+
result = storage_meta_group.get_transfer_data(field_names=list(data.keys()))
|
|
470
|
+
result = _add_field_data(result, storage_meta_group, data)
|
|
471
|
+
return result
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def _add_field_data(
|
|
475
|
+
transfer_dict: dict[str, Any], storage_meta_group: StorageMetaGroup, data: TensorDict
|
|
476
|
+
) -> dict[str, Any]:
|
|
477
|
+
"""Extract field data from TensorDict using sample_meta.batch_index as index.
|
|
478
|
+
|
|
479
|
+
This function bridges the gap between raw TensorDict data and the transfer format
|
|
480
|
+
needed for storage operations. The transfer_dict contains metadata and structure
|
|
481
|
+
information, while the 'data' parameter contains the actual tensor values.
|
|
482
|
+
|
|
483
|
+
Key Concept: sample_meta.batch_index represents the position of each sample's data
|
|
484
|
+
in the original TensorDict (received from client). This function uses batch_index
|
|
485
|
+
to extract the correct data items for each sample in the storage_meta_group.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
transfer_dict: Dictionary containing transfer metadata with structure like:
|
|
489
|
+
{
|
|
490
|
+
"batch_indexes": [2, 0, 3], # Positions in original TensorDict
|
|
491
|
+
"global_indexes": [10, 11, 12], # Global identifiers
|
|
492
|
+
"local_indexes": [4, 5, 6], # Storage locations
|
|
493
|
+
"fields": ["field1", "field2"],
|
|
494
|
+
"field_data": {} # Will be populated by this function
|
|
495
|
+
}
|
|
496
|
+
storage_meta_group: StorageMetaGroup containing SampleMeta objects with:
|
|
497
|
+
- sample_meta.batch_index: Position in original TensorDict
|
|
498
|
+
- sample_meta.local_index: Position in storage unit
|
|
499
|
+
data: Raw TensorDict with actual data (as received from client):
|
|
500
|
+
TensorDict({"field1": [t0, t1, t2, t3, t4], "field2": [t5, t6, t7, t8, t9]})
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
Updated transfer dictionary with field_data populated:
|
|
504
|
+
{
|
|
505
|
+
"batch_indexes": [2, 0, 3],
|
|
506
|
+
"global_indexes": [10, 11, 12],
|
|
507
|
+
"local_indexes": [4, 5, 6],
|
|
508
|
+
"fields": ["field1", "field2"],
|
|
509
|
+
"field_data": {
|
|
510
|
+
"field1": [t2, t0, t3], # Extracted by batch_index from original data
|
|
511
|
+
"field2": [t7, t5, t8]
|
|
512
|
+
}
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
Example:
|
|
516
|
+
>>> # Raw data from client (TensorDict index 0-4)
|
|
517
|
+
>>> data = TensorDict({"images": [img0, img1, img2, img3, img4]})
|
|
518
|
+
>>> # storage_meta_group contains samples with batch_index [2, 0, 3]
|
|
519
|
+
>>> transfer_dict = {
|
|
520
|
+
... "fields": ["images"],
|
|
521
|
+
... "batch_indexes": [2, 0, 3],
|
|
522
|
+
... "local_indexes": [4, 5, 6],
|
|
523
|
+
... "field_data": {}
|
|
524
|
+
... }
|
|
525
|
+
>>> meta_group = StorageMetaGroup("storage1")
|
|
526
|
+
>>> meta_group.add_sample_meta(SampleMeta(batch_index=2), 4) # Extract img2
|
|
527
|
+
>>> meta_group.add_sample_meta(SampleMeta(batch_index=0), 5) # Extract img0
|
|
528
|
+
>>> meta_group.add_sample_meta(SampleMeta(batch_index=3), 6) # Extract img3
|
|
529
|
+
>>> result = _add_field_data(transfer_dict, meta_group, data)
|
|
530
|
+
>>> result["field_data"]["images"] # [img2, img0, img3] - extracted by batch_index
|
|
531
|
+
"""
|
|
532
|
+
field_names = transfer_dict["fields"]
|
|
533
|
+
for fname in field_names:
|
|
534
|
+
if fname in data.keys():
|
|
535
|
+
index = [sample_meta.batch_index for sample_meta in storage_meta_group.sample_metas]
|
|
536
|
+
|
|
537
|
+
result = itemgetter(*index)(data[fname])
|
|
538
|
+
if not isinstance(result, tuple):
|
|
539
|
+
result = (result,)
|
|
540
|
+
transfer_dict["field_data"][fname] = list(result)
|
|
541
|
+
|
|
542
|
+
return transfer_dict
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def build_storage_meta_groups(
|
|
546
|
+
batch_meta: BatchMeta,
|
|
547
|
+
global_index_storage_unit_mapping: Callable,
|
|
548
|
+
global_index_local_index_mapping: Callable,
|
|
549
|
+
) -> dict[str, StorageMetaGroup]:
|
|
550
|
+
"""Build storage meta groups from batch metadata for distributed storage.
|
|
551
|
+
|
|
552
|
+
This function is the starting point of the data distribution workflow. It analyzes
|
|
553
|
+
BatchMeta containing SampleMeta objects (originating from client requests) and
|
|
554
|
+
groups them by target storage unit based on their global_index.
|
|
555
|
+
|
|
556
|
+
Key Data Flow:
|
|
557
|
+
1. BatchMeta contains SampleMeta objects with batch_index (original TensorDict position)
|
|
558
|
+
2. Each SampleMeta is assigned to a storage unit using global_index mapping
|
|
559
|
+
3. Local storage positions are calculated for each sample
|
|
560
|
+
4. Results in StorageMetaGroup objects ready for transfer operations
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
batch_meta: BatchMeta containing SampleMeta objects from client request.
|
|
564
|
+
Each SampleMeta has:
|
|
565
|
+
- batch_index: Position in original TensorDict (0-based)
|
|
566
|
+
- global_index: Global unique identifier across all storage
|
|
567
|
+
global_index_storage_unit_mapping: Function to map global_index to storage_unit_id.
|
|
568
|
+
Example: lambda x: storage_unit_ids[x % num_storage_units] (round-robin distribution)
|
|
569
|
+
global_index_local_index_mapping: Function to map global_index to local_index.
|
|
570
|
+
Example: lambda x: x // num_storage_units (local position within storage unit)
|
|
571
|
+
|
|
572
|
+
Returns:
|
|
573
|
+
Dictionary mapping storage_unit_id to StorageMetaGroup, where each group contains:
|
|
574
|
+
- storage_id: Target storage unit identifier
|
|
575
|
+
- sample_metas: List of SampleMeta objects assigned to this unit
|
|
576
|
+
- local_indexes: List of storage positions for each sample
|
|
577
|
+
|
|
578
|
+
Example:
|
|
579
|
+
>>> # Input: BatchMeta with samples at global_indexes [10, 11, 12]
|
|
580
|
+
>>> # 3 storage units available: storage_0, storage_1, storage_2
|
|
581
|
+
>>> batch_meta = BatchMeta(samples=[
|
|
582
|
+
... SampleMeta(batch_index=0, global_index=10), # Original position 0
|
|
583
|
+
... SampleMeta(batch_index=1, global_index=11), # Original position 1
|
|
584
|
+
... SampleMeta(batch_index=2, global_index=12) # Original position 2
|
|
585
|
+
... ])
|
|
586
|
+
>>> groups = build_storage_meta_groups(
|
|
587
|
+
... batch_meta,
|
|
588
|
+
... lambda x: f"storage_{x % 3}", # 10->storage_1, 11->storage_2, 12->storage_0
|
|
589
|
+
... lambda x: x // 3 # 10->3, 11->3, 12->4
|
|
590
|
+
... )
|
|
591
|
+
>>> groups["storage_1"].sample_metas[0].batch_index # 0 - original TensorDict position
|
|
592
|
+
>>> groups["storage_1"].sample_metas[0].local_index # 3 - storage position
|
|
593
|
+
|
|
594
|
+
Note:
|
|
595
|
+
This function preserves the crucial batch_index information that links each
|
|
596
|
+
SampleMeta back to its original position in the client's TensorDict.
|
|
597
|
+
This batch_index is later used by _add_field_data() to extract
|
|
598
|
+
the correct data items for storage.
|
|
599
|
+
"""
|
|
600
|
+
storage_meta_groups: dict[str, StorageMetaGroup] = {}
|
|
601
|
+
|
|
602
|
+
for sample in batch_meta.samples:
|
|
603
|
+
storage_id = global_index_storage_unit_mapping(sample.global_index)
|
|
604
|
+
local_index = global_index_local_index_mapping(sample.global_index)
|
|
605
|
+
if storage_id not in storage_meta_groups:
|
|
606
|
+
storage_meta_groups[storage_id] = StorageMetaGroup(storage_id=storage_id)
|
|
607
|
+
|
|
608
|
+
# Use add_sample_meta to store SampleMeta references directly
|
|
609
|
+
storage_meta_groups[storage_id].add_sample_meta(sample, local_index)
|
|
610
|
+
|
|
611
|
+
return storage_meta_groups
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from transfer_queue.storage.managers.base import KVStorageManager
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class YuanrongStorageManager(KVStorageManager):
|
|
7
|
+
def __init__(self, config: dict[str, Any]):
|
|
8
|
+
host = config.get("host", None)
|
|
9
|
+
port = config.get("port", None)
|
|
10
|
+
device_id = config.get("device_id", None)
|
|
11
|
+
if host is None or not isinstance(host, str):
|
|
12
|
+
raise ValueError("Missing or invalid 'host' in config")
|
|
13
|
+
if port is None or not isinstance(port, int):
|
|
14
|
+
raise ValueError("Missing or invalid 'port' in config")
|
|
15
|
+
# TODO: device_id may be a list[int]
|
|
16
|
+
if device_id is None or not isinstance(device_id, int):
|
|
17
|
+
raise ValueError("Missing or invalid 'device_id' in config")
|
|
18
|
+
super().__init__(config)
|