TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,611 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import logging
17
+ import os
18
+ from functools import wraps
19
+ from typing import Any, Callable, Optional, Union
20
+ from uuid import uuid4
21
+
22
+ import ray
23
+ import zmq
24
+ import zmq.asyncio
25
+ from tensordict import TensorDict
26
+
27
+ from transfer_queue.controller import TransferQueueController
28
+ from transfer_queue.metadata import (
29
+ BatchMeta,
30
+ )
31
+ from transfer_queue.storage import (
32
+ SimpleStorageUnit,
33
+ TransferQueueStorageManager,
34
+ TransferQueueStorageManagerFactory,
35
+ )
36
+ from transfer_queue.utils.zmq_utils import (
37
+ ZMQMessage,
38
+ ZMQRequestType,
39
+ ZMQServerInfo,
40
+ create_zmq_socket,
41
+ )
42
+
43
+ logger = logging.getLogger(__name__)
44
+ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
45
+
46
+
47
+ class AsyncTransferQueueClient:
48
+ """Asynchronous client for interacting with TransferQueue controller and storage systems.
49
+
50
+ This client provides async methods for data transfer operations including getting metadata,
51
+ reading data from storage, writing data to storage, and clearing data.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ client_id: str,
57
+ controller_info: ZMQServerInfo,
58
+ ):
59
+ """Initialize the asynchronous TransferQueue client.
60
+
61
+ Args:
62
+ client_id: Unique identifier for this client instance
63
+ controller_info: Single controller ZMQ server information
64
+ """
65
+ if controller_info is None:
66
+ raise ValueError("controller_info cannot be None")
67
+ if not isinstance(controller_info, ZMQServerInfo):
68
+ raise TypeError(f"controller_info must be ZMQServerInfo, got {type(controller_info)}")
69
+ self.client_id = client_id
70
+ self._controller: ZMQServerInfo = controller_info
71
+ logger.info(f"[{self.client_id}]: Registered Controller server {controller_info.id} at {controller_info.ip}")
72
+
73
+ def initialize_storage_manager(
74
+ self,
75
+ manager_type: str,
76
+ config: dict[str, Any],
77
+ ):
78
+ """Initialize the storage manager.
79
+
80
+ Args:
81
+ manager_type: Type of storage manager to create. Supported types include:
82
+ AsyncSimpleStorageManager, KVStorageManager (under development), etc.
83
+ config: Configuration dictionary for the storage manager.
84
+ For AsyncSimpleStorageManager, must contain the following required keys:
85
+ - controller_info: ZMQ server information about the controller
86
+ - storage_unit_infos: ZMQ server information about the storage units
87
+
88
+ """
89
+ self.storage_manager = TransferQueueStorageManagerFactory.create(manager_type, config)
90
+
91
+ # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
92
+ @staticmethod
93
+ def dynamic_socket(socket_name: str):
94
+ """Decorator to auto-manage ZMQ sockets for Controller/Storage servers.
95
+
96
+ Handles socket lifecycle: create -> connect -> inject -> close.
97
+
98
+ Args:
99
+ socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port")
100
+
101
+ Decorated Function Requirements:
102
+ 1. Must be an async class method (needs `self`)
103
+ 2. `self` must have:
104
+ - `_controller`: Server registry
105
+ - `client_id`: Unique client ID for socket identity
106
+ 3. Receives ZMQ socket via `socket` keyword argument (injected by decorator)
107
+ """
108
+
109
+ def decorator(func: Callable):
110
+ @wraps(func)
111
+ async def wrapper(self, *args, **kwargs):
112
+ server_info = self._controller
113
+ if not server_info:
114
+ raise RuntimeError("No controller registered")
115
+
116
+ context = zmq.asyncio.Context()
117
+ address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}"
118
+ identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
119
+ sock = create_zmq_socket(context, zmq.DEALER, identity=identity)
120
+
121
+ try:
122
+ sock.connect(address)
123
+ logger.info(
124
+ f"[{self.client_id}]: Connected to Controller {server_info.id} at {address} "
125
+ f"with identity {identity.decode()}"
126
+ )
127
+
128
+ kwargs["socket"] = sock
129
+ return await func(self, *args, **kwargs)
130
+ except Exception as e:
131
+ logger.error(f"[{self.client_id}]: Error in socket operation with Controller {server_info.id}: {e}")
132
+ raise
133
+ finally:
134
+ try:
135
+ if not sock.closed:
136
+ sock.close(linger=-1)
137
+ except Exception as e:
138
+ logger.warning(f"[{self.client_id}]: Error closing socket to Controller {server_info.id}: {e}")
139
+
140
+ context.term()
141
+
142
+ return wrapper
143
+
144
+ return decorator
145
+
146
+ @dynamic_socket(socket_name="request_handle_socket")
147
+ async def async_get_meta(
148
+ self,
149
+ data_fields: list[str],
150
+ batch_size: int,
151
+ partition_id: str,
152
+ mode: str = "fetch",
153
+ task_name: Optional[str] = None,
154
+ sampling_config: Optional[dict[str, Any]] = None,
155
+ socket: Optional[zmq.asyncio.Socket] = None,
156
+ ) -> BatchMeta:
157
+ """Asynchronously fetch data metadata from the controller via ZMQ.
158
+
159
+ Args:
160
+ data_fields: List of data field names to retrieve metadata for
161
+ batch_size: Number of samples to request in the batch
162
+ partition_id: Current data partition id
163
+ mode: Data fetch mode. Options:
164
+ - 'fetch': Get ready data only
165
+ - 'force_fetch': Get data regardless of readiness (may return unready samples)
166
+ - 'insert': Internal usage - should not be used by users
167
+ task_name: Optional task name associated with the request
168
+ sampling_config: Optional sampling configuration for custom samplers.
169
+ For GRPOGroupNSampler, should include "n_samples_per_prompt": int
170
+ socket: ZMQ async socket for message transmission (injected by decorator)
171
+
172
+ Returns:
173
+ BatchMeta: Metadata object containing data structure, sample information, and readiness status
174
+
175
+ Raises:
176
+ RuntimeError: If communication fails or controller returns error response
177
+
178
+ Example:
179
+ >>> # Example 1: Basic fetch metadata
180
+ >>> batch_meta = asyncio.run(client.async_get_meta(
181
+ ... data_fields=["input_ids", "attention_mask"],
182
+ ... batch_size=4,
183
+ ... partition_id="train_0",
184
+ ... mode="fetch",
185
+ ... task_name="generate_sequences"
186
+ ... ))
187
+ >>> print(batch_meta.is_ready) # True if all samples ready
188
+ >>>
189
+ >>> # Example 2: Fetch with self-defined samplers (using GRPOGroupNSampler as an example)
190
+ >>> batch_meta = asyncio.run(client.async_get_meta(
191
+ ... data_fields=["input_ids", "attention_mask"],
192
+ ... batch_size=8,
193
+ ... partition_id="train_0",
194
+ ... mode="fetch",
195
+ ... task_name="generate_sequences",
196
+ ... sampling_config={"n_samples_per_prompt": 4}
197
+ ... ))
198
+ >>> print(batch_meta.is_ready) # True if all samples ready
199
+ >>>
200
+ >>> # Example 3: Force fetch metadata (bypass production status check and Sampler,
201
+ >>> so may include unready samples. Consumed samples will not be fetched.)
202
+ >>> batch_meta = asyncio.run(client.async_get_meta(
203
+ ... data_fields=["input_ids", "attention_mask"],
204
+ ... batch_size=4,
205
+ ... partition_id="train_0",
206
+ ... mode="force_fetch",
207
+ ... task_name="generate_sequences"
208
+ ... ))
209
+ >>> print(batch_meta.is_ready) # May be False if some samples not ready
210
+ """
211
+ assert socket is not None
212
+ request_msg = ZMQMessage.create(
213
+ request_type=ZMQRequestType.GET_META,
214
+ sender_id=self.client_id,
215
+ receiver_id=self._controller.id,
216
+ body={
217
+ "data_fields": data_fields,
218
+ "batch_size": batch_size,
219
+ "partition_id": partition_id,
220
+ "mode": mode,
221
+ "task_name": task_name,
222
+ "sampling_config": sampling_config,
223
+ },
224
+ )
225
+
226
+ try:
227
+ await socket.send(request_msg.serialize())
228
+ response = await socket.recv()
229
+ response_msg = ZMQMessage.deserialize(response)
230
+ logger.debug(
231
+ f"[{self.client_id}]: Client get datameta response: {response_msg} "
232
+ f"from controller {self._controller.id}"
233
+ )
234
+
235
+ if response_msg.request_type == ZMQRequestType.GET_META_RESPONSE:
236
+ metadata = response_msg.body["metadata"]
237
+ return metadata
238
+ else:
239
+ raise RuntimeError(
240
+ f"[{self.client_id}]: Failed to get metadata from controller {self._controller.id}: "
241
+ f"{response_msg.body.get('message', 'Unknown error')}"
242
+ )
243
+ except Exception as e:
244
+ raise RuntimeError(f"[{self.client_id}]: Error in get_meta: {str(e)}") from e
245
+
246
+ async def async_put(
247
+ self,
248
+ data: TensorDict,
249
+ metadata: Optional[BatchMeta] = None,
250
+ partition_id: Optional[str] = None,
251
+ ):
252
+ """Asynchronously write data to storage units based on metadata.
253
+
254
+ If metadata is not provided, it will be created automatically using insert mode
255
+ with the provided data fields and partition_id.
256
+
257
+ Note:
258
+ When using multiple workers for distributed execution, there may be data
259
+ ordering inconsistencies between workers during put operations.
260
+
261
+ Args:
262
+ data: Data to write as TensorDict
263
+ metadata: Records the metadata of a batch of data samples, containing index and
264
+ storage unit information. If None, metadata will be auto-generated.
265
+ partition_id: Target data partition id (required if metadata is not provided)
266
+
267
+ Raises:
268
+ ValueError: If metadata is None or empty, or if partition_id is None when metadata is not provided
269
+ RuntimeError: If storage operation fails
270
+
271
+ Example:
272
+ >>> batch_size = 4
273
+ >>> seq_len = 16
274
+ >>> current_partition_id = "train_0"
275
+ >>> # Example 1: Normal usage with existing metadata
276
+ >>> batch_meta = asyncio.run(client.async_get_meta(
277
+ ... data_fields=["prompts", "attention_mask"],
278
+ ... batch_size=batch_size,
279
+ ... partition_id=current_partition_id,
280
+ ... mode="fetch",
281
+ ... task_name="generate_sequences",
282
+ ... ))
283
+ >>> batch = asyncio.run(client.async_get_data(batch_meta))
284
+ >>> output = TensorDict({"response": torch.randn(batch_size, seq_len)})
285
+ >>> asyncio.run(client.async_put(data=output, metadata=batch_meta))
286
+ >>>
287
+ >>> # Example 2: Initial data insertion without pre-existing metadata
288
+ >>> # BE CAREFUL: this usage may overwrite any unconsumed data in the given partition_id!
289
+ >>> # Please make sure the corresponding partition_id is empty before calling the async_put()
290
+ >>> # without metadata.
291
+ >>> # Now we only support put all the data of the corresponding partition id in once. You should repeat with
292
+ >>> # interleave the initial data if n_sample > 1 before calling the async_put().
293
+ >>> original_prompts = torch.randn(batch_size, seq_len)
294
+ >>> n_samples = 4
295
+ >>> prompts_repeated = torch.repeat_interleave(original_prompts, n_samples, dim=0)
296
+ >>> prompts_repeated_batch = TensorDict({"prompts": prompts_repeated})
297
+ >>> # This will create metadata in "insert" mode internally.
298
+ >>> asyncio.run(client.async_put(data=prompts_repeated_batch, partition_id=current_partition_id))
299
+
300
+ """
301
+
302
+ if not hasattr(self, "storage_manager") or self.storage_manager is None:
303
+ raise RuntimeError(
304
+ f"[{self.client_id}]: Storage manager not initialized. "
305
+ "Call initialize_storage_manager() before performing storage operations."
306
+ )
307
+
308
+ if metadata is None:
309
+ if partition_id is None:
310
+ raise ValueError("partition_id must be provided if metadata is not given")
311
+
312
+ metadata = await self.async_get_meta(
313
+ data_fields=list(data.keys()),
314
+ batch_size=data.batch_size[0],
315
+ partition_id=partition_id,
316
+ mode="insert",
317
+ )
318
+
319
+ if not metadata or metadata.size == 0:
320
+ raise ValueError("metadata cannot be none or empty")
321
+ logger.debug(f"[{self.client_id}]: Put data with data: {data}")
322
+
323
+ await self.storage_manager.put_data(data, metadata)
324
+
325
+ logger.info(
326
+ f"[{self.client_id}]: partition {partition_id} put {metadata.size} samples to storage units successfully."
327
+ )
328
+
329
+ async def async_get_data(self, metadata: BatchMeta) -> TensorDict:
330
+ """Asynchronously fetch data from storage units and organize into TensorDict.
331
+
332
+ Args:
333
+ metadata: Batch metadata containing data location information and global indexes
334
+
335
+ Returns:
336
+ TensorDict containing:
337
+ - Requested data fields (e.g., "prompts", "attention_mask")
338
+
339
+ Example:
340
+ >>> batch_meta = asyncio.run(client.async_get_meta(
341
+ ... data_fields=["prompts", "attention_mask"],
342
+ ... batch_size=4,
343
+ ... partition_id="train_0",
344
+ ... mode="fetch",
345
+ ... task_name="generate_sequences",
346
+ ... ))
347
+ >>> batch = asyncio.run(client.async_get_data(batch_meta))
348
+ >>> print(batch)
349
+ >>> # TensorDict with fields "prompts", "attention_mask", and sample order matching metadata global_indexes
350
+
351
+ """
352
+
353
+ if not hasattr(self, "storage_manager") or self.storage_manager is None:
354
+ raise RuntimeError(
355
+ f"[{self.client_id}]: Storage manager not initialized. "
356
+ "Call initialize_storage_manager() before performing storage operations."
357
+ )
358
+
359
+ if not metadata or metadata.size == 0:
360
+ return TensorDict({}, batch_size=0)
361
+
362
+ results = await self.storage_manager.get_data(metadata)
363
+
364
+ return results
365
+
366
+ async def async_clear(self, partition_id: str):
367
+ """Asynchronously clear data from all storage units and controller metadata.
368
+
369
+ Args:
370
+ partition_id: The partition id to clear data for
371
+
372
+ Raises:
373
+ RuntimeError: If clear operation fails
374
+ """
375
+ try:
376
+ if not hasattr(self, "storage_manager") or self.storage_manager is None:
377
+ raise RuntimeError(
378
+ f"[{self.client_id}]: Storage manager not initialized. "
379
+ "Call initialize_storage_manager() before performing storage operations."
380
+ )
381
+
382
+ if not self._controller:
383
+ raise RuntimeError("No controller registered")
384
+
385
+ metadata = await self._get_clear_meta(partition_id)
386
+
387
+ # Clear the controller metadata
388
+ await self._clear_controller(partition_id)
389
+
390
+ # Clear storage unit data
391
+ await self.storage_manager.clear_data(metadata)
392
+
393
+ logger.info(f"[{self.client_id}]: Clear operation for partition_id {partition_id} completed.")
394
+ except Exception as e:
395
+ raise RuntimeError(f"Error in clear operation: {str(e)}") from e
396
+
397
+ @dynamic_socket(socket_name="request_handle_socket")
398
+ async def _get_clear_meta(self, partition_id: str, socket=None) -> BatchMeta:
399
+ """Get metadata required for clear operation from controller.
400
+
401
+ Args:
402
+ partition_id: Partition id to get clear metadata for
403
+ socket: ZMQ socket (injected by decorator)
404
+
405
+ Returns:
406
+ BatchMeta: Records the metadata of a batch of data samples.
407
+
408
+ Raises:
409
+ RuntimeError: If controller returns error response
410
+ """
411
+ request_msg = ZMQMessage.create(
412
+ request_type=ZMQRequestType.GET_CLEAR_META,
413
+ sender_id=self.client_id,
414
+ receiver_id=self._controller.id,
415
+ body={"partition_id": partition_id},
416
+ )
417
+
418
+ await socket.send(request_msg.serialize())
419
+ serialized = await socket.recv()
420
+ response_msg = ZMQMessage.deserialize(serialized)
421
+
422
+ if response_msg.request_type != ZMQRequestType.GET_CLEAR_META_RESPONSE:
423
+ raise RuntimeError(
424
+ f"Failed to get metadata for clear operation: {response_msg.body.get('message', 'Unknown error')}"
425
+ )
426
+
427
+ return response_msg.body["metadata"]
428
+
429
+ @dynamic_socket(socket_name="request_handle_socket")
430
+ async def _clear_controller(self, partition_id, socket=None):
431
+ """Clear metadata from controller.
432
+
433
+ Args:
434
+ partition_id: Partition id to clear metadata for
435
+ socket: ZMQ socket (injected by decorator)
436
+
437
+ Raises:
438
+ RuntimeError: If clear operation fails
439
+ """
440
+ try:
441
+ request_msg = ZMQMessage.create(
442
+ request_type=ZMQRequestType.CLEAR_META,
443
+ sender_id=self.client_id,
444
+ receiver_id=self._controller.id,
445
+ body={"partition_id": partition_id},
446
+ )
447
+
448
+ await socket.send(request_msg.serialize())
449
+ serialized_msg = await socket.recv()
450
+ response_msg = ZMQMessage.deserialize(serialized_msg)
451
+
452
+ if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE:
453
+ raise RuntimeError(
454
+ f"Failed to clear controller {self._controller.id}: "
455
+ f"{response_msg.body.get('message', 'Unknown error')}"
456
+ )
457
+
458
+ logger.info(
459
+ f"[{self.client_id}]: Successfully clear controller {self._controller.id} for partition_id "
460
+ f"{partition_id}"
461
+ )
462
+ except Exception as e:
463
+ logger.error(f"[{self.client_id}]: Error clearing controller {self._controller.id}: {str(e)}")
464
+ raise
465
+
466
+ @dynamic_socket(socket_name="request_handle_socket")
467
+ async def check_data_consumption_status(self, task_name: str, partition_id: str):
468
+ """Check if all samples for current step have been consumed.
469
+
470
+ Args:
471
+ task_name: Name of the task to check consumption for
472
+ partition_id: Partition id to check consumption status for
473
+ """
474
+ # TODO: Implement this method to check if all samples for the current step has been consumed
475
+ pass
476
+
477
+ @dynamic_socket(socket_name="request_handle_socket")
478
+ async def check_data_production_status(self, data_fields: list[str], partition_id: str):
479
+ """Check if all samples for current partition are ready for consumption.
480
+
481
+ Args:
482
+ data_fields: Data fields to check production status for
483
+ partition_id: Partition id to check production status for
484
+ """
485
+ # TODO: Implement this method to check if all samples for the current step is ready for consumption
486
+ pass
487
+
488
+ def close(self) -> None:
489
+ """Close the client and cleanup resources including storage manager."""
490
+ try:
491
+ if hasattr(self, "storage_manager") and self.storage_manager:
492
+ if hasattr(self.storage_manager, "close"):
493
+ self.storage_manager.close()
494
+ except Exception as e:
495
+ logger.warning(f"Error closing storage manager: {e}")
496
+
497
+
498
+ class TransferQueueClient(AsyncTransferQueueClient):
499
+ """Synchronous client wrapper for TransferQueue.
500
+
501
+ Provides synchronous versions of all async methods for convenience.
502
+ """
503
+
504
+ def __init__(
505
+ self,
506
+ client_id: str,
507
+ controller_info: ZMQServerInfo,
508
+ ):
509
+ """Initialize the synchronous TransferQueue client.
510
+
511
+ Args:
512
+ client_id: Unique identifier for this client instance
513
+ controller_info: Single controller ZMQ server information
514
+ """
515
+ super().__init__(
516
+ client_id,
517
+ controller_info,
518
+ )
519
+
520
+ def put(self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None):
521
+ """Synchronously write data to storage units.
522
+
523
+ Args:
524
+ data: Data to write as TensorDict
525
+ metadata: Optional metadata containing index and storage unit information
526
+ partition_id: Target data partition id (required if metadata is not provided)
527
+ """
528
+ return asyncio.run(self.async_put(data, metadata, partition_id))
529
+
530
+ def get_meta(
531
+ self,
532
+ data_fields: list[str],
533
+ batch_size: int,
534
+ partition_id: str,
535
+ task_name: Optional[str] = None,
536
+ sampling_config: Optional[dict[str, Any]] = None,
537
+ ) -> BatchMeta:
538
+ """Synchronously fetch data metadata from controller.
539
+
540
+ Args:
541
+ data_fields: List of data field names to retrieve metadata for
542
+ batch_size: Number of samples to request in the batch
543
+ partition_id: Target data partition id
544
+ task_name: Optional task name associated with the request
545
+ sampling_config: Optional sampling configuration for custom samplers.
546
+ For GRPOGroupNSampler, should include "n_samples_per_prompt": int
547
+
548
+ Returns:
549
+ BatchMeta: Batch metadata containing data location information
550
+ """
551
+ return asyncio.run(
552
+ self.async_get_meta(
553
+ data_fields=data_fields,
554
+ batch_size=batch_size,
555
+ partition_id=partition_id,
556
+ task_name=task_name,
557
+ sampling_config=sampling_config,
558
+ )
559
+ )
560
+
561
+ def get_data(self, metadata: BatchMeta) -> TensorDict:
562
+ """Synchronously fetch data from storage units.
563
+
564
+ Args:
565
+ metadata: Batch metadata containing data location information
566
+
567
+ Returns:
568
+ TensorDict containing requested data fields
569
+ """
570
+ return asyncio.run(self.async_get_data(metadata))
571
+
572
+ def clear(self, partition_id: str):
573
+ """Synchronously clear data from storage units and controller metadata.
574
+
575
+ Args:
576
+ partition_id: The partition id to clear data for
577
+ """
578
+ return asyncio.run(self.async_clear(partition_id))
579
+
580
+
581
+ def process_zmq_server_info(
582
+ handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]]
583
+ | Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"],
584
+ ): # noqa: UP007
585
+ """Extract ZMQ server information from handler objects.
586
+
587
+ Args:
588
+ handlers: Dictionary of handler objects (controllers, storage managers, or storage units),
589
+ or a single handler object
590
+
591
+ Returns:
592
+ If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information
593
+ If handlers is a single object: ZMQ server information for that object
594
+
595
+ Examples:
596
+ >>> # Single handler
597
+ >>> controller = TransferQueueController.remote(...)
598
+ >>> info = process_zmq_server_info(controller)
599
+ >>>
600
+ >>> # Multiple handlers
601
+ >>> handlers = {"storage_0": storage_0, "storage_1": storage_1}
602
+ >>> info_dict = process_zmq_server_info(handlers)"""
603
+ # Handle single handler object case
604
+ if not isinstance(handlers, dict):
605
+ return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[attr-defined]
606
+ else:
607
+ # Handle dictionary case
608
+ server_info = {}
609
+ for name, handler in handlers.items():
610
+ server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined]
611
+ return server_info