TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,460 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import itertools
15
+ import logging
16
+ import os
17
+ import time
18
+ from abc import ABC, abstractmethod
19
+ from typing import Any
20
+ from uuid import uuid4
21
+
22
+ import torch
23
+ import zmq
24
+ from tensordict import TensorDict
25
+ from torch import Tensor
26
+
27
+ from transfer_queue.metadata import BatchMeta
28
+ from transfer_queue.storage.clients.factory import StorageClientFactory
29
+ from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket
30
+
31
+ logger = logging.getLogger(__name__)
32
+ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
33
+
34
+
35
+ # ZMQ timeouts (in seconds) and retry configurations
36
+ TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5))
37
+ TQ_STORAGE_HANDSHAKE_TIMEOUT = int(os.environ.get("TQ_STORAGE_HANDSHAKE_TIMEOUT", 30))
38
+ TQ_STORAGE_HANDSHAKE_RETRY_INTERVAL = int(os.environ.get("TQ_STORAGE_HANDSHAKE_RETRY_INTERVAL", 1))
39
+ TQ_STORAGE_HANDSHAKE_MAX_RETRIES = int(os.environ.get("TQ_STORAGE_HANDSHAKE_MAX_RETRIES", 3))
40
+ TQ_DATA_UPDATE_RESPONSE_TIMEOUT = int(os.environ.get("TQ_DATA_UPDATE_RESPONSE_TIMEOUT", 30))
41
+
42
+
43
+ class TransferQueueStorageManager(ABC):
44
+ """Base class for storage layer. It defines the interface for data operations and
45
+ generally provides handshake & notification capabilities."""
46
+
47
+ def __init__(self, config: dict[str, Any]):
48
+ self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}"
49
+ self.config = config
50
+ self.controller_info = config.get("controller_info", None) # type: ZMQServerInfo
51
+
52
+ self.data_status_update_socket = None
53
+ self.controller_handshake_socket = None
54
+
55
+ self.zmq_context = None
56
+ self._connect_to_controller()
57
+
58
+ def _connect_to_controller(self) -> None:
59
+ """Initialize ZMQ sockets between storage unit and controller for handshake."""
60
+ if not isinstance(self.controller_info, ZMQServerInfo):
61
+ raise ValueError(f"controller_info should be ZMQServerInfo, but got {type(self.controller_info)}")
62
+
63
+ try:
64
+ # create zmq context
65
+ self.zmq_context = zmq.Context()
66
+
67
+ # create zmq sockets for handshake and data status update
68
+ self.controller_handshake_socket = create_zmq_socket(
69
+ self.zmq_context,
70
+ zmq.DEALER,
71
+ identity=f"{self.storage_manager_id}-controller_handshake_socket-{uuid4().hex[:8]}".encode(),
72
+ )
73
+ self.data_status_update_socket = create_zmq_socket(
74
+ self.zmq_context,
75
+ zmq.DEALER,
76
+ identity=f"{self.storage_manager_id}-data_status_update_socket-{uuid4().hex[:8]}".encode(),
77
+ )
78
+ self.data_status_update_socket.connect(self.controller_info.to_addr("data_status_update_socket"))
79
+
80
+ # do handshake with controller
81
+ self._do_handshake_with_controller()
82
+
83
+ except Exception as e:
84
+ logger.error(f"Failed to connect to controller: {e}")
85
+ raise
86
+
87
+ def _do_handshake_with_controller(self) -> None:
88
+ """Handshake with controller to establish connection with retransmission mechanism."""
89
+ is_connected: bool = False
90
+ pending_connection: bool = True
91
+ handshake_retries: int = 0
92
+
93
+ # Create zmq poller for handshake confirmation between controller and storage manager
94
+ poller = zmq.Poller()
95
+
96
+ self.controller_handshake_socket.connect(self.controller_info.to_addr("handshake_socket"))
97
+ logger.debug(
98
+ f"[{self.storage_manager_id}]: Handshake connection from storage manager id #{self.storage_manager_id} "
99
+ f"to controller id #{self.controller_info.id} establish successfully."
100
+ )
101
+ poller.register(self.controller_handshake_socket, zmq.POLLIN)
102
+
103
+ # Initial handshake request send
104
+ self._send_handshake_requests()
105
+
106
+ start_time = time.time()
107
+ last_retry_time = time.time()
108
+
109
+ while (
110
+ not is_connected # Only one controller to connect to
111
+ and time.time() - start_time < TQ_STORAGE_HANDSHAKE_TIMEOUT
112
+ ):
113
+ # Check for timeout and retransmission
114
+ current_time = time.time()
115
+ if pending_connection:
116
+ if (
117
+ current_time - last_retry_time >= TQ_STORAGE_HANDSHAKE_RETRY_INTERVAL
118
+ and handshake_retries < TQ_STORAGE_HANDSHAKE_MAX_RETRIES
119
+ ):
120
+ logger.warning(
121
+ f"[{self.storage_manager_id}]: Retransmitting handshake "
122
+ f"to controller {self.controller_info.id}, "
123
+ f"attempt {handshake_retries + 1}/{TQ_STORAGE_HANDSHAKE_MAX_RETRIES}"
124
+ )
125
+ self._send_handshake_requests()
126
+ last_retry_time = current_time
127
+ handshake_retries += 1
128
+ elif handshake_retries >= TQ_STORAGE_HANDSHAKE_MAX_RETRIES:
129
+ raise TimeoutError(
130
+ f"[{self.storage_manager_id}]: Handshake with controller {self.controller_info.id} "
131
+ f"({self.controller_info.ip}) failed after "
132
+ f"{TQ_STORAGE_HANDSHAKE_MAX_RETRIES} attempts."
133
+ )
134
+
135
+ # Use shorter poll timeout for more responsive retry timing
136
+ # while maintaining overall handshake timeout behavior
137
+ poll_timeout = min(TQ_STORAGE_POLLER_TIMEOUT * 1000, 500) # Max 500ms
138
+ socks = dict(poller.poll(poll_timeout))
139
+
140
+ if (socks.get(self.controller_handshake_socket, 0) & zmq.POLLIN) and pending_connection:
141
+ try:
142
+ response_msg = ZMQMessage.deserialize(self.controller_handshake_socket.recv())
143
+
144
+ if response_msg.request_type == ZMQRequestType.HANDSHAKE_ACK:
145
+ is_connected = True
146
+ pending_connection = False
147
+ logger.debug(
148
+ f"[{self.storage_manager_id}]: Get handshake ACK response from "
149
+ f"controller id #{str(response_msg.sender_id)} to storage manager id "
150
+ f"#{self.storage_manager_id} successfully."
151
+ )
152
+ except Exception as e:
153
+ logger.warning(
154
+ f"[{self.storage_manager_id}]: Error receiving handshake "
155
+ f"response from {self.controller_info.id}: {e}"
156
+ )
157
+
158
+ def _send_handshake_requests(self) -> None:
159
+ """Send handshake request to controller."""
160
+ request_msg = ZMQMessage.create(
161
+ request_type=ZMQRequestType.HANDSHAKE,
162
+ sender_id=self.storage_manager_id,
163
+ body={
164
+ "storage_manager_id": self.storage_manager_id,
165
+ "storage_manager_type": self.__class__.__name__,
166
+ },
167
+ ).serialize()
168
+
169
+ self.controller_handshake_socket.send(request_msg)
170
+ logger.debug(
171
+ f"[{self.storage_manager_id}]: Send handshake request from storage manager id "
172
+ f"{self.storage_manager_id} to controller id #{self.controller_info.id} successfully."
173
+ )
174
+
175
+ async def notify_data_update(
176
+ self,
177
+ partition_id: str,
178
+ fields: list[str],
179
+ global_indexes: list[int],
180
+ dtypes: dict[int, dict[str, Any]],
181
+ shapes: dict[int, dict[str, Any]],
182
+ ) -> None:
183
+ """
184
+ Notify controller that new data is ready.
185
+
186
+ Args:
187
+ partition_id: Current data partition id.
188
+ fields: Data update related fields.
189
+ global_indexes: Data update related global_indexes.
190
+ dtypes: Per-field dtypes for each field, in {global_index: {field: dtype}} format.
191
+ shapes: Per-field shapes for each field, in {global_index: {field: shape}} format.
192
+ """
193
+ # Create zmq poller for notifying data update information
194
+
195
+ if not self.controller_info:
196
+ logger.warning(f"No controller connected for storage manager {self.storage_manager_id}")
197
+ return
198
+
199
+ # Create zmq poller for notifying data update information
200
+ poller = zmq.Poller()
201
+ # Note: data_status_update_socket is already connected during initialization
202
+
203
+ try:
204
+ poller.register(self.data_status_update_socket, zmq.POLLIN)
205
+
206
+ request_msg = ZMQMessage.create(
207
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE,
208
+ sender_id=self.storage_manager_id,
209
+ body={
210
+ "partition_id": partition_id,
211
+ "fields": fields,
212
+ "global_indexes": global_indexes,
213
+ "dtypes": dtypes,
214
+ "shapes": shapes,
215
+ },
216
+ ).serialize()
217
+
218
+ self.data_status_update_socket.send(request_msg)
219
+ logger.debug(
220
+ f"[{self.storage_manager_id}]: Send data status update request "
221
+ f"from storage manager id #{self.storage_manager_id} "
222
+ f"to controller id #{self.controller_info.id} successfully."
223
+ )
224
+ except Exception as e:
225
+ request_msg = ZMQMessage.create(
226
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR,
227
+ sender_id=self.storage_manager_id,
228
+ body={
229
+ "message": f"Failed to notify data status update information from "
230
+ f"storage manager id #{self.storage_manager_id}, "
231
+ f"detail error message: {str(e)}"
232
+ },
233
+ ).serialize()
234
+
235
+ self.data_status_update_socket.send(request_msg)
236
+
237
+ # Make sure controller successfully receives data status update information.
238
+ response_received: bool = False
239
+ start_time = time.time()
240
+
241
+ while (
242
+ not response_received # Only one controller to get response from
243
+ and time.time() - start_time < TQ_DATA_UPDATE_RESPONSE_TIMEOUT
244
+ ):
245
+ socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
246
+
247
+ if self.data_status_update_socket in socks:
248
+ response_msg = ZMQMessage.deserialize(self.data_status_update_socket.recv())
249
+
250
+ if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK:
251
+ response_received = True
252
+ logger.debug(
253
+ f"[{self.storage_manager_id}]: Get data status update ACK response "
254
+ f"from controller id #{response_msg.sender_id} "
255
+ f"to storage manager id #{self.storage_manager_id} successfully."
256
+ )
257
+
258
+ if not response_received:
259
+ logger.error(
260
+ f"[{self.storage_manager_id}]: Storage manager id #{self.storage_manager_id} "
261
+ f"did not receive data status update ACK response from controller."
262
+ )
263
+
264
+ @abstractmethod
265
+ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
266
+ raise NotImplementedError("Subclasses must implement put_data")
267
+
268
+ @abstractmethod
269
+ async def get_data(self, metadata: BatchMeta) -> TensorDict:
270
+ raise NotImplementedError("Subclasses must implement get_data")
271
+
272
+ @abstractmethod
273
+ async def clear_data(self, metadata: BatchMeta) -> None:
274
+ raise NotImplementedError("Subclasses must implement clear_data")
275
+
276
+ def close(self) -> None:
277
+ """Close all ZMQ sockets and context to prevent resource leaks."""
278
+ for sock in (self.controller_handshake_socket, self.data_status_update_socket):
279
+ try:
280
+ if sock and not sock.closed:
281
+ sock.close(linger=0)
282
+ except Exception as e:
283
+ logger.error(f"[{self.storage_manager_id}]: Error closing socket {sock}: {str(e)}")
284
+
285
+ try:
286
+ if self.zmq_context:
287
+ self.zmq_context.term()
288
+ except Exception as e:
289
+ logger.error(f"[{self.storage_manager_id}]: Error terminating zmq_context: {str(e)}")
290
+
291
+ def __del__(self):
292
+ """Destructor to ensure resources are cleaned up."""
293
+ try:
294
+ self.close()
295
+ except Exception as e:
296
+ logger.error(f"[{self.storage_manager_id}]: Exception during __del__: {str(e)}")
297
+
298
+
299
+ class KVStorageManager(TransferQueueStorageManager):
300
+ """
301
+ A storage manager that uses a key-value (KV) backend (e.g., YuanRong) to store and retrieve tensor data.
302
+ It maps structured metadata (BatchMeta) to flat lists of keys and values for efficient KV operations.
303
+ """
304
+
305
+ def __init__(self, config: dict[str, Any]):
306
+ """
307
+ Initialize the KVStorageManager with configuration.
308
+ """
309
+ super().__init__(config)
310
+ client_name = config.get("client_name", "Yuanrong")
311
+ self.storage_client = StorageClientFactory.create(client_name, config)
312
+
313
+ @staticmethod
314
+ def _generate_keys(metadata: BatchMeta) -> list[str]:
315
+ """
316
+ Generate KV keys in the format 'global_index@field_name' for all sample-field pairs.
317
+ Keys are generated in sorted order by field name first, then by global index,
318
+ ensuring consistent ordering for batched operations.
319
+
320
+ Args:
321
+ metadata (BatchMeta): Metadata containing global indexes and field names.
322
+ Returns:
323
+ list[str]: List of keys, e.g., ['0@field_a', '1@field_a', '0@field_b', ...]
324
+ """
325
+ return [
326
+ f"{index}@{field}"
327
+ for field, index in itertools.product(sorted(metadata.field_names), metadata.global_indexes)
328
+ ]
329
+
330
+ @staticmethod
331
+ def _generate_values(data: TensorDict) -> list[Tensor]:
332
+ """
333
+ Extract and flatten tensor values from a TensorDict in field-major order.
334
+ Values are ordered by sorted field names, then by row (sample) order within each field.
335
+ This matches the key order generated by `_generate_keys`.
336
+
337
+ Args:
338
+ data (TensorDict): Input data where keys are field names and values are tensors.
339
+ Returns:
340
+ list[Tensor]: Flattened list of tensors, e.g.,
341
+ [data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...]
342
+ """
343
+ # TODO: We will support more complex data types ( NonTensorStack/ NonTensorData/ NestedTensor)
344
+ for v in data.values():
345
+ if not torch.is_tensor(v):
346
+ raise TypeError(f"TensorDict values must be torch.Tensor, but got {type(v)}")
347
+
348
+ return [row_data for field in sorted(data.keys()) for row_data in data[field]]
349
+
350
+ @staticmethod
351
+ def _merge_tensors_to_tensordict(metadata: BatchMeta, values: list[Tensor]) -> TensorDict:
352
+ """
353
+ Reconstruct a TensorDict from a list of values using metadata.
354
+ The values list is assumed to be in the same order as keys generated by `_generate_keys`.
355
+ According to field names and global indexes in metadata, this method can determine
356
+ which dict key and which row this tensor belongs to. Then it reshapes the flat tensors list
357
+ back into a structured TensorDict .
358
+
359
+ Args:
360
+ metadata (BatchMeta): Metadata containing global indexes and field names.
361
+ values (list[Tensor]): List of tensors in field-major order.
362
+ Returns:
363
+ TensorDict: Reconstructed tensor dictionary with batch size equal to number of samples.
364
+ """
365
+ global_indexes = metadata.global_indexes
366
+ field_names = sorted(metadata.field_names)
367
+ expected_length = len(global_indexes) * len(field_names)
368
+ if len(values) != expected_length:
369
+ raise ValueError(f"Length of values ({len(values)}) does not match expected ({expected_length})")
370
+
371
+ if len(values) == 0:
372
+ return TensorDict({}, batch_size=len(global_indexes))
373
+
374
+ merged_data: dict[str, list[Tensor]] = {field: [] for field in field_names}
375
+
376
+ # Group values by field_name
377
+ value_idx = 0
378
+ for field in field_names:
379
+ for _ in range(len(global_indexes)):
380
+ merged_data[field].append(values[value_idx])
381
+ value_idx += 1
382
+
383
+ # Stack or nest tensors per field
384
+ tensor_data = {}
385
+ for field, tensor_list in merged_data.items():
386
+ try:
387
+ tensor_data[field] = torch.stack(tensor_list)
388
+ except RuntimeError:
389
+ # Fallback to nested tensor if shapes are irregular
390
+ tensor_data[field] = torch.nested.as_nested_tensor(tensor_list)
391
+
392
+ return TensorDict(tensor_data, batch_size=len(global_indexes))
393
+
394
+ @staticmethod
395
+ def _get_shape_type_list(metadata: BatchMeta):
396
+ """
397
+ Extract the expected shape and dtype for each field-sample pair in metadata.
398
+ The order matches the key/value order: sorted by field name, then by global index.
399
+
400
+ Args:
401
+ metadata (BatchMeta): Metadata containing sample and field information.
402
+ Returns:
403
+ tuple[list[torch.Size], list[torch.dtype]]: Two lists containing the shape and dtype
404
+ for each tensor to be retrieved.
405
+ """
406
+ shapes = []
407
+ dtypes = []
408
+ for field_name in sorted(metadata.field_names):
409
+ for index in range(len(metadata)):
410
+ field = metadata.samples[index].get_field_by_name(field_name)
411
+ shapes.append(field.shape)
412
+ dtypes.append(field.dtype)
413
+ return shapes, dtypes
414
+
415
+ # TODO: Test put_data/get_data/clear_data with YuanrongStorageClient
416
+ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
417
+ """
418
+ Store tensor data in the backend storage and notify the controller.
419
+
420
+ Serializes the input tensors, stores them using the storage client,
421
+ extracts per-sample dtype and shape information, and sends a notification
422
+ to the controller that new data is available.
423
+ """
424
+ keys = self._generate_keys(metadata)
425
+ values = self._generate_values(data)
426
+ self.storage_client.put(keys=keys, values=values)
427
+ per_field_dtypes = {}
428
+ per_field_shapes = {}
429
+
430
+ # Initialize the data structure for each global index
431
+ for global_idx in metadata.global_indexes:
432
+ per_field_dtypes[global_idx] = {}
433
+ per_field_shapes[global_idx] = {}
434
+
435
+ # For each field, extract dtype and shape for each sample
436
+ for field in data.keys():
437
+ for i, data_item in enumerate(data[field]):
438
+ global_idx = metadata.global_indexes[i]
439
+ per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None
440
+ per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None
441
+
442
+ # notify controller that new data is ready
443
+ await self.notify_data_update(list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes)
444
+
445
+ async def get_data(self, metadata: BatchMeta) -> TensorDict:
446
+ """
447
+ Retrieve tensor data from the backend storage.
448
+
449
+ Fetches tensors using the provided metadata, reconstructs them with the
450
+ correct shapes and dtypes, and merge them as a TensorDict according to metadata.
451
+ """
452
+ keys = self._generate_keys(metadata)
453
+ shapes, dtypes = self._get_shape_type_list(metadata)
454
+ values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes)
455
+ return self._merge_tensors_to_tensordict(metadata, values)
456
+
457
+ async def clear_data(self, metadata: BatchMeta) -> None:
458
+ """Remove stored data associated with the given metadata."""
459
+ keys = self._generate_keys(metadata)
460
+ self.storage_client.clear(keys=keys)
@@ -0,0 +1,43 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any
15
+
16
+ from transfer_queue.storage.managers.base import TransferQueueStorageManager
17
+
18
+
19
+ class TransferQueueStorageManagerFactory:
20
+ """Factory that creates a StorageManager instance."""
21
+
22
+ _registry: dict[str, type[TransferQueueStorageManager]] = {}
23
+
24
+ @classmethod
25
+ def register(cls, manager_type: str):
26
+ def decorator(manager_cls: type[TransferQueueStorageManager]):
27
+ if not issubclass(manager_cls, TransferQueueStorageManager):
28
+ raise TypeError(
29
+ f"manager_cls {getattr(manager_cls, '__name__', repr(manager_cls))} must be "
30
+ f"a subclass of TransferQueueStorageManager"
31
+ )
32
+ cls._registry[manager_type] = manager_cls
33
+ return manager_cls
34
+
35
+ return decorator
36
+
37
+ @classmethod
38
+ def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager:
39
+ if manager_type not in cls._registry:
40
+ raise ValueError(
41
+ f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}"
42
+ )
43
+ return cls._registry[manager_type](config)