TransferQueue 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,663 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import logging
17
+ import os
18
+ from functools import wraps
19
+ from typing import Any, Callable, Optional, Union
20
+ from uuid import uuid4
21
+
22
+ import ray
23
+ import torch
24
+ import zmq
25
+ import zmq.asyncio
26
+ from tensordict import NonTensorStack, TensorDict
27
+
28
+ from transfer_queue.controller import TransferQueueController
29
+ from transfer_queue.metadata import (
30
+ BatchMeta,
31
+ StorageMetaGroup,
32
+ )
33
+ from transfer_queue.storage import TransferQueueStorageSimpleUnit
34
+ from transfer_queue.utils.utils import (
35
+ TransferQueueRole,
36
+ )
37
+ from transfer_queue.utils.zmq_utils import (
38
+ ZMQMessage,
39
+ ZMQRequestType,
40
+ ZMQServerInfo,
41
+ create_zmq_socket,
42
+ )
43
+
44
+ logger = logging.getLogger(__name__)
45
+ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
46
+
47
+
48
+ class AsyncTransferQueueClient:
49
+ def __init__(
50
+ self,
51
+ client_id: str,
52
+ controller_infos: ZMQServerInfo | dict[Any, ZMQServerInfo],
53
+ storage_infos: ZMQServerInfo | dict[Any, ZMQServerInfo],
54
+ ):
55
+ self.client_id = client_id
56
+
57
+ self._controllers: dict[str, ZMQServerInfo] = {}
58
+ self._storages: dict[str, ZMQServerInfo] = {}
59
+ self._register_servers(TransferQueueRole.CONTROLLER, controller_infos)
60
+ self._register_servers(TransferQueueRole.STORAGE, storage_infos)
61
+
62
+ def _register_servers(
63
+ self,
64
+ role: TransferQueueRole,
65
+ server_infos: ZMQServerInfo | dict[Any, ZMQServerInfo],
66
+ ):
67
+ mapping = self._controllers if role == TransferQueueRole.CONTROLLER else self._storages
68
+
69
+ if not isinstance(server_infos, dict):
70
+ server_infos = {server_infos.id: server_infos}
71
+
72
+ for info in server_infos.values():
73
+ if not isinstance(info, ZMQServerInfo):
74
+ raise ValueError(f"Invalid server info for {role} {info.id}")
75
+
76
+ if info.id not in mapping:
77
+ mapping[info.id] = info
78
+ logger.info(f"[{self.client_id}]: Registered {role} server {info.id} at {info.ip}")
79
+ else:
80
+ logger.warning(f"[{self.client_id}]: Server {info.id} already registered, skipping")
81
+
82
+ @staticmethod
83
+ def dynamic_socket(target_role: TransferQueueRole, socket_name: str):
84
+ """Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close).
85
+
86
+ Args:
87
+ target_role (TransferQueueRole): Server type to connect to. Must be one of:
88
+ - `TransferQueueRole.CONTROLLER`
89
+ - `TransferQueueRole.STORAGE`
90
+ socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port").
91
+
92
+ Decorated Function Rules:
93
+ 1. Must be an async class method (needs `self`).
94
+ 2. `self` requires:
95
+ - `_controllers`/`_storages`: Server registries (match `target_role`).
96
+ - `client_id`: Unique client ID (for socket identity).
97
+ 3. Specify target server via:
98
+ - `target_controller` (for Controller) or `target_storage` (for Storage) arg.
99
+ - Controller role: Uses first registered server if no ID is given.
100
+ 4. Receives ZMQ socket via `socket` keyword arg (injected by decorator).
101
+ """
102
+
103
+ def decorator(func: Callable):
104
+ @wraps(func)
105
+ async def wrapper(self, *args, **kwargs):
106
+ if target_role == TransferQueueRole.CONTROLLER:
107
+ servers = self._controllers
108
+ target = "target_controller"
109
+ elif target_role == TransferQueueRole.STORAGE:
110
+ servers = self._storages
111
+ target = "target_storage"
112
+ else:
113
+ raise ValueError("Invalid target_role, must be CONTROLLER or STORAGE")
114
+
115
+ server_key = kwargs.get(target)
116
+ if server_key is None:
117
+ for arg in args:
118
+ if isinstance(arg, str) and arg in servers.keys():
119
+ server_key = arg
120
+ break
121
+ if server_key is None and target == "target_controller":
122
+ server_key = next(iter(servers.keys()))
123
+
124
+ server_info = servers.get(server_key)
125
+ if not server_info:
126
+ raise RuntimeError(f"Server {server_key} not found in registered {target_role} servers")
127
+
128
+ context = zmq.asyncio.Context()
129
+ address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}"
130
+ identity = f"{self.client_id}_to_{server_info.id}_{uuid4()}".encode()
131
+ sock = create_zmq_socket(context, zmq.DEALER, identity=identity)
132
+
133
+ try:
134
+ sock.connect(address)
135
+ logger.info(
136
+ f"[{self.client_id}]: Connected to {target_role} {server_info.id} at {address} "
137
+ f"with identity {identity.decode()}"
138
+ )
139
+
140
+ kwargs["socket"] = sock
141
+ return await func(self, *args, **kwargs)
142
+ except Exception as e:
143
+ logger.error(
144
+ f"[{self.client_id}]: Error in socket operation with {target_role} {server_info.id}: {e}"
145
+ )
146
+ raise
147
+ finally:
148
+ try:
149
+ if not sock.closed:
150
+ sock.setsockopt(zmq.LINGER, -1)
151
+ sock.close()
152
+ sock.close(linger=0)
153
+ except Exception as e:
154
+ logger.warning(
155
+ f"[{self.client_id}]: Error closing socket to {target_role} {server_info.id}: {e}"
156
+ )
157
+
158
+ context.term()
159
+
160
+ return wrapper
161
+
162
+ return decorator
163
+
164
+ @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
165
+ async def async_get_meta(
166
+ self,
167
+ data_fields: list[str],
168
+ batch_size: int,
169
+ global_step: int,
170
+ mode: str = "fetch",
171
+ get_n_samples: bool = False,
172
+ task_name: Optional[str] = None,
173
+ target_controller: Optional[str] = None,
174
+ socket: Optional[zmq.asyncio.Socket] = None,
175
+ ) -> BatchMeta:
176
+ """Asynchronously fetches data metadata via ZMQ from the target controller.
177
+
178
+ Args:
179
+ data_fields (list[str]): List of fields to retrieve metadata for
180
+ batch_size (int): Processing batch size
181
+ global_step (int): Current training/processing step
182
+ mode (str): Data fetch mode. 'fetch' to get ready data, 'force_fetch' to get data regardless of readiness.
183
+ 'insert' IS AN INTERNAL USAGE THAT SHOULD NOT BE USED BY USERS.
184
+ get_n_samples (bool): If True, we arrange the samples of the same prompt in contiguous order. In 'fetch'
185
+ mode, only the samples of the same prompt that are all ready will be returned.
186
+ task_name (str): Optional task name associated with the request
187
+ target_controller (str): ID of the target controller to send the request to
188
+ socket (zmq.asyncio.Socket): ZMQ async socket for message transmission
189
+
190
+ Example:
191
+ >>> batch_size = 4
192
+ >>> current_step = 0
193
+ >>> # Example 1: "fetch" a batch of metadata that has been produced
194
+ >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"],
195
+ >>> batch_size=batch_size,
196
+ >>> global_step=current_step,
197
+ >>> mode="fetch",
198
+ >>> get_n_samples=False,
199
+ >>> task_name="generate_sequences",
200
+ >>> ))
201
+ >>> print(batch_meta.is_ready) # you should get a batch_meta with is_ready=True
202
+ >>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, True, True, True]
203
+ >>>
204
+ >>> # Example 2: "force_fetch" a batch of metadata, ignoring their production status (but we still make
205
+ >>> # sure the corresponding data has not been consumed)
206
+ >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["input_ids", "attention_mask"],
207
+ >>> batch_size=batch_size,
208
+ >>> global_step=current_step,
209
+ >>> mode="force_fetch",
210
+ >>> get_n_samples=False,
211
+ >>> task_name="generate_sequences",
212
+ >>> ))
213
+ >>> print(batch_meta.is_ready) # you may get a batch_meta with is_ready=False
214
+ >>> print([sample_meta.is_ready for sample_meta in batch_meta.samples]) # [True, False, False, True]
215
+
216
+ Returns:
217
+ BatchMeta: Metadata object containing data structure, sample info, etc.
218
+ """
219
+ assert socket is not None
220
+ request_msg = ZMQMessage.create(
221
+ request_type=ZMQRequestType.GET_META,
222
+ sender_id=self.client_id,
223
+ receiver_id=target_controller,
224
+ body={
225
+ "data_fields": data_fields,
226
+ "batch_size": batch_size,
227
+ "global_step": global_step,
228
+ "mode": mode,
229
+ "get_n_samples": get_n_samples,
230
+ "task_name": task_name,
231
+ },
232
+ )
233
+
234
+ try:
235
+ await socket.send(request_msg.serialize())
236
+ response = await socket.recv()
237
+ response_msg = ZMQMessage.deserialize(response)
238
+ logger.debug(
239
+ f"[{self.client_id}]: Client get datameta response: {response_msg} from controller {target_controller}"
240
+ )
241
+
242
+ if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE:
243
+ metadata = response_msg.body["metadata"]
244
+ return metadata
245
+ else:
246
+ raise RuntimeError(
247
+ f"[{self.client_id}]: Failed to get metadata from controller {target_controller}: "
248
+ f"{response_msg.body.get('message', 'Unknown error')}"
249
+ )
250
+ except Exception as e:
251
+ raise RuntimeError(f"[{self.client_id}]: Error in get_meta: {str(e)}") from e
252
+
253
+ async def async_put(
254
+ self,
255
+ data: TensorDict,
256
+ metadata: Optional[BatchMeta] = None,
257
+ global_step: Optional[int] = None,
258
+ ):
259
+ """Asynchronously writes data to appropriate Storage Units based on metadata.
260
+
261
+ If metadata isn't provided, it will be created automatically using the insert mode
262
+ with the provided data_columns and global_step.
263
+
264
+ Args:
265
+ data (torch.Tensor | tensordict.TensorDict): Data to write, either a Tensor or TensorDict
266
+ metadata (BatchMeta, optional): Optional metadata containing index and storage unit information
267
+ global_step (int, optional): Current step (required if no metadata is provided)
268
+
269
+ Example:
270
+ >>> batch_size = 4
271
+ >>> seq_len = 16
272
+ >>> current_step = 0
273
+ >>> # Example 1: normal usage
274
+ >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"],
275
+ >>> batch_size=batch_size,
276
+ >>> global_step=current_step,
277
+ >>> mode="fetch",
278
+ >>> get_n_samples=False,
279
+ >>> task_name="generate_sequences",
280
+ >>> ))
281
+ >>> batch = asyncio.run(client.async_get_data(batch_meta))
282
+ >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
283
+ >>> asyncio.run(client.async_put(data=output, metadata=batch_meta))
284
+ >>>
285
+ >>> # Example 2: put the initial data into the system without pre-existing metadata
286
+ >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given global_step!
287
+ >>> # Please make sure the corresponding global_step is empty before calling the async_put()
288
+ >>> # without metadata.
289
+ >>> # Now we only support put all the data of the corresponding global step in once. You should repeat with
290
+ >>> # interleave the initial data if n_sample > 1 before calling the async_put().
291
+ >>> original_prompts = torch.randn(batch_size, seq_len)
292
+ >>> n_samples = 4
293
+ >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
294
+ >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
295
+ >>> # This will create metadata in "insert" mode internally.
296
+ >>> asyncio.run(client.async_put(data=prompts_repeated_batch, global_step=current_step))
297
+
298
+ """
299
+ if metadata is None:
300
+ assert global_step is not None, "global_steps must be provided if metadata is not given"
301
+
302
+ metadata = await self.async_get_meta(
303
+ data_fields=list(data.keys()),
304
+ batch_size=data.batch_size[0],
305
+ global_step=global_step,
306
+ get_n_samples=True,
307
+ mode="insert",
308
+ )
309
+
310
+ if not metadata or metadata.size == 0:
311
+ raise ValueError("metadata cannot be none or empty")
312
+ logger.debug(f"[{self.client_id}]: Put data with data: {data}")
313
+ tasks = [
314
+ self._put_to_storage(get_transfer_info(meta_group, data), target_storage=storage_id)
315
+ for storage_id, meta_group in metadata.storage_meta_groups.items()
316
+ ]
317
+ await asyncio.gather(*tasks)
318
+
319
+ logger.info(
320
+ f"[{self.client_id}]: step {global_step} put {metadata.size} samples to storage units successfully."
321
+ )
322
+
323
+ @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket")
324
+ async def _put_to_storage(self, storage_unit_data, target_storage=None, socket=None):
325
+ """
326
+ Send data to a specific storage unit.
327
+ """
328
+ global_indexes = storage_unit_data["global_indexes"]
329
+ local_indexes = storage_unit_data["local_indexes"]
330
+ field_data = TensorDict(
331
+ {
332
+ field: (
333
+ torch.nested.as_nested_tensor(storage_unit_data["field_data"][field])
334
+ if storage_unit_data["field_data"][field]
335
+ and all(isinstance(x, torch.Tensor) for x in storage_unit_data["field_data"][field])
336
+ else NonTensorStack(*storage_unit_data["field_data"][field])
337
+ )
338
+ for field in storage_unit_data["field_data"]
339
+ }
340
+ )
341
+
342
+ request_msg = ZMQMessage.create(
343
+ request_type=ZMQRequestType.PUT_DATA,
344
+ sender_id=self.client_id,
345
+ receiver_id=target_storage,
346
+ body={"global_indexes": global_indexes, "local_indexes": local_indexes, "field_data": field_data},
347
+ )
348
+ try:
349
+ await socket.send(request_msg.serialize())
350
+ serialized = await socket.recv()
351
+ response_msg = ZMQMessage.deserialize(serialized)
352
+
353
+ if response_msg.request_type != ZMQRequestType.PUT_DATA_RESPONSE:
354
+ raise RuntimeError(
355
+ f"Failed to put data to storage unit {target_storage}: "
356
+ f"{response_msg.body.get('message', 'Unknown error')}"
357
+ )
358
+ except Exception as e:
359
+ raise RuntimeError(f"Error in put to storage unit {target_storage}: {str(e)}") from e
360
+
361
+ @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket")
362
+ async def _get_from_storage(self, index_data, target_storage=None, socket=None):
363
+ global_indexes = index_data["global_indexes"]
364
+ local_indexes = index_data["local_indexes"]
365
+ fields = index_data["fields"]
366
+
367
+ request_msg = ZMQMessage.create(
368
+ request_type=ZMQRequestType.GET_DATA,
369
+ sender_id=self.client_id,
370
+ receiver_id=target_storage,
371
+ body={"local_indexes": local_indexes, "fields": fields},
372
+ )
373
+
374
+ try:
375
+ await socket.send(request_msg.serialize())
376
+ serialized = await socket.recv()
377
+ response_msg = ZMQMessage.deserialize(serialized)
378
+ logger.info(f"[{self.client_id}]: get data response from storage unit {target_storage}: {response_msg}")
379
+
380
+ if response_msg.request_type == ZMQRequestType.GET_DATA_RESPONSE:
381
+ # Return data and index information from this storage unit
382
+ storage_unit_data = response_msg.body["data"]
383
+ return global_indexes, fields, storage_unit_data
384
+ else:
385
+ raise RuntimeError(
386
+ f"Failed to get data from storage unit {target_storage}: "
387
+ f"{response_msg.body.get('message', 'Unknown error')}"
388
+ )
389
+ except Exception as e:
390
+ raise RuntimeError(f"Error getting data from storage unit {target_storage}: {str(e)}") from e
391
+
392
+ async def async_get_data(self, metadata: BatchMeta) -> TensorDict:
393
+ """Asynchronously fetches data via Storage Units and organizes it into a TensorDict.
394
+
395
+ Args:
396
+ metadata (BatchMeta): Object containing:
397
+ - Data location info (which Storage Units hold the data)
398
+ - `global_indexes` to determine the ordering of merged results
399
+
400
+ Returns:
401
+ tensordict.TensorDict with:
402
+ - Requested data fields (e.g., "prompt_token_ids", "response_token_ids").
403
+ - "global_indexes" key: Maps each sample to its original global index.
404
+
405
+ Example:
406
+ >>> batch_size = 4
407
+ >>> seq_len = 16
408
+ >>> current_step = 0
409
+ >>> batch_meta = asyncio.run(client.async_get_meta(data_fields=["prompts", "attention_mask"],
410
+ >>> batch_size=batch_size,
411
+ >>> global_step=current_step,
412
+ >>> mode="fetch",
413
+ >>> get_n_samples=False,
414
+ >>> task_name="generate_sequences",
415
+ >>> ))
416
+ >>> batch = asyncio.run(client.async_get_data(batch_meta))
417
+ >>> print(batch)
418
+ >>> # this is a TensorDict with fields "prompts" and "attention_mask".
419
+ >>> # The order of samples in the TensorDict matches the order of global_indexes in batch_meta
420
+
421
+ Note:
422
+ Why track `global_indexes`?
423
+ - Batches may be rearranged during task processing. `global_indexes` retains the original
424
+ mapping to Storage Units, enabling correct data writing back to Storage Units later.
425
+
426
+ """
427
+ if not metadata or metadata.size == 0:
428
+ return TensorDict({}, batch_size=0)
429
+
430
+ # Use optimized retrieval with direct storage group access
431
+ tasks = [
432
+ self._get_from_storage(meta_group.get_transfer_info(), target_storage=storage_id)
433
+ for storage_id, meta_group in metadata.storage_meta_groups.items()
434
+ ]
435
+
436
+ results = await asyncio.gather(*tasks)
437
+
438
+ # global_index: {field1: value, field2: value, ...}
439
+ storage_data: dict[int, dict[str, torch.Tensor]] = {}
440
+ for global_indexes, fields, storage_unit_data in results:
441
+ extracted_data = {field: storage_unit_data[field] for field in fields}
442
+
443
+ for idx, global_idx in enumerate(global_indexes):
444
+ if global_idx not in storage_data:
445
+ storage_data[global_idx] = {}
446
+ for field in fields:
447
+ storage_data[global_idx][field] = extracted_data[field][idx]
448
+
449
+ ordered_data: dict[str, torch.Tensor] = {field: [] for field in metadata.field_names}
450
+ for global_idx in metadata.global_indexes:
451
+ for field in metadata.field_names:
452
+ ordered_data[field].append(storage_data[global_idx][field])
453
+
454
+ tensor_data = {
455
+ field: (
456
+ torch.stack(torch.nested.as_nested_tensor(v).unbind())
457
+ if v
458
+ and all(isinstance(item, torch.Tensor) for item in v)
459
+ and all(item.shape == v[0].shape for item in v)
460
+ else (
461
+ torch.nested.as_nested_tensor(v)
462
+ if v and all(isinstance(item, torch.Tensor) for item in v)
463
+ else NonTensorStack(*v)
464
+ )
465
+ )
466
+ for field, v in ordered_data.items()
467
+ }
468
+ tensor_data["global_indexes"] = torch.tensor(metadata.global_indexes)
469
+
470
+ return TensorDict(tensor_data, batch_size=len(storage_data))
471
+
472
+ async def async_clear(self, global_step: int):
473
+ """Asynchronously clears data from all storage units and controller metadata.
474
+
475
+ Args:
476
+ global_step (int): The training step associated with the clear operation
477
+
478
+ """
479
+ try:
480
+ target_controller = next(iter(self._controllers.keys()))
481
+ metadata = await self._get_clear_meta(global_step, target_controller)
482
+
483
+ tasks = []
484
+
485
+ for target_controller in self._controllers.keys():
486
+ tasks.append(self._clear_controller(global_step, target_controller))
487
+
488
+ # Group samples by storage unit for clearing
489
+ for target_storage, group in metadata.storage_meta_groups.items():
490
+ group_info = group.get_transfer_info()
491
+ if target_storage not in self._storages:
492
+ logger.warning(
493
+ f"[{self.client_id}]: Storage unit {target_storage} not registered, skipping clear operation."
494
+ )
495
+ continue
496
+ tasks.append(
497
+ self._clear_storage_unit(
498
+ group_info["local_indexes"],
499
+ target_storage,
500
+ )
501
+ )
502
+
503
+ results = await asyncio.gather(*tasks, return_exceptions=True)
504
+
505
+ for i, result in enumerate(results):
506
+ if isinstance(result, Exception):
507
+ logger.error(f"[{self.client_id}]: Error in clear operation task {i}: {result}")
508
+
509
+ logger.info(f"[{self.client_id}]: Clear operation for global_step {global_step} completed.")
510
+ except Exception as e:
511
+ raise RuntimeError(f"Error in clear operation: {str(e)}") from e
512
+
513
+ @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
514
+ async def _get_clear_meta(self, global_step: int, target_controller=None, socket=None):
515
+ request_msg = ZMQMessage.create(
516
+ request_type=ZMQRequestType.GET_CLEAR_META,
517
+ sender_id=self.client_id,
518
+ receiver_id=target_controller,
519
+ body={"global_step": global_step},
520
+ )
521
+
522
+ await socket.send(request_msg.serialize())
523
+ serialized = await socket.recv()
524
+ response_msg = ZMQMessage.deserialize(serialized)
525
+
526
+ if response_msg.request_type != ZMQRequestType.GET_CLEAR_META_RESPONSE:
527
+ raise RuntimeError(
528
+ f"Failed to get metadata for clear operation: {response_msg.body.get('message', 'Unknown error')}"
529
+ )
530
+
531
+ return response_msg.body["metadata"]
532
+
533
+ @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
534
+ async def _clear_controller(self, global_step, target_controller=None, socket=None):
535
+ try:
536
+ request_msg = ZMQMessage.create(
537
+ request_type=ZMQRequestType.CLEAR_META,
538
+ sender_id=self.client_id,
539
+ receiver_id=target_controller,
540
+ body={"global_step": global_step},
541
+ )
542
+
543
+ await socket.send(request_msg.serialize())
544
+ serialized_msg = await socket.recv()
545
+ response_msg = ZMQMessage.deserialize(serialized_msg)
546
+
547
+ if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE:
548
+ raise RuntimeError(
549
+ f"Failed to clear controller {target_controller}: "
550
+ f"{response_msg.body.get('message', 'Unknown error')}"
551
+ )
552
+
553
+ logger.info(
554
+ f"[{self.client_id}]: Successfully clear controller {target_controller} for global_step {global_step}"
555
+ )
556
+ except Exception as e:
557
+ logger.error(f"[{self.client_id}]: Error clearing controller {target_controller}: {str(e)}")
558
+ raise
559
+
560
+ @dynamic_socket(target_role=TransferQueueRole.STORAGE, socket_name="put_get_socket")
561
+ async def _clear_storage_unit(self, local_indexes, target_storage=None, socket=None):
562
+ try:
563
+ request_msg = ZMQMessage.create(
564
+ request_type=ZMQRequestType.CLEAR_DATA,
565
+ sender_id=self.client_id,
566
+ receiver_id=target_storage,
567
+ body={"local_indexes": local_indexes},
568
+ )
569
+
570
+ await socket.send(request_msg.serialize())
571
+ serialized_msg = await socket.recv()
572
+ response_msg = ZMQMessage.deserialize(serialized_msg)
573
+
574
+ if response_msg.request_type != ZMQRequestType.CLEAR_DATA_RESPONSE:
575
+ raise RuntimeError(
576
+ f"Failed to clear storage {target_storage}: {response_msg.body.get('message', 'Unknown error')}"
577
+ )
578
+
579
+ logger.info(f"[{self.client_id}]: Successfully clear storage unit {target_storage}")
580
+ except Exception as e:
581
+ logger.error(f"[{self.client_id}]: Error clearing storage unit {target_storage}: {str(e)}")
582
+ raise
583
+
584
+ @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
585
+ def check_current_step_consumption(self, task_name: str, global_step: int):
586
+ # TODO: Implement this method to check if all samples for the current step has been consumed
587
+ pass
588
+
589
+ @dynamic_socket(target_role=TransferQueueRole.CONTROLLER, socket_name="request_handle_socket")
590
+ def check_current_step_production(self, data_fields: list[str], global_step: int):
591
+ # TODO: Implement this method to check if all samples for the current step is ready for consumption
592
+ pass
593
+
594
+
595
+ class TransferQueueClient(AsyncTransferQueueClient):
596
+ def __init__(
597
+ self,
598
+ client_id: str,
599
+ controller_infos: ZMQServerInfo | dict[Any, ZMQServerInfo],
600
+ storage_infos: ZMQServerInfo | dict[Any, ZMQServerInfo],
601
+ ):
602
+ super().__init__(
603
+ client_id,
604
+ controller_infos,
605
+ storage_infos,
606
+ )
607
+
608
+ def put(self, data: TensorDict, metadata: Optional[BatchMeta] = None, global_step: Optional[int] = None):
609
+ return asyncio.run(self.async_put(data, metadata, global_step))
610
+
611
+ def get_meta(
612
+ self,
613
+ data_fields: list[str],
614
+ batch_size: int,
615
+ global_step: int,
616
+ get_n_samples: bool = False,
617
+ task_name: Optional[str] = None,
618
+ ) -> BatchMeta:
619
+ return asyncio.run(
620
+ self.async_get_meta(
621
+ data_fields=data_fields,
622
+ batch_size=batch_size,
623
+ global_step=global_step,
624
+ get_n_samples=get_n_samples,
625
+ task_name=task_name,
626
+ )
627
+ )
628
+
629
+ def get_data(self, metadata: BatchMeta) -> TensorDict:
630
+ return asyncio.run(self.async_get_data(metadata))
631
+
632
+ def clear(self, global_step: int):
633
+ return asyncio.run(self.async_clear(global_step))
634
+
635
+
636
+ def _add_field_data(
637
+ transfer_dict: dict[str, Any], storage_meta_group: StorageMetaGroup, data: TensorDict
638
+ ) -> dict[str, Any]:
639
+ """Helper function to add field data to the transfer dictionary"""
640
+ field_names = transfer_dict["fields"]
641
+ for fname in field_names:
642
+ if fname in data.keys():
643
+ transfer_dict["field_data"][fname] = []
644
+ for sample_meta in storage_meta_group.sample_metas:
645
+ transfer_dict["field_data"][fname].append(data[fname][sample_meta.batch_index])
646
+ return transfer_dict
647
+
648
+
649
+ def get_transfer_info(
650
+ storage_meta_group: StorageMetaGroup,
651
+ data: TensorDict,
652
+ ) -> dict[str, Any]:
653
+ """Convert to dictionary format with field data for put operations"""
654
+ result = storage_meta_group.get_transfer_info(field_names=data.keys())
655
+ result = _add_field_data(result, storage_meta_group, data)
656
+ return result
657
+
658
+
659
+ def process_zmq_server_info(handlers: dict[Any, Union[TransferQueueController, TransferQueueStorageSimpleUnit]]): # noqa: UP007
660
+ server_info = {}
661
+ for name, handler in handlers.items():
662
+ server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined]
663
+ return server_info