TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,611 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import asyncio
15
+ import logging
16
+ import os
17
+ from collections.abc import Mapping
18
+ from functools import wraps
19
+ from operator import itemgetter
20
+ from typing import Any, Callable
21
+ from uuid import uuid4
22
+
23
+ import torch
24
+ import zmq
25
+ from tensordict import NonTensorStack, TensorDict
26
+
27
+ from transfer_queue.metadata import BatchMeta
28
+ from transfer_queue.storage.managers.base import TransferQueueStorageManager
29
+ from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
30
+ from transfer_queue.storage.simple_backend import StorageMetaGroup
31
+ from transfer_queue.utils.utils import limit_pytorch_auto_parallel_threads
32
+ from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket
33
+
34
+ logger = logging.getLogger(__name__)
35
+ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
36
+
37
+
38
+ @TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager")
39
+ class AsyncSimpleStorageManager(TransferQueueStorageManager):
40
+ """Asynchronous storage manager that handles multiple storage units.
41
+
42
+ This manager provides async put/get/clear operations across multiple SimpleStorageUnit
43
+ instances using ZMQ communication and dynamic socket management.
44
+ """
45
+
46
+ def __init__(self, config: dict[str, Any]):
47
+ super().__init__(config)
48
+
49
+ self.config = config
50
+ server_infos = config.get("storage_unit_infos", None) # type: ZMQServerInfo | dict[str, ZMQServerInfo]
51
+
52
+ if server_infos is None:
53
+ raise ValueError("AsyncSimpleStorageManager requires non-empty 'storage_unit_infos' in config.")
54
+
55
+ self.storage_unit_infos = self._register_servers(server_infos)
56
+ self._build_storage_mapping_functions()
57
+
58
+ def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerInfo]"):
59
+ """Register and validate server information.
60
+
61
+ Args:
62
+ server_infos: ZMQServerInfo | dict[Any, ZMQServerInfo])
63
+ ZMQServerInfo or dict of server infos to register.
64
+
65
+ Returns:
66
+ Dictionary with server IDs as keys and ZMQServerInfo objects as values.
67
+
68
+ Raises:
69
+ ValueError: If server_infos format is invalid.
70
+ """
71
+ server_infos_transform = {}
72
+
73
+ if isinstance(server_infos, ZMQServerInfo):
74
+ server_infos_transform[server_infos.id] = server_infos
75
+ elif isinstance(server_infos, Mapping):
76
+ for k, v in server_infos.items():
77
+ if not isinstance(v, ZMQServerInfo):
78
+ raise ValueError(f"Invalid server info for key {k}: {v}")
79
+ server_infos_transform[v.id] = v
80
+ else:
81
+ raise ValueError(f"Invalid server infos: {server_infos}")
82
+
83
+ return server_infos_transform
84
+
85
+ def _build_storage_mapping_functions(self):
86
+ """Build mapping functions for global index to storage unit and local index.
87
+
88
+ Creates round-robin mapping functions to distribute data across storage units.
89
+ """
90
+ self.global_index_storage_unit_mapping = lambda x: list(self.storage_unit_infos.keys())[
91
+ x % len(self.storage_unit_infos)
92
+ ]
93
+ self.global_index_local_index_mapping = lambda x: x // len(self.storage_unit_infos)
94
+
95
+ # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
96
+ @staticmethod
97
+ def dynamic_storage_manager_socket(socket_name: str):
98
+ """Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close).
99
+
100
+ Args:
101
+ socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port").
102
+
103
+ Decorated Function Rules:
104
+ 1. Must be an async class method (needs `self`).
105
+ 2. `self` requires:
106
+ - `storage_unit_infos: storage unit infos (ZMQServerInfo | dict[Any, ZMQServerInfo]).
107
+ 3. Specify target server via:
108
+ - `target_storage_unit` arg.
109
+ 4. Receives ZMQ socket via `socket` keyword arg (injected by decorator).
110
+ """
111
+
112
+ def decorator(func: Callable):
113
+ @wraps(func)
114
+ async def wrapper(self, *args, **kwargs):
115
+ server_key = kwargs.get("target_storage_unit")
116
+ if server_key is None:
117
+ for arg in args:
118
+ if isinstance(arg, str) and arg in self.storage_unit_infos.keys():
119
+ server_key = arg
120
+ break
121
+
122
+ server_info = self.storage_unit_infos.get(server_key)
123
+
124
+ if not server_info:
125
+ raise RuntimeError(f"Server {server_key} not found in registered servers")
126
+
127
+ context = zmq.asyncio.Context()
128
+ address = f"tcp://{server_info.ip}:{server_info.ports.get(socket_name)}"
129
+ identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
130
+ sock = create_zmq_socket(context, zmq.DEALER, identity=identity)
131
+
132
+ try:
133
+ sock.connect(address)
134
+ # Timeouts to avoid indefinite await on recv/send
135
+ sock.setsockopt(zmq.RCVTIMEO, 10_000) # 10s
136
+ sock.setsockopt(zmq.SNDTIMEO, 10_000) # 10s
137
+ logger.info(
138
+ f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} "
139
+ f"with identity {identity.decode()}"
140
+ )
141
+
142
+ kwargs["socket"] = sock
143
+ return await func(self, *args, **kwargs)
144
+ except Exception as e:
145
+ logger.error(
146
+ f"[{self.storage_manager_id}]: Error in socket operation with StorageUnit {server_info.id}: {e}"
147
+ )
148
+ raise
149
+ finally:
150
+ try:
151
+ if not sock.closed:
152
+ sock.close(linger=-1)
153
+ except Exception as e:
154
+ logger.warning(
155
+ f"[{self.storage_manager_id}]: Error closing socket to StorageUnit {server_info.id}: {e}"
156
+ )
157
+
158
+ context.term()
159
+
160
+ return wrapper
161
+
162
+ return decorator
163
+
164
+ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
165
+ """
166
+ Send data to remote StorageUnit based on metadata.
167
+
168
+ Args:
169
+ data: TensorDict containing the data to store.
170
+ metadata: BatchMeta containing storage location information.
171
+ """
172
+
173
+ # group samples by storage unit
174
+ storage_meta_groups = build_storage_meta_groups(
175
+ metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
176
+ )
177
+
178
+ # send data to each storage unit
179
+ tasks = [
180
+ self._put_to_single_storage_unit(get_transfer_data(meta_group, data), target_storage_unit=storage_id)
181
+ for storage_id, meta_group in storage_meta_groups.items()
182
+ ]
183
+ await asyncio.gather(*tasks)
184
+
185
+ # Gather per-field dtype and shape information for each field
186
+ # global_indexes, local_indexes, and field_data correspond one-to-one
187
+ per_field_dtypes = {}
188
+ per_field_shapes = {}
189
+
190
+ # Initialize the data structure for each global index
191
+ for global_idx in metadata.global_indexes:
192
+ per_field_dtypes[global_idx] = {}
193
+ per_field_shapes[global_idx] = {}
194
+
195
+ # For each field, extract dtype and shape for each sample
196
+ for field in data.keys():
197
+ for i, data_item in enumerate(data[field]):
198
+ global_idx = metadata.global_indexes[i]
199
+ per_field_dtypes[global_idx][field] = data_item.dtype if hasattr(data_item, "dtype") else None
200
+ per_field_shapes[global_idx][field] = data_item.shape if hasattr(data_item, "shape") else None
201
+
202
+ # Get current data partition id
203
+ # Note: Currently we only support putting to & getting data from a single data partition simultaneously,
204
+ # but in the future we may support putting to & getting data from multiple data partitions concurrently.
205
+ partition_id = metadata.samples[0].partition_id
206
+
207
+ # notify controller that new data is ready
208
+ await self.notify_data_update(
209
+ partition_id, list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes
210
+ )
211
+
212
+ @dynamic_storage_manager_socket(socket_name="put_get_socket")
213
+ async def _put_to_single_storage_unit(self, transfer_data: dict[str, Any], target_storage_unit=None, socket=None):
214
+ """
215
+ Send data to a specific storage unit.
216
+ """
217
+ local_indexes = transfer_data["local_indexes"]
218
+
219
+ tensordict_data = TensorDict(
220
+ {
221
+ field: (
222
+ torch.nested.as_nested_tensor(transfer_data["field_data"][field])
223
+ if transfer_data["field_data"][field]
224
+ and all(isinstance(x, torch.Tensor) for x in transfer_data["field_data"][field])
225
+ else NonTensorStack(*transfer_data["field_data"][field])
226
+ )
227
+ for field in transfer_data["field_data"]
228
+ }
229
+ )
230
+
231
+ request_msg = ZMQMessage.create(
232
+ request_type=ZMQRequestType.PUT_DATA,
233
+ sender_id=self.storage_manager_id,
234
+ receiver_id=target_storage_unit,
235
+ body={"local_indexes": local_indexes, "data": tensordict_data},
236
+ )
237
+
238
+ try:
239
+ await socket.send(request_msg.serialize())
240
+ serialized = await socket.recv()
241
+ response_msg = ZMQMessage.deserialize(serialized)
242
+
243
+ if response_msg.request_type != ZMQRequestType.PUT_DATA_RESPONSE:
244
+ raise RuntimeError(
245
+ f"Failed to put data to storage unit {target_storage_unit}: "
246
+ f"{response_msg.body.get('message', 'Unknown error')}"
247
+ )
248
+ except Exception as e:
249
+ raise RuntimeError(f"Error in put to storage unit {target_storage_unit}: {str(e)}") from e
250
+
251
+ async def get_data(self, metadata: BatchMeta) -> TensorDict:
252
+ """
253
+ Retrieve data from remote StorageUnit based on metadata.
254
+
255
+ Args:
256
+ metadata: BatchMeta that contains metadata for data retrieval.
257
+
258
+ Returns:
259
+ TensorDict containing the retrieved data.
260
+ """
261
+
262
+ # group samples by storage unit
263
+ storage_meta_groups = build_storage_meta_groups(
264
+ metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
265
+ )
266
+
267
+ # retrive data
268
+ tasks = [
269
+ self._get_from_single_storage_unit(meta_group.get_transfer_data(), target_storage_unit=storage_id)
270
+ for storage_id, meta_group in storage_meta_groups.items()
271
+ ]
272
+
273
+ results = await asyncio.gather(*tasks)
274
+
275
+ # post-process data segments to generate a batch of data
276
+ merged_data: dict[int, dict[str, torch.Tensor]] = {}
277
+ for global_indexes, fields, data_from_single_storage_unit in results:
278
+ field_getter = itemgetter(*fields)
279
+ field_values = field_getter(data_from_single_storage_unit)
280
+
281
+ if len(fields) == 1:
282
+ extracted_data = {fields[0]: field_values}
283
+ else:
284
+ extracted_data = dict(zip(fields, field_values, strict=False))
285
+
286
+ for idx, global_idx in enumerate(global_indexes):
287
+ if global_idx not in merged_data:
288
+ merged_data[global_idx] = {}
289
+ merged_data[global_idx].update({field: extracted_data[field][idx] for field in fields})
290
+
291
+ ordered_data: dict[str, list[torch.Tensor]] = {}
292
+ for field in metadata.field_names:
293
+ ordered_data[field] = [merged_data[global_idx][field] for global_idx in metadata.global_indexes]
294
+
295
+ with limit_pytorch_auto_parallel_threads():
296
+ tensor_data = {
297
+ field: (
298
+ torch.stack(torch.nested.as_nested_tensor(v).unbind())
299
+ if v
300
+ and all(isinstance(item, torch.Tensor) for item in v)
301
+ and all(item.shape == v[0].shape for item in v)
302
+ else (
303
+ torch.nested.as_nested_tensor(v)
304
+ if v and all(isinstance(item, torch.Tensor) for item in v)
305
+ else NonTensorStack(*v)
306
+ )
307
+ )
308
+ for field, v in ordered_data.items()
309
+ }
310
+
311
+ return TensorDict(tensor_data, batch_size=len(metadata))
312
+
313
+ @dynamic_storage_manager_socket(socket_name="put_get_socket")
314
+ async def _get_from_single_storage_unit(self, index_data, target_storage_unit=None, socket=None):
315
+ global_indexes = index_data["global_indexes"]
316
+ local_indexes = index_data["local_indexes"]
317
+ fields = index_data["fields"]
318
+
319
+ request_msg = ZMQMessage.create(
320
+ request_type=ZMQRequestType.GET_DATA,
321
+ sender_id=self.storage_manager_id,
322
+ receiver_id=target_storage_unit,
323
+ body={"local_indexes": local_indexes, "fields": fields},
324
+ )
325
+
326
+ try:
327
+ await socket.send(request_msg.serialize())
328
+ serialized = await socket.recv()
329
+ response_msg = ZMQMessage.deserialize(serialized)
330
+ logger.info(
331
+ f"[{self.storage_manager_id}]: get data response from storage unit "
332
+ f"{target_storage_unit}: {response_msg}"
333
+ )
334
+
335
+ if response_msg.request_type == ZMQRequestType.GET_DATA_RESPONSE:
336
+ # Return data and index information from this storage unit
337
+ storage_unit_data = response_msg.body["data"]
338
+ return global_indexes, fields, storage_unit_data
339
+ else:
340
+ raise RuntimeError(
341
+ f"Failed to get data from storage unit {target_storage_unit}: "
342
+ f"{response_msg.body.get('message', 'Unknown error')}"
343
+ )
344
+ except Exception as e:
345
+ raise RuntimeError(f"Error getting data from storage unit {target_storage_unit}: {str(e)}") from e
346
+
347
+ async def clear_data(self, metadata: BatchMeta) -> None:
348
+ """Clear data in remote StorageUnit.
349
+
350
+ Args:
351
+ metadata: BatchMeta that contains metadata for data clearing.
352
+ """
353
+
354
+ # group samples by storage unit
355
+ storage_meta_groups = build_storage_meta_groups(
356
+ metadata, self.global_index_storage_unit_mapping, self.global_index_local_index_mapping
357
+ )
358
+
359
+ # clear data
360
+ tasks = [
361
+ self._clear_single_storage_unit(
362
+ meta_group.get_transfer_data()["local_indexes"], target_storage_unit=storage_id
363
+ )
364
+ for storage_id, meta_group in storage_meta_groups.items()
365
+ ]
366
+
367
+ results = await asyncio.gather(*tasks, return_exceptions=True)
368
+
369
+ for i, result in enumerate(results):
370
+ if isinstance(result, Exception):
371
+ logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}")
372
+
373
+ @dynamic_storage_manager_socket(socket_name="put_get_socket")
374
+ async def _clear_single_storage_unit(self, local_indexes, target_storage_unit=None, socket=None):
375
+ try:
376
+ request_msg = ZMQMessage.create(
377
+ request_type=ZMQRequestType.CLEAR_DATA,
378
+ sender_id=self.storage_manager_id,
379
+ receiver_id=target_storage_unit,
380
+ body={"local_indexes": local_indexes},
381
+ )
382
+
383
+ await socket.send(request_msg.serialize())
384
+ serialized_msg = await socket.recv()
385
+ response_msg = ZMQMessage.deserialize(serialized_msg)
386
+
387
+ if response_msg.request_type != ZMQRequestType.CLEAR_DATA_RESPONSE:
388
+ raise RuntimeError(
389
+ f"Failed to clear storage {target_storage_unit}: "
390
+ f"{response_msg.body.get('message', 'Unknown error')}"
391
+ )
392
+
393
+ logger.info(f"[{self.storage_manager_id}]: Successfully clear storage unit {target_storage_unit}")
394
+ except Exception as e:
395
+ logger.error(f"[{self.storage_manager_id}]: Error clearing storage unit {target_storage_unit}: {str(e)}")
396
+ raise
397
+
398
+ def get_zmq_server_info(self) -> dict[str, ZMQServerInfo]:
399
+ """Get ZMQ server information for all storage units.
400
+
401
+ Returns:
402
+ Dictionary mapping storage unit IDs to their ZMQServerInfo.
403
+ """
404
+ return self.storage_unit_infos
405
+
406
+ def close(self) -> None:
407
+ """Close all ZMQ sockets and context to prevent resource leaks."""
408
+ super().close()
409
+
410
+
411
+ def get_transfer_data(
412
+ storage_meta_group: StorageMetaGroup,
413
+ data: TensorDict,
414
+ ) -> dict[str, Any]:
415
+ """Convert StorageMetaGroup and TensorDict to transfer format for put operations.
416
+
417
+ This function creates a bridge between the high-level metadata (StorageMetaGroup)
418
+ and the raw data (TensorDict), producing a transfer_dict that contains both
419
+ metadata structure and the actual field data needed for storage operations.
420
+
421
+ Key Data Flow:
422
+ 1. storage_meta_group.get_transfer_data() creates metadata structure
423
+ 2. _add_field_data() extracts data using sample_meta.batch_index as key
424
+ 3. Final transfer_dict contains both metadata and correctly ordered data
425
+
426
+ Args:
427
+ storage_meta_group: StorageMetaGroup containing SampleMeta objects with:
428
+ - sample_meta.batch_index: Position in original TensorDict (0-based)
429
+ - sample_meta.global_index: Global unique identifier
430
+ - sample_meta.local_index: Position in target storage unit
431
+ data: Raw TensorDict with actual data values (as received from client):
432
+ Format: {"field_name": [data_at_index_0, data_at_index_1, ...]}
433
+
434
+ Returns:
435
+ Complete transfer dictionary ready for storage operations:
436
+ {
437
+ "batch_indexes": [2, 0, 3], # Original TensorDict positions
438
+ "global_indexes": [10, 11, 12], # Global identifiers
439
+ "local_indexes": [4, 5, 6], # Storage locations
440
+ "fields": ["images", "labels"],
441
+ "field_data": {
442
+ "images": [img2, img0, img3], # Extracted by batch_index
443
+ "labels": [label2, label0, label3]
444
+ }
445
+ }
446
+
447
+ Example:
448
+ >>> # Client data: TensorDict with 5 samples (indices 0-4)
449
+ >>> data = TensorDict({
450
+ ... "images": [img0, img1, img2, img3, img4],
451
+ ... "labels": [label0, label1, label2, label3, label4]
452
+ ... })
453
+ >>> # MetaGroup contains samples at positions 2, 0, 3 in original data
454
+ >>> group = StorageMetaGroup("storage1")
455
+ >>> group.add_sample_meta(SampleMeta(batch_index=2, global_index=10), 4)
456
+ >>> group.add_sample_meta(SampleMeta(batch_index=0, global_index=11), 5)
457
+ >>> group.add_sample_meta(SampleMeta(batch_index=3, global_index=12), 6)
458
+ >>> transfer_dict = get_transfer_data(group, data)
459
+ >>> transfer_dict["batch_indexes"] # [2, 0, 3] - positions in original TensorDict
460
+ >>> transfer_dict["field_data"]["images"] # [img2, img0, img3] - extracted data
461
+
462
+ Note:
463
+ The critical insight is that sample_meta.batch_index is used to index into
464
+ the original TensorDict to extract the correct data items. This ensures that
465
+ even when samples are reordered or distributed across storage units,
466
+ each sample's data is correctly mapped to its metadata.
467
+ """
468
+
469
+ result = storage_meta_group.get_transfer_data(field_names=list(data.keys()))
470
+ result = _add_field_data(result, storage_meta_group, data)
471
+ return result
472
+
473
+
474
+ def _add_field_data(
475
+ transfer_dict: dict[str, Any], storage_meta_group: StorageMetaGroup, data: TensorDict
476
+ ) -> dict[str, Any]:
477
+ """Extract field data from TensorDict using sample_meta.batch_index as index.
478
+
479
+ This function bridges the gap between raw TensorDict data and the transfer format
480
+ needed for storage operations. The transfer_dict contains metadata and structure
481
+ information, while the 'data' parameter contains the actual tensor values.
482
+
483
+ Key Concept: sample_meta.batch_index represents the position of each sample's data
484
+ in the original TensorDict (received from client). This function uses batch_index
485
+ to extract the correct data items for each sample in the storage_meta_group.
486
+
487
+ Args:
488
+ transfer_dict: Dictionary containing transfer metadata with structure like:
489
+ {
490
+ "batch_indexes": [2, 0, 3], # Positions in original TensorDict
491
+ "global_indexes": [10, 11, 12], # Global identifiers
492
+ "local_indexes": [4, 5, 6], # Storage locations
493
+ "fields": ["field1", "field2"],
494
+ "field_data": {} # Will be populated by this function
495
+ }
496
+ storage_meta_group: StorageMetaGroup containing SampleMeta objects with:
497
+ - sample_meta.batch_index: Position in original TensorDict
498
+ - sample_meta.local_index: Position in storage unit
499
+ data: Raw TensorDict with actual data (as received from client):
500
+ TensorDict({"field1": [t0, t1, t2, t3, t4], "field2": [t5, t6, t7, t8, t9]})
501
+
502
+ Returns:
503
+ Updated transfer dictionary with field_data populated:
504
+ {
505
+ "batch_indexes": [2, 0, 3],
506
+ "global_indexes": [10, 11, 12],
507
+ "local_indexes": [4, 5, 6],
508
+ "fields": ["field1", "field2"],
509
+ "field_data": {
510
+ "field1": [t2, t0, t3], # Extracted by batch_index from original data
511
+ "field2": [t7, t5, t8]
512
+ }
513
+ }
514
+
515
+ Example:
516
+ >>> # Raw data from client (TensorDict index 0-4)
517
+ >>> data = TensorDict({"images": [img0, img1, img2, img3, img4]})
518
+ >>> # storage_meta_group contains samples with batch_index [2, 0, 3]
519
+ >>> transfer_dict = {
520
+ ... "fields": ["images"],
521
+ ... "batch_indexes": [2, 0, 3],
522
+ ... "local_indexes": [4, 5, 6],
523
+ ... "field_data": {}
524
+ ... }
525
+ >>> meta_group = StorageMetaGroup("storage1")
526
+ >>> meta_group.add_sample_meta(SampleMeta(batch_index=2), 4) # Extract img2
527
+ >>> meta_group.add_sample_meta(SampleMeta(batch_index=0), 5) # Extract img0
528
+ >>> meta_group.add_sample_meta(SampleMeta(batch_index=3), 6) # Extract img3
529
+ >>> result = _add_field_data(transfer_dict, meta_group, data)
530
+ >>> result["field_data"]["images"] # [img2, img0, img3] - extracted by batch_index
531
+ """
532
+ field_names = transfer_dict["fields"]
533
+ for fname in field_names:
534
+ if fname in data.keys():
535
+ index = [sample_meta.batch_index for sample_meta in storage_meta_group.sample_metas]
536
+
537
+ result = itemgetter(*index)(data[fname])
538
+ if not isinstance(result, tuple):
539
+ result = (result,)
540
+ transfer_dict["field_data"][fname] = list(result)
541
+
542
+ return transfer_dict
543
+
544
+
545
+ def build_storage_meta_groups(
546
+ batch_meta: BatchMeta,
547
+ global_index_storage_unit_mapping: Callable,
548
+ global_index_local_index_mapping: Callable,
549
+ ) -> dict[str, StorageMetaGroup]:
550
+ """Build storage meta groups from batch metadata for distributed storage.
551
+
552
+ This function is the starting point of the data distribution workflow. It analyzes
553
+ BatchMeta containing SampleMeta objects (originating from client requests) and
554
+ groups them by target storage unit based on their global_index.
555
+
556
+ Key Data Flow:
557
+ 1. BatchMeta contains SampleMeta objects with batch_index (original TensorDict position)
558
+ 2. Each SampleMeta is assigned to a storage unit using global_index mapping
559
+ 3. Local storage positions are calculated for each sample
560
+ 4. Results in StorageMetaGroup objects ready for transfer operations
561
+
562
+ Args:
563
+ batch_meta: BatchMeta containing SampleMeta objects from client request.
564
+ Each SampleMeta has:
565
+ - batch_index: Position in original TensorDict (0-based)
566
+ - global_index: Global unique identifier across all storage
567
+ global_index_storage_unit_mapping: Function to map global_index to storage_unit_id.
568
+ Example: lambda x: storage_unit_ids[x % num_storage_units] (round-robin distribution)
569
+ global_index_local_index_mapping: Function to map global_index to local_index.
570
+ Example: lambda x: x // num_storage_units (local position within storage unit)
571
+
572
+ Returns:
573
+ Dictionary mapping storage_unit_id to StorageMetaGroup, where each group contains:
574
+ - storage_id: Target storage unit identifier
575
+ - sample_metas: List of SampleMeta objects assigned to this unit
576
+ - local_indexes: List of storage positions for each sample
577
+
578
+ Example:
579
+ >>> # Input: BatchMeta with samples at global_indexes [10, 11, 12]
580
+ >>> # 3 storage units available: storage_0, storage_1, storage_2
581
+ >>> batch_meta = BatchMeta(samples=[
582
+ ... SampleMeta(batch_index=0, global_index=10), # Original position 0
583
+ ... SampleMeta(batch_index=1, global_index=11), # Original position 1
584
+ ... SampleMeta(batch_index=2, global_index=12) # Original position 2
585
+ ... ])
586
+ >>> groups = build_storage_meta_groups(
587
+ ... batch_meta,
588
+ ... lambda x: f"storage_{x % 3}", # 10->storage_1, 11->storage_2, 12->storage_0
589
+ ... lambda x: x // 3 # 10->3, 11->3, 12->4
590
+ ... )
591
+ >>> groups["storage_1"].sample_metas[0].batch_index # 0 - original TensorDict position
592
+ >>> groups["storage_1"].sample_metas[0].local_index # 3 - storage position
593
+
594
+ Note:
595
+ This function preserves the crucial batch_index information that links each
596
+ SampleMeta back to its original position in the client's TensorDict.
597
+ This batch_index is later used by _add_field_data() to extract
598
+ the correct data items for storage.
599
+ """
600
+ storage_meta_groups: dict[str, StorageMetaGroup] = {}
601
+
602
+ for sample in batch_meta.samples:
603
+ storage_id = global_index_storage_unit_mapping(sample.global_index)
604
+ local_index = global_index_local_index_mapping(sample.global_index)
605
+ if storage_id not in storage_meta_groups:
606
+ storage_meta_groups[storage_id] = StorageMetaGroup(storage_id=storage_id)
607
+
608
+ # Use add_sample_meta to store SampleMeta references directly
609
+ storage_meta_groups[storage_id].add_sample_meta(sample, local_index)
610
+
611
+ return storage_meta_groups
@@ -0,0 +1,18 @@
1
+ from typing import Any
2
+
3
+ from transfer_queue.storage.managers.base import KVStorageManager
4
+
5
+
6
+ class YuanrongStorageManager(KVStorageManager):
7
+ def __init__(self, config: dict[str, Any]):
8
+ host = config.get("host", None)
9
+ port = config.get("port", None)
10
+ device_id = config.get("device_id", None)
11
+ if host is None or not isinstance(host, str):
12
+ raise ValueError("Missing or invalid 'host' in config")
13
+ if port is None or not isinstance(port, int):
14
+ raise ValueError("Missing or invalid 'port' in config")
15
+ # TODO: device_id may be a list[int]
16
+ if device_id is None or not isinstance(device_id, int):
17
+ raise ValueError("Missing or invalid 'device_id' in config")
18
+ super().__init__(config)