TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
transfer_queue/client.py
ADDED
|
@@ -0,0 +1,611 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
from functools import wraps
|
|
19
|
+
from typing import Any, Callable, Optional, Union
|
|
20
|
+
from uuid import uuid4
|
|
21
|
+
|
|
22
|
+
import ray
|
|
23
|
+
import zmq
|
|
24
|
+
import zmq.asyncio
|
|
25
|
+
from tensordict import TensorDict
|
|
26
|
+
|
|
27
|
+
from transfer_queue.controller import TransferQueueController
|
|
28
|
+
from transfer_queue.metadata import (
|
|
29
|
+
BatchMeta,
|
|
30
|
+
)
|
|
31
|
+
from transfer_queue.storage import (
|
|
32
|
+
SimpleStorageUnit,
|
|
33
|
+
TransferQueueStorageManager,
|
|
34
|
+
TransferQueueStorageManagerFactory,
|
|
35
|
+
)
|
|
36
|
+
from transfer_queue.utils.zmq_utils import (
|
|
37
|
+
ZMQMessage,
|
|
38
|
+
ZMQRequestType,
|
|
39
|
+
ZMQServerInfo,
|
|
40
|
+
create_zmq_socket,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class AsyncTransferQueueClient:
|
|
48
|
+
"""Asynchronous client for interacting with TransferQueue controller and storage systems.
|
|
49
|
+
|
|
50
|
+
This client provides async methods for data transfer operations including getting metadata,
|
|
51
|
+
reading data from storage, writing data to storage, and clearing data.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
client_id: str,
|
|
57
|
+
controller_info: ZMQServerInfo,
|
|
58
|
+
):
|
|
59
|
+
"""Initialize the asynchronous TransferQueue client.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
client_id: Unique identifier for this client instance
|
|
63
|
+
controller_info: Single controller ZMQ server information
|
|
64
|
+
"""
|
|
65
|
+
if controller_info is None:
|
|
66
|
+
raise ValueError("controller_info cannot be None")
|
|
67
|
+
if not isinstance(controller_info, ZMQServerInfo):
|
|
68
|
+
raise TypeError(f"controller_info must be ZMQServerInfo, got {type(controller_info)}")
|
|
69
|
+
self.client_id = client_id
|
|
70
|
+
self._controller: ZMQServerInfo = controller_info
|
|
71
|
+
logger.info(f"[{self.client_id}]: Registered Controller server {controller_info.id} at {controller_info.ip}")
|
|
72
|
+
|
|
73
|
+
def initialize_storage_manager(
|
|
74
|
+
self,
|
|
75
|
+
manager_type: str,
|
|
76
|
+
config: dict[str, Any],
|
|
77
|
+
):
|
|
78
|
+
"""Initialize the storage manager.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
manager_type: Type of storage manager to create. Supported types include:
|
|
82
|
+
AsyncSimpleStorageManager, KVStorageManager (under development), etc.
|
|
83
|
+
config: Configuration dictionary for the storage manager.
|
|
84
|
+
For AsyncSimpleStorageManager, must contain the following required keys:
|
|
85
|
+
- controller_info: ZMQ server information about the controller
|
|
86
|
+
- storage_unit_infos: ZMQ server information about the storage units
|
|
87
|
+
|
|
88
|
+
"""
|
|
89
|
+
self.storage_manager = TransferQueueStorageManagerFactory.create(manager_type, config)
|
|
90
|
+
|
|
91
|
+
# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
|
|
92
|
+
@staticmethod
|
|
93
|
+
def dynamic_socket(socket_name: str):
|
|
94
|
+
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers.
|
|
95
|
+
|
|
96
|
+
Handles socket lifecycle: create -> connect -> inject -> close.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port")
|
|
100
|
+
|
|
101
|
+
Decorated Function Requirements:
|
|
102
|
+
1. Must be an async class method (needs `self`)
|
|
103
|
+
2. `self` must have:
|
|
104
|
+
- `_controller`: Server registry
|
|
105
|
+
- `client_id`: Unique client ID for socket identity
|
|
106
|
+
3. Receives ZMQ socket via `socket` keyword argument (injected by decorator)
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def decorator(func: Callable):
|
|
110
|
+
@wraps(func)
|
|
111
|
+
async def wrapper(self, *args, **kwargs):
|
|
112
|
+
server_info = self._controller
|
|
113
|
+
if not server_info:
|
|
114
|
+
raise RuntimeError("No controller registered")
|
|
115
|
+
|
|
116
|
+
context = zmq.asyncio.Context()
|
|
117
|
+
address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}"
|
|
118
|
+
identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
|
|
119
|
+
sock = create_zmq_socket(context, zmq.DEALER, identity=identity)
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
sock.connect(address)
|
|
123
|
+
logger.info(
|
|
124
|
+
f"[{self.client_id}]: Connected to Controller {server_info.id} at {address} "
|
|
125
|
+
f"with identity {identity.decode()}"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
kwargs["socket"] = sock
|
|
129
|
+
return await func(self, *args, **kwargs)
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logger.error(f"[{self.client_id}]: Error in socket operation with Controller {server_info.id}: {e}")
|
|
132
|
+
raise
|
|
133
|
+
finally:
|
|
134
|
+
try:
|
|
135
|
+
if not sock.closed:
|
|
136
|
+
sock.close(linger=-1)
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logger.warning(f"[{self.client_id}]: Error closing socket to Controller {server_info.id}: {e}")
|
|
139
|
+
|
|
140
|
+
context.term()
|
|
141
|
+
|
|
142
|
+
return wrapper
|
|
143
|
+
|
|
144
|
+
return decorator
|
|
145
|
+
|
|
146
|
+
@dynamic_socket(socket_name="request_handle_socket")
|
|
147
|
+
async def async_get_meta(
|
|
148
|
+
self,
|
|
149
|
+
data_fields: list[str],
|
|
150
|
+
batch_size: int,
|
|
151
|
+
partition_id: str,
|
|
152
|
+
mode: str = "fetch",
|
|
153
|
+
task_name: Optional[str] = None,
|
|
154
|
+
sampling_config: Optional[dict[str, Any]] = None,
|
|
155
|
+
socket: Optional[zmq.asyncio.Socket] = None,
|
|
156
|
+
) -> BatchMeta:
|
|
157
|
+
"""Asynchronously fetch data metadata from the controller via ZMQ.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
data_fields: List of data field names to retrieve metadata for
|
|
161
|
+
batch_size: Number of samples to request in the batch
|
|
162
|
+
partition_id: Current data partition id
|
|
163
|
+
mode: Data fetch mode. Options:
|
|
164
|
+
- 'fetch': Get ready data only
|
|
165
|
+
- 'force_fetch': Get data regardless of readiness (may return unready samples)
|
|
166
|
+
- 'insert': Internal usage - should not be used by users
|
|
167
|
+
task_name: Optional task name associated with the request
|
|
168
|
+
sampling_config: Optional sampling configuration for custom samplers.
|
|
169
|
+
For GRPOGroupNSampler, should include "n_samples_per_prompt": int
|
|
170
|
+
socket: ZMQ async socket for message transmission (injected by decorator)
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
BatchMeta: Metadata object containing data structure, sample information, and readiness status
|
|
174
|
+
|
|
175
|
+
Raises:
|
|
176
|
+
RuntimeError: If communication fails or controller returns error response
|
|
177
|
+
|
|
178
|
+
Example:
|
|
179
|
+
>>> # Example 1: Basic fetch metadata
|
|
180
|
+
>>> batch_meta = asyncio.run(client.async_get_meta(
|
|
181
|
+
... data_fields=["input_ids", "attention_mask"],
|
|
182
|
+
... batch_size=4,
|
|
183
|
+
... partition_id="train_0",
|
|
184
|
+
... mode="fetch",
|
|
185
|
+
... task_name="generate_sequences"
|
|
186
|
+
... ))
|
|
187
|
+
>>> print(batch_meta.is_ready) # True if all samples ready
|
|
188
|
+
>>>
|
|
189
|
+
>>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
|
|
190
|
+
>>> batch_meta = asyncio.run(client.async_get_meta(
|
|
191
|
+
... data_fields=["input_ids", "attention_mask"],
|
|
192
|
+
... batch_size=8,
|
|
193
|
+
... partition_id="train_0",
|
|
194
|
+
... mode="fetch",
|
|
195
|
+
... task_name="generate_sequences",
|
|
196
|
+
... sampling_config={"n_samples_per_prompt": 4}
|
|
197
|
+
... ))
|
|
198
|
+
>>> print(batch_meta.is_ready) # True if all samples ready
|
|
199
|
+
>>>
|
|
200
|
+
>>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
|
|
201
|
+
>>> so may include unready samples. Consumed samples will not be fetched.)
|
|
202
|
+
>>> batch_meta = asyncio.run(client.async_get_meta(
|
|
203
|
+
... data_fields=["input_ids", "attention_mask"],
|
|
204
|
+
... batch_size=4,
|
|
205
|
+
... partition_id="train_0",
|
|
206
|
+
... mode="force_fetch",
|
|
207
|
+
... task_name="generate_sequences"
|
|
208
|
+
... ))
|
|
209
|
+
>>> print(batch_meta.is_ready) # May be False if some samples not ready
|
|
210
|
+
"""
|
|
211
|
+
assert socket is not None
|
|
212
|
+
request_msg = ZMQMessage.create(
|
|
213
|
+
request_type=ZMQRequestType.GET_META,
|
|
214
|
+
sender_id=self.client_id,
|
|
215
|
+
receiver_id=self._controller.id,
|
|
216
|
+
body={
|
|
217
|
+
"data_fields": data_fields,
|
|
218
|
+
"batch_size": batch_size,
|
|
219
|
+
"partition_id": partition_id,
|
|
220
|
+
"mode": mode,
|
|
221
|
+
"task_name": task_name,
|
|
222
|
+
"sampling_config": sampling_config,
|
|
223
|
+
},
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
try:
|
|
227
|
+
await socket.send(request_msg.serialize())
|
|
228
|
+
response = await socket.recv()
|
|
229
|
+
response_msg = ZMQMessage.deserialize(response)
|
|
230
|
+
logger.debug(
|
|
231
|
+
f"[{self.client_id}]: Client get datameta response: {response_msg} "
|
|
232
|
+
f"from controller {self._controller.id}"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE:
|
|
236
|
+
metadata = response_msg.body["metadata"]
|
|
237
|
+
return metadata
|
|
238
|
+
else:
|
|
239
|
+
raise RuntimeError(
|
|
240
|
+
f"[{self.client_id}]: Failed to get metadata from controller {self._controller.id}: "
|
|
241
|
+
f"{response_msg.body.get('message', 'Unknown error')}"
|
|
242
|
+
)
|
|
243
|
+
except Exception as e:
|
|
244
|
+
raise RuntimeError(f"[{self.client_id}]: Error in get_meta: {str(e)}") from e
|
|
245
|
+
|
|
246
|
+
async def async_put(
|
|
247
|
+
self,
|
|
248
|
+
data: TensorDict,
|
|
249
|
+
metadata: Optional[BatchMeta] = None,
|
|
250
|
+
partition_id: Optional[str] = None,
|
|
251
|
+
):
|
|
252
|
+
"""Asynchronously write data to storage units based on metadata.
|
|
253
|
+
|
|
254
|
+
If metadata is not provided, it will be created automatically using insert mode
|
|
255
|
+
with the provided data fields and partition_id.
|
|
256
|
+
|
|
257
|
+
Note:
|
|
258
|
+
When using multiple workers for distributed execution, there may be data
|
|
259
|
+
ordering inconsistencies between workers during put operations.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
data: Data to write as TensorDict
|
|
263
|
+
metadata: Records the metadata of a batch of data samples, containing index and
|
|
264
|
+
storage unit information. If None, metadata will be auto-generated.
|
|
265
|
+
partition_id: Target data partition id (required if metadata is not provided)
|
|
266
|
+
|
|
267
|
+
Raises:
|
|
268
|
+
ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
|
|
269
|
+
RuntimeError: If storage operation fails
|
|
270
|
+
|
|
271
|
+
Example:
|
|
272
|
+
>>> batch_size = 4
|
|
273
|
+
>>> seq_len = 16
|
|
274
|
+
>>> current_partition_id = "train_0"
|
|
275
|
+
>>> # Example 1: Normal usage with existing metadata
|
|
276
|
+
>>> batch_meta = asyncio.run(client.async_get_meta(
|
|
277
|
+
... data_fields=["prompts", "attention_mask"],
|
|
278
|
+
... batch_size=batch_size,
|
|
279
|
+
... partition_id=current_partition_id,
|
|
280
|
+
... mode="fetch",
|
|
281
|
+
... task_name="generate_sequences",
|
|
282
|
+
... ))
|
|
283
|
+
>>> batch = asyncio.run(client.async_get_data(batch_meta))
|
|
284
|
+
>>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
|
|
285
|
+
>>> asyncio.run(client.async_put(data=output, metadata=batch_meta))
|
|
286
|
+
>>>
|
|
287
|
+
>>> # Example 2: Initial data insertion without pre-existing metadata
|
|
288
|
+
>>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
|
|
289
|
+
>>> # Please make sure the corresponding partition_id is empty before calling the async_put()
|
|
290
|
+
>>> # without metadata.
|
|
291
|
+
>>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
|
|
292
|
+
>>> # interleave the initial data if n_sample > 1 before calling the async_put().
|
|
293
|
+
>>> original_prompts = torch.randn(batch_size, seq_len)
|
|
294
|
+
>>> n_samples = 4
|
|
295
|
+
>>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
|
|
296
|
+
>>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
|
|
297
|
+
>>> # This will create metadata in "insert" mode internally.
|
|
298
|
+
>>> asyncio.run(client.async_put(data=prompts_repeated_batch, partition_id=current_partition_id))
|
|
299
|
+
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
if not hasattr(self, "storage_manager") or self.storage_manager is None:
|
|
303
|
+
raise RuntimeError(
|
|
304
|
+
f"[{self.client_id}]: Storage manager not initialized. "
|
|
305
|
+
"Call initialize_storage_manager() before performing storage operations."
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
if metadata is None:
|
|
309
|
+
if partition_id is None:
|
|
310
|
+
raise ValueError("partition_id must be provided if metadata is not given")
|
|
311
|
+
|
|
312
|
+
metadata = await self.async_get_meta(
|
|
313
|
+
data_fields=list(data.keys()),
|
|
314
|
+
batch_size=data.batch_size[0],
|
|
315
|
+
partition_id=partition_id,
|
|
316
|
+
mode="insert",
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
if not metadata or metadata.size == 0:
|
|
320
|
+
raise ValueError("metadata cannot be none or empty")
|
|
321
|
+
logger.debug(f"[{self.client_id}]: Put data with data: {data}")
|
|
322
|
+
|
|
323
|
+
await self.storage_manager.put_data(data, metadata)
|
|
324
|
+
|
|
325
|
+
logger.info(
|
|
326
|
+
f"[{self.client_id}]: partition {partition_id} put {metadata.size} samples to storage units successfully."
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
async def async_get_data(self, metadata: BatchMeta) -> TensorDict:
|
|
330
|
+
"""Asynchronously fetch data from storage units and organize into TensorDict.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
metadata: Batch metadata containing data location information and global indexes
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
TensorDict containing:
|
|
337
|
+
- Requested data fields (e.g., "prompts", "attention_mask")
|
|
338
|
+
|
|
339
|
+
Example:
|
|
340
|
+
>>> batch_meta = asyncio.run(client.async_get_meta(
|
|
341
|
+
... data_fields=["prompts", "attention_mask"],
|
|
342
|
+
... batch_size=4,
|
|
343
|
+
... partition_id="train_0",
|
|
344
|
+
... mode="fetch",
|
|
345
|
+
... task_name="generate_sequences",
|
|
346
|
+
... ))
|
|
347
|
+
>>> batch = asyncio.run(client.async_get_data(batch_meta))
|
|
348
|
+
>>> print(batch)
|
|
349
|
+
>>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
|
|
350
|
+
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
if not hasattr(self, "storage_manager") or self.storage_manager is None:
|
|
354
|
+
raise RuntimeError(
|
|
355
|
+
f"[{self.client_id}]: Storage manager not initialized. "
|
|
356
|
+
"Call initialize_storage_manager() before performing storage operations."
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if not metadata or metadata.size == 0:
|
|
360
|
+
return TensorDict({}, batch_size=0)
|
|
361
|
+
|
|
362
|
+
results = await self.storage_manager.get_data(metadata)
|
|
363
|
+
|
|
364
|
+
return results
|
|
365
|
+
|
|
366
|
+
async def async_clear(self, partition_id: str):
|
|
367
|
+
"""Asynchronously clear data from all storage units and controller metadata.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
partition_id: The partition id to clear data for
|
|
371
|
+
|
|
372
|
+
Raises:
|
|
373
|
+
RuntimeError: If clear operation fails
|
|
374
|
+
"""
|
|
375
|
+
try:
|
|
376
|
+
if not hasattr(self, "storage_manager") or self.storage_manager is None:
|
|
377
|
+
raise RuntimeError(
|
|
378
|
+
f"[{self.client_id}]: Storage manager not initialized. "
|
|
379
|
+
"Call initialize_storage_manager() before performing storage operations."
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
if not self._controller:
|
|
383
|
+
raise RuntimeError("No controller registered")
|
|
384
|
+
|
|
385
|
+
metadata = await self._get_clear_meta(partition_id)
|
|
386
|
+
|
|
387
|
+
# Clear the controller metadata
|
|
388
|
+
await self._clear_controller(partition_id)
|
|
389
|
+
|
|
390
|
+
# Clear storage unit data
|
|
391
|
+
await self.storage_manager.clear_data(metadata)
|
|
392
|
+
|
|
393
|
+
logger.info(f"[{self.client_id}]: Clear operation for partition_id {partition_id} completed.")
|
|
394
|
+
except Exception as e:
|
|
395
|
+
raise RuntimeError(f"Error in clear operation: {str(e)}") from e
|
|
396
|
+
|
|
397
|
+
@dynamic_socket(socket_name="request_handle_socket")
|
|
398
|
+
async def _get_clear_meta(self, partition_id: str, socket=None) -> BatchMeta:
|
|
399
|
+
"""Get metadata required for clear operation from controller.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
partition_id: Partition id to get clear metadata for
|
|
403
|
+
socket: ZMQ socket (injected by decorator)
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
BatchMeta: Records the metadata of a batch of data samples.
|
|
407
|
+
|
|
408
|
+
Raises:
|
|
409
|
+
RuntimeError: If controller returns error response
|
|
410
|
+
"""
|
|
411
|
+
request_msg = ZMQMessage.create(
|
|
412
|
+
request_type=ZMQRequestType.GET_CLEAR_META,
|
|
413
|
+
sender_id=self.client_id,
|
|
414
|
+
receiver_id=self._controller.id,
|
|
415
|
+
body={"partition_id": partition_id},
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
await socket.send(request_msg.serialize())
|
|
419
|
+
serialized = await socket.recv()
|
|
420
|
+
response_msg = ZMQMessage.deserialize(serialized)
|
|
421
|
+
|
|
422
|
+
if response_msg.request_type != ZMQRequestType.GET_CLEAR_META_RESPONSE:
|
|
423
|
+
raise RuntimeError(
|
|
424
|
+
f"Failed to get metadata for clear operation: {response_msg.body.get('message', 'Unknown error')}"
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
return response_msg.body["metadata"]
|
|
428
|
+
|
|
429
|
+
@dynamic_socket(socket_name="request_handle_socket")
|
|
430
|
+
async def _clear_controller(self, partition_id, socket=None):
|
|
431
|
+
"""Clear metadata from controller.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
partition_id: Partition id to clear metadata for
|
|
435
|
+
socket: ZMQ socket (injected by decorator)
|
|
436
|
+
|
|
437
|
+
Raises:
|
|
438
|
+
RuntimeError: If clear operation fails
|
|
439
|
+
"""
|
|
440
|
+
try:
|
|
441
|
+
request_msg = ZMQMessage.create(
|
|
442
|
+
request_type=ZMQRequestType.CLEAR_META,
|
|
443
|
+
sender_id=self.client_id,
|
|
444
|
+
receiver_id=self._controller.id,
|
|
445
|
+
body={"partition_id": partition_id},
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
await socket.send(request_msg.serialize())
|
|
449
|
+
serialized_msg = await socket.recv()
|
|
450
|
+
response_msg = ZMQMessage.deserialize(serialized_msg)
|
|
451
|
+
|
|
452
|
+
if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE:
|
|
453
|
+
raise RuntimeError(
|
|
454
|
+
f"Failed to clear controller {self._controller.id}: "
|
|
455
|
+
f"{response_msg.body.get('message', 'Unknown error')}"
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
logger.info(
|
|
459
|
+
f"[{self.client_id}]: Successfully clear controller {self._controller.id} for partition_id "
|
|
460
|
+
f"{partition_id}"
|
|
461
|
+
)
|
|
462
|
+
except Exception as e:
|
|
463
|
+
logger.error(f"[{self.client_id}]: Error clearing controller {self._controller.id}: {str(e)}")
|
|
464
|
+
raise
|
|
465
|
+
|
|
466
|
+
@dynamic_socket(socket_name="request_handle_socket")
|
|
467
|
+
async def check_data_consumption_status(self, task_name: str, partition_id: str):
|
|
468
|
+
"""Check if all samples for current step have been consumed.
|
|
469
|
+
|
|
470
|
+
Args:
|
|
471
|
+
task_name: Name of the task to check consumption for
|
|
472
|
+
partition_id: Partition id to check consumption status for
|
|
473
|
+
"""
|
|
474
|
+
# TODO: Implement this method to check if all samples for the current step has been consumed
|
|
475
|
+
pass
|
|
476
|
+
|
|
477
|
+
@dynamic_socket(socket_name="request_handle_socket")
|
|
478
|
+
async def check_data_production_status(self, data_fields: list[str], partition_id: str):
|
|
479
|
+
"""Check if all samples for current partition are ready for consumption.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
data_fields: Data fields to check production status for
|
|
483
|
+
partition_id: Partition id to check production status for
|
|
484
|
+
"""
|
|
485
|
+
# TODO: Implement this method to check if all samples for the current step is ready for consumption
|
|
486
|
+
pass
|
|
487
|
+
|
|
488
|
+
def close(self) -> None:
|
|
489
|
+
"""Close the client and cleanup resources including storage manager."""
|
|
490
|
+
try:
|
|
491
|
+
if hasattr(self, "storage_manager") and self.storage_manager:
|
|
492
|
+
if hasattr(self.storage_manager, "close"):
|
|
493
|
+
self.storage_manager.close()
|
|
494
|
+
except Exception as e:
|
|
495
|
+
logger.warning(f"Error closing storage manager: {e}")
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
class TransferQueueClient(AsyncTransferQueueClient):
|
|
499
|
+
"""Synchronous client wrapper for TransferQueue.
|
|
500
|
+
|
|
501
|
+
Provides synchronous versions of all async methods for convenience.
|
|
502
|
+
"""
|
|
503
|
+
|
|
504
|
+
def __init__(
|
|
505
|
+
self,
|
|
506
|
+
client_id: str,
|
|
507
|
+
controller_info: ZMQServerInfo,
|
|
508
|
+
):
|
|
509
|
+
"""Initialize the synchronous TransferQueue client.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
client_id: Unique identifier for this client instance
|
|
513
|
+
controller_info: Single controller ZMQ server information
|
|
514
|
+
"""
|
|
515
|
+
super().__init__(
|
|
516
|
+
client_id,
|
|
517
|
+
controller_info,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
def put(self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None):
|
|
521
|
+
"""Synchronously write data to storage units.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
data: Data to write as TensorDict
|
|
525
|
+
metadata: Optional metadata containing index and storage unit information
|
|
526
|
+
partition_id: Target data partition id (required if metadata is not provided)
|
|
527
|
+
"""
|
|
528
|
+
return asyncio.run(self.async_put(data, metadata, partition_id))
|
|
529
|
+
|
|
530
|
+
def get_meta(
|
|
531
|
+
self,
|
|
532
|
+
data_fields: list[str],
|
|
533
|
+
batch_size: int,
|
|
534
|
+
partition_id: str,
|
|
535
|
+
task_name: Optional[str] = None,
|
|
536
|
+
sampling_config: Optional[dict[str, Any]] = None,
|
|
537
|
+
) -> BatchMeta:
|
|
538
|
+
"""Synchronously fetch data metadata from controller.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
data_fields: List of data field names to retrieve metadata for
|
|
542
|
+
batch_size: Number of samples to request in the batch
|
|
543
|
+
partition_id: Target data partition id
|
|
544
|
+
task_name: Optional task name associated with the request
|
|
545
|
+
sampling_config: Optional sampling configuration for custom samplers.
|
|
546
|
+
For GRPOGroupNSampler, should include "n_samples_per_prompt": int
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
BatchMeta: Batch metadata containing data location information
|
|
550
|
+
"""
|
|
551
|
+
return asyncio.run(
|
|
552
|
+
self.async_get_meta(
|
|
553
|
+
data_fields=data_fields,
|
|
554
|
+
batch_size=batch_size,
|
|
555
|
+
partition_id=partition_id,
|
|
556
|
+
task_name=task_name,
|
|
557
|
+
sampling_config=sampling_config,
|
|
558
|
+
)
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
def get_data(self, metadata: BatchMeta) -> TensorDict:
|
|
562
|
+
"""Synchronously fetch data from storage units.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
metadata: Batch metadata containing data location information
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
TensorDict containing requested data fields
|
|
569
|
+
"""
|
|
570
|
+
return asyncio.run(self.async_get_data(metadata))
|
|
571
|
+
|
|
572
|
+
def clear(self, partition_id: str):
|
|
573
|
+
"""Synchronously clear data from storage units and controller metadata.
|
|
574
|
+
|
|
575
|
+
Args:
|
|
576
|
+
partition_id: The partition id to clear data for
|
|
577
|
+
"""
|
|
578
|
+
return asyncio.run(self.async_clear(partition_id))
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def process_zmq_server_info(
|
|
582
|
+
handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]]
|
|
583
|
+
| Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"],
|
|
584
|
+
): # noqa: UP007
|
|
585
|
+
"""Extract ZMQ server information from handler objects.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
handlers: Dictionary of handler objects (controllers, storage managers, or storage units),
|
|
589
|
+
or a single handler object
|
|
590
|
+
|
|
591
|
+
Returns:
|
|
592
|
+
If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information
|
|
593
|
+
If handlers is a single object: ZMQ server information for that object
|
|
594
|
+
|
|
595
|
+
Examples:
|
|
596
|
+
>>> # Single handler
|
|
597
|
+
>>> controller = TransferQueueController.remote(...)
|
|
598
|
+
>>> info = process_zmq_server_info(controller)
|
|
599
|
+
>>>
|
|
600
|
+
>>> # Multiple handlers
|
|
601
|
+
>>> handlers = {"storage_0": storage_0, "storage_1": storage_1}
|
|
602
|
+
>>> info_dict = process_zmq_server_info(handlers)"""
|
|
603
|
+
# Handle single handler object case
|
|
604
|
+
if not isinstance(handlers, dict):
|
|
605
|
+
return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[attr-defined]
|
|
606
|
+
else:
|
|
607
|
+
# Handle dictionary case
|
|
608
|
+
server_info = {}
|
|
609
|
+
for name, handler in handlers.items():
|
|
610
|
+
server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined]
|
|
611
|
+
return server_info
|