TransferQueue 0.0.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +307 -0
- recipe/simple_use_case/sync_demo.py +223 -0
- tests/test_client.py +390 -0
- tests/test_controller.py +268 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +479 -0
- transfer_queue/__init__.py +42 -0
- transfer_queue/client.py +663 -0
- transfer_queue/controller.py +772 -0
- transfer_queue/metadata.py +603 -0
- transfer_queue/storage.py +515 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +98 -0
- transfer_queue/utils/zmq_utils.py +175 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.0.1.dev0.dist-info/METADATA +15 -0
- transferqueue-0.0.1.dev0.dist-info/RECORD +21 -0
- transferqueue-0.0.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.0.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.0.1.dev0.dist-info/top_level.txt +4 -0
transfer_queue/client.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
from functools import wraps
|
|
19
|
+
from typing import Any, Callable, Optional, Union
|
|
20
|
+
from uuid import uuid4
|
|
21
|
+
|
|
22
|
+
import ray
|
|
23
|
+
import torch
|
|
24
|
+
import zmq
|
|
25
|
+
import zmq.asyncio
|
|
26
|
+
from tensordict import NonTensorStack, TensorDict
|
|
27
|
+
|
|
28
|
+
from transfer_queue.controller import TransferQueueController
|
|
29
|
+
from transfer_queue.metadata import (
|
|
30
|
+
BatchMeta,
|
|
31
|
+
StorageMetaGroup,
|
|
32
|
+
)
|
|
33
|
+
from transfer_queue.storage import TransferQueueStorageSimpleUnit
|
|
34
|
+
from transfer_queue.utils.utils import (
|
|
35
|
+
TransferQueueRole,
|
|
36
|
+
)
|
|
37
|
+
from transfer_queue.utils.zmq_utils import (
|
|
38
|
+
ZMQMessage,
|
|
39
|
+
ZMQRequestType,
|
|
40
|
+
ZMQServerInfo,
|
|
41
|
+
create_zmq_socket,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class AsyncTransferQueueClient:
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
client_id: str,
|
|
52
|
+
controller_infos: ZMQServerInfo | dict[Any, ZMQServerInfo],
|
|
53
|
+
storage_infos: ZMQServerInfo | dict[Any, ZMQServerInfo],
|
|
54
|
+
):
|
|
55
|
+
self.client_id = client_id
|
|
56
|
+
|
|
57
|
+
self._controllers: dict[str, ZMQServerInfo] = {}
|
|
58
|
+
self._storages: dict[str, ZMQServerInfo] = {}
|
|
59
|
+
self._register_servers(TransferQueueRole.CONTROLLER, controller_infos)
|
|
60
|
+
self._register_servers(TransferQueueRole.STORAGE, storage_infos)
|
|
61
|
+
|
|
62
|
+
def _register_servers(
|
|
63
|
+
self,
|
|
64
|
+
role: TransferQueueRole,
|
|
65
|
+
server_infos: ZMQServerInfo | dict[Any, ZMQServerInfo],
|
|
66
|
+
):
|
|
67
|
+
mapping = self._controllers if role == TransferQueueRole.CONTROLLER else self._storages
|
|
68
|
+
|
|
69
|
+
if not isinstance(server_infos, dict):
|
|
70
|
+
server_infos = {server_infos.id: server_infos}
|
|
71
|
+
|
|
72
|
+
for info in server_infos.values():
|
|
73
|
+
if not isinstance(info, ZMQServerInfo):
|
|
74
|
+
raise ValueError(f"Invalid server info for {role} {info.id}")
|
|
75
|
+
|
|
76
|
+
if info.id not in mapping:
|
|
77
|
+
mapping[info.id] = info
|
|
78
|
+
logger.info(f"[{self.client_id}]: Registered {role} server {info.id} at {info.ip}")
|
|
79
|
+
else:
|
|
80
|
+
logger.warning(f"[{self.client_id}]: Server {info.id} already registered, skipping")
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def dynamic_socket(target_role: TransferQueueRole, socket_name: str):
|
|
84
|
+
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close).
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
target_role (TransferQueueRole): Server type to connect to. Must be one of:
|
|
88
|
+
- `TransferQueueRole.CONTROLLER`
|
|
89
|
+
- `TransferQueueRole.STORAGE`
|
|
90
|
+
socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port").
|
|
91
|
+
|
|
92
|
+
Decorated Function Rules:
|
|
93
|
+
1. Must be an async class method (needs `self`).
|
|
94
|
+
2. `self` requires:
|
|
95
|
+
- `_controllers`/`_storages`: Server registries (match `target_role`).
|
|
96
|
+
- `client_id`: Unique client ID (for socket identity).
|
|
97
|
+
3. Specify target server via:
|
|
98
|
+
- `target_controller` (for Controller) or `target_storage` (for Storage) arg.
|
|
99
|
+
- Controller role: Uses first registered server if no ID is given.
|
|
100
|
+
4. Receives ZMQ socket via `socket` keyword arg (injected by decorator).
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def decorator(func: Callable):
|
|
104
|
+
@wraps(func)
|
|
105
|
+
async def wrapper(self, *args, **kwargs):
|
|
106
|
+
if target_role == TransferQueueRole.CONTROLLER:
|
|
107
|
+
servers = self._controllers
|
|
108
|
+
target = "target_controller"
|
|
109
|
+
elif target_role == TransferQueueRole.STORAGE:
|
|
110
|
+
servers = self._storages
|
|
111
|
+
target = "target_storage"
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError("Invalid target_role, must be CONTROLLER or STORAGE")
|
|
114
|
+
|
|
115
|
+
server_key = kwargs.get(target)
|
|
116
|
+
if server_key is None:
|
|
117
|
+
for arg in args:
|
|
118
|
+
if isinstance(arg, str) and arg in servers.keys():
|
|
119
|
+
server_key = arg
|
|
120
|
+
break
|
|
121
|
+
if server_key is None and target == "target_controller":
|
|
122
|
+
server_key = next(iter(servers.keys()))
|
|
123
|
+
|
|
124
|
+
server_info = servers.get(server_key)
|
|
125
|
+
if not server_info:
|
|
126
|
+
raise RuntimeError(f"Server {server_key} not found in registered {target_role} servers")
|
|
127
|
+
|
|
128
|
+
context = zmq.asyncio.Context()
|
|
129
|
+
address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}"
|
|
130
|
+
identity = f"{self.client_id}_to_{server_info.id}_{uuid4()}".encode()
|
|
131
|
+
sock = create_zmq_socket(context, zmq.DEALER, identity=identity)
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
sock.connect(address)
|
|
135
|
+
logger.info(
|
|
136
|
+
f"[{self.client_id}]: Connected to {target_role} {server_info.id} at {address} "
|
|
137
|
+
f"with identity {identity.decode()}"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
kwargs["socket"] = sock
|
|
141
|
+
return await func(self, *args, **kwargs)
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error(
|
|
144
|
+
f"[{self.client_id}]: Error in socket operation with {target_role} {server_info.id}: {e}"
|
|
145
|
+
)
|
|
146
|
+
raise
|
|
147
|
+
finally:
|
|
148
|
+
try:
|
|
149
|
+
if not sock.closed:
|
|
150
|
+
sock.setsockopt(zmq.LINGER, -1)
|
|
151
|
+
sock.close()
|
|
152
|
+
sock.close(linger=0)
|
|
153
|
+
except Exception as e:
|
|
154
|
+
logger.warning(
|
|
155
|
+
f"[{self.client_id}]: Error closing socket to {target_role} {server_info.id}: {e}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
context.term()
|
|
159
|
+
|
|
160
|
+
return wrapper
|
|
161
|
+
|
|
162
|
+
return decorator
|
|
163
|
+
|
|
164
|
+
@dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
|
|
165
|
+
async def async_get_meta(
|
|
166
|
+
self,
|
|
167
|
+
data_fields: list[str],
|
|
168
|
+
batch_size: int,
|
|
169
|
+
global_step: int,
|
|
170
|
+
mode: str = "fetch",
|
|
171
|
+
get_n_samples: bool = False,
|
|
172
|
+
task_name: Optional[str] = None,
|
|
173
|
+
target_controller: Optional[str] = None,
|
|
174
|
+
socket: Optional[zmq.asyncio.Socket] = None,
|
|
175
|
+
) -> BatchMeta:
|
|
176
|
+
"""Asynchronously fetches data metadata via ZMQ from the target controller.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
data_fields (list[str]): List of fields to retrieve metadata for
|
|
180
|
+
batch_size (int): Processing batch size
|
|
181
|
+
global_step (int): Current training/processing step
|
|
182
|
+
mode (str): Data fetch mode. 'fetch' to get ready data, 'force_fetch' to get data regardless of readiness.
|
|
183
|
+
'insert' IS AN INTERNAL USAGE THAT SHOULD NOT BE USED BY USERS.
|
|
184
|
+
get_n_samples (bool): If True, we arrange the samples of the same prompt in contiguous order. In 'fetch'
|
|
185
|
+
mode, only the samples of the same prompt that are all ready will be returned.
|
|
186
|
+
task_name (str): Optional task name associated with the request
|
|
187
|
+
target_controller (str): ID of the target controller to send the request to
|
|
188
|
+
socket (zmq.asyncio.Socket): ZMQ async socket for message transmission
|
|
189
|
+
|
|
190
|
+
Example:
|
|
191
|
+
>>> batch_size = 4
|
|
192
|
+
>>> current_step = 0
|
|
193
|
+
>>> # Example 1: "fetch" a batch of metadata that has been produced
|
|
194
|
+
>>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"],
|
|
195
|
+
>>> batch_size=batch_size,
|
|
196
|
+
>>> global_step=current_step,
|
|
197
|
+
>>> mode="fetch",
|
|
198
|
+
>>> get_n_samples=False,
|
|
199
|
+
>>> task_name="generate_sequences",
|
|
200
|
+
>>> ))
|
|
201
|
+
>>> print(batch_meta.is_ready) # you should get a batch_meta with is_ready=True
|
|
202
|
+
>>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, True, True, True]
|
|
203
|
+
>>>
|
|
204
|
+
>>> # Example 2: "force_fetch" a batch of metadata, ignoring their production status (but we still make
|
|
205
|
+
>>> # sure the corresponding data has not been consumed)
|
|
206
|
+
>>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"],
|
|
207
|
+
>>> batch_size=batch_size,
|
|
208
|
+
>>> global_step=current_step,
|
|
209
|
+
>>> mode="force_fetch",
|
|
210
|
+
>>> get_n_samples=False,
|
|
211
|
+
>>> task_name="generate_sequences",
|
|
212
|
+
>>> ))
|
|
213
|
+
>>> print(batch_meta.is_ready) # you may get a batch_meta with is_ready=False
|
|
214
|
+
>>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, False, False, True]
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
BatchMeta: Metadata object containing data structure, sample info, etc.
|
|
218
|
+
"""
|
|
219
|
+
assert socket is not None
|
|
220
|
+
request_msg = ZMQMessage.create(
|
|
221
|
+
request_type=ZMQRequestType.GET_META,
|
|
222
|
+
sender_id=self.client_id,
|
|
223
|
+
receiver_id=target_controller,
|
|
224
|
+
body={
|
|
225
|
+
"data_fields": data_fields,
|
|
226
|
+
"batch_size": batch_size,
|
|
227
|
+
"global_step": global_step,
|
|
228
|
+
"mode": mode,
|
|
229
|
+
"get_n_samples": get_n_samples,
|
|
230
|
+
"task_name": task_name,
|
|
231
|
+
},
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
try:
|
|
235
|
+
await socket.send(request_msg.serialize())
|
|
236
|
+
response = await socket.recv()
|
|
237
|
+
response_msg = ZMQMessage.deserialize(response)
|
|
238
|
+
logger.debug(
|
|
239
|
+
f"[{self.client_id}]: Client get datameta response: {response_msg} from controller {target_controller}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE:
|
|
243
|
+
metadata = response_msg.body["metadata"]
|
|
244
|
+
return metadata
|
|
245
|
+
else:
|
|
246
|
+
raise RuntimeError(
|
|
247
|
+
f"[{self.client_id}]: Failed to get metadata from controller {target_controller}: "
|
|
248
|
+
f"{response_msg.body.get('message', 'Unknown error')}"
|
|
249
|
+
)
|
|
250
|
+
except Exception as e:
|
|
251
|
+
raise RuntimeError(f"[{self.client_id}]: Error in get_meta: {str(e)}") from e
|
|
252
|
+
|
|
253
|
+
async def async_put(
|
|
254
|
+
self,
|
|
255
|
+
data: TensorDict,
|
|
256
|
+
metadata: Optional[BatchMeta] = None,
|
|
257
|
+
global_step: Optional[int] = None,
|
|
258
|
+
):
|
|
259
|
+
"""Asynchronously writes data to appropriate Storage Units based on metadata.
|
|
260
|
+
|
|
261
|
+
If metadata isn't provided, it will be created automatically using the insert mode
|
|
262
|
+
with the provided data_columns and global_step.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
data (torch.Tensor | tensordict.TensorDict): Data to write, either a Tensor or TensorDict
|
|
266
|
+
metadata (BatchMeta, optional): Optional metadata containing index and storage unit information
|
|
267
|
+
global_step (int, optional): Current step (required if no metadata is provided)
|
|
268
|
+
|
|
269
|
+
Example:
|
|
270
|
+
>>> batch_size = 4
|
|
271
|
+
>>> seq_len = 16
|
|
272
|
+
>>> current_step = 0
|
|
273
|
+
>>> # Example 1: normal usage
|
|
274
|
+
>>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"],
|
|
275
|
+
>>> batch_size=batch_size,
|
|
276
|
+
>>> global_step=current_step,
|
|
277
|
+
>>> mode="fetch",
|
|
278
|
+
>>> get_n_samples=False,
|
|
279
|
+
>>> task_name="generate_sequences",
|
|
280
|
+
>>> ))
|
|
281
|
+
>>> batch = asyncio.run(client.async_get_data(batch_meta))
|
|
282
|
+
>>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
|
|
283
|
+
>>> asyncio.run(client.async_put(data=output, metadata=batch_meta))
|
|
284
|
+
>>>
|
|
285
|
+
>>> # Example 2: put the initial data into the system without pre-existing metadata
|
|
286
|
+
>>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given global_step!
|
|
287
|
+
>>> # Please make sure the corresponding global_step is empty before calling the async_put()
|
|
288
|
+
>>> # without metadata.
|
|
289
|
+
>>> # Now we only support put all the data of the corresponding global step in once. You should repeat with
|
|
290
|
+
>>> # interleave the initial data if n_sample > 1 before calling the async_put().
|
|
291
|
+
>>> original_prompts = torch.randn(batch_size, seq_len)
|
|
292
|
+
>>> n_samples = 4
|
|
293
|
+
>>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
|
|
294
|
+
>>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
|
|
295
|
+
>>> # This will create metadata in "insert" mode internally.
|
|
296
|
+
>>> asyncio.run(client.async_put(data=prompts_repeated_batch, global_step=current_step))
|
|
297
|
+
|
|
298
|
+
"""
|
|
299
|
+
if metadata is None:
|
|
300
|
+
assert global_step is not None, "global_steps must be provided if metadata is not given"
|
|
301
|
+
|
|
302
|
+
metadata = await self.async_get_meta(
|
|
303
|
+
data_fields=list(data.keys()),
|
|
304
|
+
batch_size=data.batch_size[0],
|
|
305
|
+
global_step=global_step,
|
|
306
|
+
get_n_samples=True,
|
|
307
|
+
mode="insert",
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if not metadata or metadata.size == 0:
|
|
311
|
+
raise ValueError("metadata cannot be none or empty")
|
|
312
|
+
logger.debug(f"[{self.client_id}]: Put data with data: {data}")
|
|
313
|
+
tasks = [
|
|
314
|
+
self._put_to_storage(get_transfer_info(meta_group, data), target_storage=storage_id)
|
|
315
|
+
for storage_id, meta_group in metadata.storage_meta_groups.items()
|
|
316
|
+
]
|
|
317
|
+
await asyncio.gather(*tasks)
|
|
318
|
+
|
|
319
|
+
logger.info(
|
|
320
|
+
f"[{self.client_id}]: step {global_step} put {metadata.size} samples to storage units successfully."
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
@dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket")
|
|
324
|
+
async def _put_to_storage(self, storage_unit_data, target_storage=None, socket=None):
|
|
325
|
+
"""
|
|
326
|
+
Send data to a specific storage unit.
|
|
327
|
+
"""
|
|
328
|
+
global_indexes = storage_unit_data["global_indexes"]
|
|
329
|
+
local_indexes = storage_unit_data["local_indexes"]
|
|
330
|
+
field_data = TensorDict(
|
|
331
|
+
{
|
|
332
|
+
field: (
|
|
333
|
+
torch.nested.as_nested_tensor(storage_unit_data["field_data"][field])
|
|
334
|
+
if storage_unit_data["field_data"][field]
|
|
335
|
+
and all(isinstance(x, torch.Tensor) for x in storage_unit_data["field_data"][field])
|
|
336
|
+
else NonTensorStack(*storage_unit_data["field_data"][field])
|
|
337
|
+
)
|
|
338
|
+
for field in storage_unit_data["field_data"]
|
|
339
|
+
}
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
request_msg = ZMQMessage.create(
|
|
343
|
+
request_type=ZMQRequestType.PUT_DATA,
|
|
344
|
+
sender_id=self.client_id,
|
|
345
|
+
receiver_id=target_storage,
|
|
346
|
+
body={"global_indexes": global_indexes, "local_indexes": local_indexes, "field_data": field_data},
|
|
347
|
+
)
|
|
348
|
+
try:
|
|
349
|
+
await socket.send(request_msg.serialize())
|
|
350
|
+
serialized = await socket.recv()
|
|
351
|
+
response_msg = ZMQMessage.deserialize(serialized)
|
|
352
|
+
|
|
353
|
+
if response_msg.request_type != ZMQRequestType.PUT_DATA_RESPONSE:
|
|
354
|
+
raise RuntimeError(
|
|
355
|
+
f"Failed to put data to storage unit {target_storage}: "
|
|
356
|
+
f"{response_msg.body.get('message', 'Unknown error')}"
|
|
357
|
+
)
|
|
358
|
+
except Exception as e:
|
|
359
|
+
raise RuntimeError(f"Error in put to storage unit {target_storage}: {str(e)}") from e
|
|
360
|
+
|
|
361
|
+
@dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket")
|
|
362
|
+
async def _get_from_storage(self, index_data, target_storage=None, socket=None):
|
|
363
|
+
global_indexes = index_data["global_indexes"]
|
|
364
|
+
local_indexes = index_data["local_indexes"]
|
|
365
|
+
fields = index_data["fields"]
|
|
366
|
+
|
|
367
|
+
request_msg = ZMQMessage.create(
|
|
368
|
+
request_type=ZMQRequestType.GET_DATA,
|
|
369
|
+
sender_id=self.client_id,
|
|
370
|
+
receiver_id=target_storage,
|
|
371
|
+
body={"local_indexes": local_indexes, "fields": fields},
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
try:
|
|
375
|
+
await socket.send(request_msg.serialize())
|
|
376
|
+
serialized = await socket.recv()
|
|
377
|
+
response_msg = ZMQMessage.deserialize(serialized)
|
|
378
|
+
logger.info(f"[{self.client_id}]: get data response from storage unit {target_storage}: {response_msg}")
|
|
379
|
+
|
|
380
|
+
if response_msg.request_type == ZMQRequestType.GET_DATA_RESPONSE:
|
|
381
|
+
# Return data and index information from this storage unit
|
|
382
|
+
storage_unit_data = response_msg.body["data"]
|
|
383
|
+
return global_indexes, fields, storage_unit_data
|
|
384
|
+
else:
|
|
385
|
+
raise RuntimeError(
|
|
386
|
+
f"Failed to get data from storage unit {target_storage}: "
|
|
387
|
+
f"{response_msg.body.get('message', 'Unknown error')}"
|
|
388
|
+
)
|
|
389
|
+
except Exception as e:
|
|
390
|
+
raise RuntimeError(f"Error getting data from storage unit {target_storage}: {str(e)}") from e
|
|
391
|
+
|
|
392
|
+
async def async_get_data(self, metadata: BatchMeta) -> TensorDict:
|
|
393
|
+
"""Asynchronously fetches data via Storage Units and organizes it into a TensorDict.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
metadata (BatchMeta): Object containing:
|
|
397
|
+
- Data location info (which Storage Units hold the data)
|
|
398
|
+
- `global_indexes` to determine the ordering of merged results
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
tensordict.TensorDict with:
|
|
402
|
+
- Requested data fields (e.g., "prompt_token_ids", "response_token_ids").
|
|
403
|
+
- "global_indexes" key: Maps each sample to its original global index.
|
|
404
|
+
|
|
405
|
+
Example:
|
|
406
|
+
>>> batch_size = 4
|
|
407
|
+
>>> seq_len = 16
|
|
408
|
+
>>> current_step = 0
|
|
409
|
+
>>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"],
|
|
410
|
+
>>> batch_size=batch_size,
|
|
411
|
+
>>> global_step=current_step,
|
|
412
|
+
>>> mode="fetch",
|
|
413
|
+
>>> get_n_samples=False,
|
|
414
|
+
>>> task_name="generate_sequences",
|
|
415
|
+
>>> ))
|
|
416
|
+
>>> batch = asyncio.run(client.async_get_data(batch_meta))
|
|
417
|
+
>>> print(batch)
|
|
418
|
+
>>> # this is a TensorDict with fields "prompts" and "attention_mask".
|
|
419
|
+
>>> # The order of samples in the TensorDict matches the order of global_indexes in batch_meta
|
|
420
|
+
|
|
421
|
+
Note:
|
|
422
|
+
Why track `global_indexes`?
|
|
423
|
+
- Batches may be rearranged during task processing. `global_indexes` retains the original
|
|
424
|
+
mapping to Storage Units, enabling correct data writing back to Storage Units later.
|
|
425
|
+
|
|
426
|
+
"""
|
|
427
|
+
if not metadata or metadata.size == 0:
|
|
428
|
+
return TensorDict({}, batch_size=0)
|
|
429
|
+
|
|
430
|
+
# Use optimized retrieval with direct storage group access
|
|
431
|
+
tasks = [
|
|
432
|
+
self._get_from_storage(meta_group.get_transfer_info(), target_storage=storage_id)
|
|
433
|
+
for storage_id, meta_group in metadata.storage_meta_groups.items()
|
|
434
|
+
]
|
|
435
|
+
|
|
436
|
+
results = await asyncio.gather(*tasks)
|
|
437
|
+
|
|
438
|
+
# global_index: {field1: value, field2: value, ...}
|
|
439
|
+
storage_data: dict[int, dict[str, torch.Tensor]] = {}
|
|
440
|
+
for global_indexes, fields, storage_unit_data in results:
|
|
441
|
+
extracted_data = {field: storage_unit_data[field] for field in fields}
|
|
442
|
+
|
|
443
|
+
for idx, global_idx in enumerate(global_indexes):
|
|
444
|
+
if global_idx not in storage_data:
|
|
445
|
+
storage_data[global_idx] = {}
|
|
446
|
+
for field in fields:
|
|
447
|
+
storage_data[global_idx][field] = extracted_data[field][idx]
|
|
448
|
+
|
|
449
|
+
ordered_data: dict[str, torch.Tensor] = {field: [] for field in metadata.field_names}
|
|
450
|
+
for global_idx in metadata.global_indexes:
|
|
451
|
+
for field in metadata.field_names:
|
|
452
|
+
ordered_data[field].append(storage_data[global_idx][field])
|
|
453
|
+
|
|
454
|
+
tensor_data = {
|
|
455
|
+
field: (
|
|
456
|
+
torch.stack(torch.nested.as_nested_tensor(v).unbind())
|
|
457
|
+
if v
|
|
458
|
+
and all(isinstance(item, torch.Tensor) for item in v)
|
|
459
|
+
and all(item.shape == v[0].shape for item in v)
|
|
460
|
+
else (
|
|
461
|
+
torch.nested.as_nested_tensor(v)
|
|
462
|
+
if v and all(isinstance(item, torch.Tensor) for item in v)
|
|
463
|
+
else NonTensorStack(*v)
|
|
464
|
+
)
|
|
465
|
+
)
|
|
466
|
+
for field, v in ordered_data.items()
|
|
467
|
+
}
|
|
468
|
+
tensor_data["global_indexes"] = torch.tensor(metadata.global_indexes)
|
|
469
|
+
|
|
470
|
+
return TensorDict(tensor_data, batch_size=len(storage_data))
|
|
471
|
+
|
|
472
|
+
async def async_clear(self, global_step: int):
|
|
473
|
+
"""Asynchronously clears data from all storage units and controller metadata.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
global_step (int): The training step associated with the clear operation
|
|
477
|
+
|
|
478
|
+
"""
|
|
479
|
+
try:
|
|
480
|
+
target_controller = next(iter(self._controllers.keys()))
|
|
481
|
+
metadata = await self._get_clear_meta(global_step, target_controller)
|
|
482
|
+
|
|
483
|
+
tasks = []
|
|
484
|
+
|
|
485
|
+
for target_controller in self._controllers.keys():
|
|
486
|
+
tasks.append(self._clear_controller(global_step, target_controller))
|
|
487
|
+
|
|
488
|
+
# Group samples by storage unit for clearing
|
|
489
|
+
for target_storage, group in metadata.storage_meta_groups.items():
|
|
490
|
+
group_info = group.get_transfer_info()
|
|
491
|
+
if target_storage not in self._storages:
|
|
492
|
+
logger.warning(
|
|
493
|
+
f"[{self.client_id}]: Storage unit {target_storage} not registered, skipping clear operation."
|
|
494
|
+
)
|
|
495
|
+
continue
|
|
496
|
+
tasks.append(
|
|
497
|
+
self._clear_storage_unit(
|
|
498
|
+
group_info["local_indexes"],
|
|
499
|
+
target_storage,
|
|
500
|
+
)
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
504
|
+
|
|
505
|
+
for i, result in enumerate(results):
|
|
506
|
+
if isinstance(result, Exception):
|
|
507
|
+
logger.error(f"[{self.client_id}]: Error in clear operation task {i}: {result}")
|
|
508
|
+
|
|
509
|
+
logger.info(f"[{self.client_id}]: Clear operation for global_step {global_step} completed.")
|
|
510
|
+
except Exception as e:
|
|
511
|
+
raise RuntimeError(f"Error in clear operation: {str(e)}") from e
|
|
512
|
+
|
|
513
|
+
@dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
|
|
514
|
+
async def _get_clear_meta(self, global_step: int, target_controller=None, socket=None):
|
|
515
|
+
request_msg = ZMQMessage.create(
|
|
516
|
+
request_type=ZMQRequestType.GET_CLEAR_META,
|
|
517
|
+
sender_id=self.client_id,
|
|
518
|
+
receiver_id=target_controller,
|
|
519
|
+
body={"global_step": global_step},
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
await socket.send(request_msg.serialize())
|
|
523
|
+
serialized = await socket.recv()
|
|
524
|
+
response_msg = ZMQMessage.deserialize(serialized)
|
|
525
|
+
|
|
526
|
+
if response_msg.request_type != ZMQRequestType.GET_CLEAR_META_RESPONSE:
|
|
527
|
+
raise RuntimeError(
|
|
528
|
+
f"Failed to get metadata for clear operation: {response_msg.body.get('message', 'Unknown error')}"
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
return response_msg.body["metadata"]
|
|
532
|
+
|
|
533
|
+
@dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
|
|
534
|
+
async def _clear_controller(self, global_step, target_controller=None, socket=None):
|
|
535
|
+
try:
|
|
536
|
+
request_msg = ZMQMessage.create(
|
|
537
|
+
request_type=ZMQRequestType.CLEAR_META,
|
|
538
|
+
sender_id=self.client_id,
|
|
539
|
+
receiver_id=target_controller,
|
|
540
|
+
body={"global_step": global_step},
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
await socket.send(request_msg.serialize())
|
|
544
|
+
serialized_msg = await socket.recv()
|
|
545
|
+
response_msg = ZMQMessage.deserialize(serialized_msg)
|
|
546
|
+
|
|
547
|
+
if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE:
|
|
548
|
+
raise RuntimeError(
|
|
549
|
+
f"Failed to clear controller {target_controller}: "
|
|
550
|
+
f"{response_msg.body.get('message', 'Unknown error')}"
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
logger.info(
|
|
554
|
+
f"[{self.client_id}]: Successfully clear controller {target_controller} for global_step {global_step}"
|
|
555
|
+
)
|
|
556
|
+
except Exception as e:
|
|
557
|
+
logger.error(f"[{self.client_id}]: Error clearing controller {target_controller}: {str(e)}")
|
|
558
|
+
raise
|
|
559
|
+
|
|
560
|
+
@dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket")
|
|
561
|
+
async def _clear_storage_unit(self, local_indexes, target_storage=None, socket=None):
|
|
562
|
+
try:
|
|
563
|
+
request_msg = ZMQMessage.create(
|
|
564
|
+
request_type=ZMQRequestType.CLEAR_DATA,
|
|
565
|
+
sender_id=self.client_id,
|
|
566
|
+
receiver_id=target_storage,
|
|
567
|
+
body={"local_indexes": local_indexes},
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
await socket.send(request_msg.serialize())
|
|
571
|
+
serialized_msg = await socket.recv()
|
|
572
|
+
response_msg = ZMQMessage.deserialize(serialized_msg)
|
|
573
|
+
|
|
574
|
+
if response_msg.request_type != ZMQRequestType.CLEAR_DATA_RESPONSE:
|
|
575
|
+
raise RuntimeError(
|
|
576
|
+
f"Failed to clear storage {target_storage}: {response_msg.body.get('message', 'Unknown error')}"
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
logger.info(f"[{self.client_id}]: Successfully clear storage unit {target_storage}")
|
|
580
|
+
except Exception as e:
|
|
581
|
+
logger.error(f"[{self.client_id}]: Error clearing storage unit {target_storage}: {str(e)}")
|
|
582
|
+
raise
|
|
583
|
+
|
|
584
|
+
@dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
|
|
585
|
+
def check_current_step_consumption(self, task_name: str, global_step: int):
|
|
586
|
+
# TODO: Implement this method to check if all samples for the current step has been consumed
|
|
587
|
+
pass
|
|
588
|
+
|
|
589
|
+
@dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
|
|
590
|
+
def check_current_step_production(self, data_fields: list[str], global_step: int):
|
|
591
|
+
# TODO: Implement this method to check if all samples for the current step is ready for consumption
|
|
592
|
+
pass
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
class TransferQueueClient(AsyncTransferQueueClient):
|
|
596
|
+
def __init__(
|
|
597
|
+
self,
|
|
598
|
+
client_id: str,
|
|
599
|
+
controller_infos: ZMQServerInfo | dict[Any, ZMQServerInfo],
|
|
600
|
+
storage_infos: ZMQServerInfo | dict[Any, ZMQServerInfo],
|
|
601
|
+
):
|
|
602
|
+
super().__init__(
|
|
603
|
+
client_id,
|
|
604
|
+
controller_infos,
|
|
605
|
+
storage_infos,
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
def put(self, data: TensorDict, metadata: Optional[BatchMeta] = None, global_step: Optional[int] = None):
|
|
609
|
+
return asyncio.run(self.async_put(data, metadata, global_step))
|
|
610
|
+
|
|
611
|
+
def get_meta(
|
|
612
|
+
self,
|
|
613
|
+
data_fields: list[str],
|
|
614
|
+
batch_size: int,
|
|
615
|
+
global_step: int,
|
|
616
|
+
get_n_samples: bool = False,
|
|
617
|
+
task_name: Optional[str] = None,
|
|
618
|
+
) -> BatchMeta:
|
|
619
|
+
return asyncio.run(
|
|
620
|
+
self.async_get_meta(
|
|
621
|
+
data_fields=data_fields,
|
|
622
|
+
batch_size=batch_size,
|
|
623
|
+
global_step=global_step,
|
|
624
|
+
get_n_samples=get_n_samples,
|
|
625
|
+
task_name=task_name,
|
|
626
|
+
)
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
def get_data(self, metadata: BatchMeta) -> TensorDict:
|
|
630
|
+
return asyncio.run(self.async_get_data(metadata))
|
|
631
|
+
|
|
632
|
+
def clear(self, global_step: int):
|
|
633
|
+
return asyncio.run(self.async_clear(global_step))
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
def _add_field_data(
|
|
637
|
+
transfer_dict: dict[str, Any], storage_meta_group: StorageMetaGroup, data: TensorDict
|
|
638
|
+
) -> dict[str, Any]:
|
|
639
|
+
"""Helper function to add field data to the transfer dictionary"""
|
|
640
|
+
field_names = transfer_dict["fields"]
|
|
641
|
+
for fname in field_names:
|
|
642
|
+
if fname in data.keys():
|
|
643
|
+
transfer_dict["field_data"][fname] = []
|
|
644
|
+
for sample_meta in storage_meta_group.sample_metas:
|
|
645
|
+
transfer_dict["field_data"][fname].append(data[fname][sample_meta.batch_index])
|
|
646
|
+
return transfer_dict
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def get_transfer_info(
|
|
650
|
+
storage_meta_group: StorageMetaGroup,
|
|
651
|
+
data: TensorDict,
|
|
652
|
+
) -> dict[str, Any]:
|
|
653
|
+
"""Convert to dictionary format with field data for put operations"""
|
|
654
|
+
result = storage_meta_group.get_transfer_info(field_names=data.keys())
|
|
655
|
+
result = _add_field_data(result, storage_meta_group, data)
|
|
656
|
+
return result
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def process_zmq_server_info(handlers: dict[Any, Union[TransferQueueController, TransferQueueStorageSimpleUnit]]): # noqa: UP007
|
|
660
|
+
server_info = {}
|
|
661
|
+
for name, handler in handlers.items():
|
|
662
|
+
server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined]
|
|
663
|
+
return server_info
|