TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import dataclasses
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from operator import itemgetter
|
|
20
|
+
from threading import Thread
|
|
21
|
+
from typing import Any, Optional
|
|
22
|
+
from uuid import uuid4
|
|
23
|
+
|
|
24
|
+
import ray
|
|
25
|
+
import torch
|
|
26
|
+
import zmq
|
|
27
|
+
from ray.util import get_node_ip_address
|
|
28
|
+
from tensordict import NonTensorStack, TensorDict
|
|
29
|
+
|
|
30
|
+
from transfer_queue.metadata import SampleMeta
|
|
31
|
+
from transfer_queue.utils.utils import TransferQueueRole
|
|
32
|
+
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket, get_free_port
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
36
|
+
|
|
37
|
+
# ZMQ timeouts (in seconds) and retry configurations
|
|
38
|
+
TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class StorageUnitData:
|
|
42
|
+
"""Storage unit for managing 2D data structure (samples × fields).
|
|
43
|
+
|
|
44
|
+
This class provides efficient storage and retrieval of data in a 2D matrix format
|
|
45
|
+
where rows represent samples (indexed by local_index) and columns represent fields.
|
|
46
|
+
Each field contains a list of data items indexed by their local position.
|
|
47
|
+
|
|
48
|
+
Data Structure Example:
|
|
49
|
+
┌─────────────┬─────────────┬─────────────┬─────────┐
|
|
50
|
+
│ local_index │ field_name1 │ field_name2 │ ... │
|
|
51
|
+
├─────────────┼─────────────┼─────────────┼─────────┤
|
|
52
|
+
│ 0 │ item1 │ item2 │ ... │
|
|
53
|
+
│ 1 │ item3 │ item4 │ ... │
|
|
54
|
+
│ 2 │ item5 │ item6 │ ... │
|
|
55
|
+
└─────────────┴─────────────┴─────────────┴─────────┘
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, storage_size: int):
|
|
59
|
+
# Dict containing field names and corresponding data in the field
|
|
60
|
+
# Format: {"field_name": [data_at_index_0, data_at_index_1, ...]}
|
|
61
|
+
self.field_data: dict[str, list] = {}
|
|
62
|
+
|
|
63
|
+
# Maximum number of elements stored in storage unit
|
|
64
|
+
self.storage_size = storage_size
|
|
65
|
+
|
|
66
|
+
def get_data(self, fields: list[str], local_indexes: list[int]) -> TensorDict[str, list]:
|
|
67
|
+
"""
|
|
68
|
+
Get data from storage unit according to given fields and local_indexes.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
fields: Field names used for getting data.
|
|
72
|
+
local_indexes: Local indexes used for getting data.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
TensorDict with field names as keys, corresponding data list as values.
|
|
76
|
+
"""
|
|
77
|
+
result: dict[str, list] = {}
|
|
78
|
+
|
|
79
|
+
for field in fields:
|
|
80
|
+
# Validate field name
|
|
81
|
+
if field not in self.field_data:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if len(local_indexes) == 1:
|
|
87
|
+
# The unsqueeze op make the shape from n to (1, n)
|
|
88
|
+
gathered_item = self.field_data[field][local_indexes[0]]
|
|
89
|
+
if not isinstance(gathered_item, torch.Tensor):
|
|
90
|
+
result[field] = NonTensorStack(gathered_item)
|
|
91
|
+
else:
|
|
92
|
+
result[field] = gathered_item.unsqueeze(0)
|
|
93
|
+
else:
|
|
94
|
+
gathered_items = list(itemgetter(*local_indexes)(self.field_data[field]))
|
|
95
|
+
|
|
96
|
+
if gathered_items:
|
|
97
|
+
all_tensors = all(isinstance(x, torch.Tensor) for x in gathered_items)
|
|
98
|
+
if all_tensors:
|
|
99
|
+
result[field] = torch.nested.as_nested_tensor(gathered_items)
|
|
100
|
+
else:
|
|
101
|
+
result[field] = NonTensorStack(*gathered_items)
|
|
102
|
+
|
|
103
|
+
# Explicit batch size for stability
|
|
104
|
+
batch_size = 0 if not fields or not local_indexes else len(local_indexes)
|
|
105
|
+
return TensorDict(result, batch_size=batch_size)
|
|
106
|
+
|
|
107
|
+
def put_data(self, field_data: TensorDict[str, Any], local_indexes: list[int]) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Put or update data into storage unit according to given field_data and local_indexes.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
field_data: Dict with field names as keys, corresponding data in the field as values.
|
|
113
|
+
local_indexes: Local indexes used for putting data.
|
|
114
|
+
"""
|
|
115
|
+
extracted_data = field_data.to_dict()
|
|
116
|
+
|
|
117
|
+
for f, values in extracted_data.items():
|
|
118
|
+
if f not in self.field_data:
|
|
119
|
+
self.field_data[f] = [None] * self.storage_size
|
|
120
|
+
|
|
121
|
+
for i, idx in enumerate(local_indexes):
|
|
122
|
+
if idx < 0 or idx >= self.storage_size:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond "
|
|
125
|
+
f"storage_size: {self.storage_size}"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
self.field_data[f][idx] = values[i]
|
|
129
|
+
|
|
130
|
+
def clear(self, local_indexes: list[int]) -> None:
|
|
131
|
+
"""
|
|
132
|
+
Clear data at specified local_indexes by setting all related fields to None.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
local_indexes: local_indexes to clear.
|
|
136
|
+
"""
|
|
137
|
+
# Validate local_indexes
|
|
138
|
+
for idx in local_indexes:
|
|
139
|
+
if idx < 0 or idx >= self.storage_size:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"StorageUnitData clear operation receive invalid local_index: {idx} beyond "
|
|
142
|
+
f"storage_size: {self.storage_size}"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Clear data at specified local_indexes
|
|
146
|
+
for f in self.field_data:
|
|
147
|
+
for idx in local_indexes:
|
|
148
|
+
self.field_data[f][idx] = None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@ray.remote(num_cpus=1)
|
|
152
|
+
class SimpleStorageUnit:
|
|
153
|
+
"""A storage unit that provides distributed data storage functionality.
|
|
154
|
+
|
|
155
|
+
This class represents a storage unit that can store data in a 2D structure
|
|
156
|
+
(samples × data fields) and provides ZMQ-based communication for put/get/clear operations.
|
|
157
|
+
|
|
158
|
+
Note: We use Ray decorator (@ray.remote) only for initialization purposes.
|
|
159
|
+
We do NOT use Ray's .remote() call capabilities - the storage unit runs
|
|
160
|
+
as a standalone process with its own ZMQ server socket.
|
|
161
|
+
|
|
162
|
+
Attributes:
|
|
163
|
+
storage_unit_id: Unique identifier for this storage unit.
|
|
164
|
+
storage_unit_size: Maximum number of elements that can be stored.
|
|
165
|
+
storage_data: Internal StorageUnitData instance for data management.
|
|
166
|
+
zmq_server_info: ZMQ connection information for clients.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, storage_unit_size: int):
|
|
170
|
+
"""Initialize a SimpleStorageUnit with the specified size.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
storage_unit_size: Maximum number of elements that can be stored in this storage unit.
|
|
174
|
+
"""
|
|
175
|
+
self.storage_unit_id = f"TQ_STORAGE_UNIT_{uuid4().hex[:8]}"
|
|
176
|
+
self.storage_unit_size = storage_unit_size
|
|
177
|
+
|
|
178
|
+
self.storage_data = StorageUnitData(self.storage_unit_size)
|
|
179
|
+
|
|
180
|
+
self.zmq_server_info = ZMQServerInfo(
|
|
181
|
+
role=TransferQueueRole.STORAGE,
|
|
182
|
+
id=str(self.storage_unit_id),
|
|
183
|
+
ip=get_node_ip_address(),
|
|
184
|
+
ports={"put_get_socket": get_free_port()},
|
|
185
|
+
)
|
|
186
|
+
self._init_zmq_socket()
|
|
187
|
+
self._start_process_put_get()
|
|
188
|
+
|
|
189
|
+
def _init_zmq_socket(self) -> None:
|
|
190
|
+
"""
|
|
191
|
+
Initialize ZMQ socket connections between storage unit and controller/clients:
|
|
192
|
+
- put_get_socket:
|
|
193
|
+
Handle put/get requests from clients.
|
|
194
|
+
"""
|
|
195
|
+
self.zmq_context = zmq.Context()
|
|
196
|
+
|
|
197
|
+
self.put_get_socket = create_zmq_socket(self.zmq_context, zmq.ROUTER)
|
|
198
|
+
self.put_get_socket.bind(self.zmq_server_info.to_addr("put_get_socket"))
|
|
199
|
+
|
|
200
|
+
def _start_process_put_get(self) -> None:
|
|
201
|
+
"""Create a daemon thread and start put/get process."""
|
|
202
|
+
self.process_put_get_thread = Thread(
|
|
203
|
+
target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.zmq_server_info.id}", daemon=True
|
|
204
|
+
)
|
|
205
|
+
self.process_put_get_thread.start()
|
|
206
|
+
|
|
207
|
+
def _process_put_get(self) -> None:
|
|
208
|
+
"""Process put_get_socket request."""
|
|
209
|
+
poller = zmq.Poller()
|
|
210
|
+
poller.register(self.put_get_socket, zmq.POLLIN)
|
|
211
|
+
|
|
212
|
+
while True:
|
|
213
|
+
socks = dict(poller.poll(TQ_STORAGE_POLLER_TIMEOUT * 1000))
|
|
214
|
+
|
|
215
|
+
if self.put_get_socket in socks:
|
|
216
|
+
identity, serialized_msg = self.put_get_socket.recv_multipart()
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
220
|
+
operation = request_msg.request_type
|
|
221
|
+
logger.debug(f"[{self.zmq_server_info.id}]: receive operation: {operation}, message: {request_msg}")
|
|
222
|
+
|
|
223
|
+
if operation == ZMQRequestType.PUT_DATA:
|
|
224
|
+
response_msg = self._handle_put(request_msg)
|
|
225
|
+
elif operation == ZMQRequestType.GET_DATA:
|
|
226
|
+
response_msg = self._handle_get(request_msg)
|
|
227
|
+
elif operation == ZMQRequestType.CLEAR_DATA:
|
|
228
|
+
response_msg = self._handle_clear(request_msg)
|
|
229
|
+
else:
|
|
230
|
+
response_msg = ZMQMessage.create(
|
|
231
|
+
request_type=ZMQRequestType.PUT_GET_OPERATION_ERROR,
|
|
232
|
+
sender_id=self.zmq_server_info.id,
|
|
233
|
+
body={
|
|
234
|
+
"message": f"Storage unit id #{self.zmq_server_info.id} "
|
|
235
|
+
f"receive invalid operation: {operation}."
|
|
236
|
+
},
|
|
237
|
+
)
|
|
238
|
+
except Exception as e:
|
|
239
|
+
response_msg = ZMQMessage.create(
|
|
240
|
+
request_type=ZMQRequestType.PUT_GET_ERROR,
|
|
241
|
+
sender_id=self.zmq_server_info.id,
|
|
242
|
+
body={
|
|
243
|
+
"message": f"Storage unit id #{self.zmq_server_info.id} occur error in processing "
|
|
244
|
+
f"put/get/clear request, detail error message: {str(e)}."
|
|
245
|
+
},
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
self.put_get_socket.send_multipart([identity, response_msg.serialize()])
|
|
249
|
+
|
|
250
|
+
def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage:
|
|
251
|
+
"""
|
|
252
|
+
Handle put request, add or update data into storage unit.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
data_parts: ZMQMessage from client.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Put data success response ZMQMessage.
|
|
259
|
+
"""
|
|
260
|
+
try:
|
|
261
|
+
local_indexes = data_parts.body["local_indexes"]
|
|
262
|
+
field_data = data_parts.body["data"] # field_data should be a TensorDict.
|
|
263
|
+
|
|
264
|
+
self.storage_data.put_data(field_data, local_indexes)
|
|
265
|
+
|
|
266
|
+
# After put operation finish, send a message to the client
|
|
267
|
+
response_msg = ZMQMessage.create(
|
|
268
|
+
request_type=ZMQRequestType.PUT_DATA_RESPONSE, sender_id=self.zmq_server_info.id, body={}
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
return response_msg
|
|
272
|
+
except Exception as e:
|
|
273
|
+
return ZMQMessage.create(
|
|
274
|
+
request_type=ZMQRequestType.PUT_ERROR,
|
|
275
|
+
sender_id=self.zmq_server_info.id,
|
|
276
|
+
body={
|
|
277
|
+
"message": f"Failed to put data into storage unit id "
|
|
278
|
+
f"#{self.zmq_server_info.id}, detail error message: {str(e)}"
|
|
279
|
+
},
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def _handle_get(self, data_parts: ZMQMessage) -> ZMQMessage:
|
|
283
|
+
"""
|
|
284
|
+
Handle get request, return data from storage unit.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
data_parts: ZMQMessage from client.
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
Get data success response ZMQMessage, containing target data.
|
|
291
|
+
"""
|
|
292
|
+
try:
|
|
293
|
+
fields = data_parts.body["fields"]
|
|
294
|
+
local_indexes = data_parts.body["local_indexes"]
|
|
295
|
+
|
|
296
|
+
result_data = self.storage_data.get_data(fields, local_indexes)
|
|
297
|
+
|
|
298
|
+
response_msg = ZMQMessage.create(
|
|
299
|
+
request_type=ZMQRequestType.GET_DATA_RESPONSE,
|
|
300
|
+
sender_id=self.zmq_server_info.id,
|
|
301
|
+
body={
|
|
302
|
+
"data": result_data,
|
|
303
|
+
},
|
|
304
|
+
)
|
|
305
|
+
except Exception as e:
|
|
306
|
+
response_msg = ZMQMessage.create(
|
|
307
|
+
request_type=ZMQRequestType.GET_ERROR,
|
|
308
|
+
sender_id=self.zmq_server_info.id,
|
|
309
|
+
body={
|
|
310
|
+
"message": f"Failed to get data from storage unit id #{self.zmq_server_info.id}, "
|
|
311
|
+
f"detail error message: {str(e)}"
|
|
312
|
+
},
|
|
313
|
+
)
|
|
314
|
+
return response_msg
|
|
315
|
+
|
|
316
|
+
def _handle_clear(self, data_parts: ZMQMessage) -> ZMQMessage:
|
|
317
|
+
"""
|
|
318
|
+
Handle clear request, clear data in storage unit according to given local_indexes.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
data_parts: ZMQMessage from client, including target local_indexes.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Clear data success response ZMQMessage.
|
|
325
|
+
"""
|
|
326
|
+
try:
|
|
327
|
+
local_indexes = data_parts.body["local_indexes"]
|
|
328
|
+
|
|
329
|
+
self.storage_data.clear(local_indexes)
|
|
330
|
+
|
|
331
|
+
response_msg = ZMQMessage.create(
|
|
332
|
+
request_type=ZMQRequestType.CLEAR_DATA_RESPONSE,
|
|
333
|
+
sender_id=self.zmq_server_info.id,
|
|
334
|
+
body={"message": f"Clear data in storage unit id #{self.zmq_server_info.id} successfully."},
|
|
335
|
+
)
|
|
336
|
+
except Exception as e:
|
|
337
|
+
response_msg = ZMQMessage.create(
|
|
338
|
+
request_type=ZMQRequestType.CLEAR_DATA_ERROR,
|
|
339
|
+
sender_id=self.zmq_server_info.id,
|
|
340
|
+
body={
|
|
341
|
+
"message": f"Failed to clear data in storage unit id #{self.zmq_server_info.id}, "
|
|
342
|
+
f"detail error message: {str(e)}"
|
|
343
|
+
},
|
|
344
|
+
)
|
|
345
|
+
return response_msg
|
|
346
|
+
|
|
347
|
+
def get_zmq_server_info(self) -> ZMQServerInfo:
|
|
348
|
+
"""Get the ZMQ server information for this storage unit.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
ZMQServerInfo containing connection details for this storage unit.
|
|
352
|
+
"""
|
|
353
|
+
return self.zmq_server_info
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@dataclass
|
|
357
|
+
class StorageMetaGroup:
|
|
358
|
+
"""
|
|
359
|
+
Represents a group of samples stored in the same storage unit.
|
|
360
|
+
Used to organize samples by their storage_id for efficient client operations.
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
storage_id: str
|
|
364
|
+
sample_metas: list[SampleMeta] = dataclasses.field(default_factory=list)
|
|
365
|
+
local_indexes: list[int] = dataclasses.field(default_factory=list)
|
|
366
|
+
|
|
367
|
+
def add_sample_meta(self, sample_meta: SampleMeta, local_index: int) -> None:
|
|
368
|
+
"""Add a SampleMeta object to this storage group"""
|
|
369
|
+
self.sample_metas.append(sample_meta)
|
|
370
|
+
self.local_indexes.append(local_index)
|
|
371
|
+
|
|
372
|
+
def get_batch_indexes(self) -> list[int]:
|
|
373
|
+
"""Get all internal indexes from stored SampleMeta objects"""
|
|
374
|
+
return [meta.batch_index for meta in self.sample_metas]
|
|
375
|
+
|
|
376
|
+
def get_global_indexes(self) -> list[int]:
|
|
377
|
+
"""Get all global indexes from stored SampleMeta objects"""
|
|
378
|
+
return [meta.global_index for meta in self.sample_metas]
|
|
379
|
+
|
|
380
|
+
def get_local_indexes(self) -> list[int]:
|
|
381
|
+
"""Get all local indexes from stored SampleMeta objects"""
|
|
382
|
+
return self.local_indexes
|
|
383
|
+
|
|
384
|
+
def get_field_names(self) -> list[str]:
|
|
385
|
+
"""Get all unique field names from stored SampleMeta objects"""
|
|
386
|
+
all_fields: set[str] = set()
|
|
387
|
+
for meta in self.sample_metas:
|
|
388
|
+
all_fields.update(meta.fields.keys())
|
|
389
|
+
return list(all_fields)
|
|
390
|
+
|
|
391
|
+
def get_transfer_data(self, field_names: Optional[list[str]] = None) -> dict[str, list | dict]:
|
|
392
|
+
"""Convert metadata to transfer dictionary format.
|
|
393
|
+
|
|
394
|
+
Creates a transfer_dict structure containing indexing and field information
|
|
395
|
+
but without the actual field data. The field_data placeholder will be
|
|
396
|
+
populated by the _add_field_data() function.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
field_names: Optional list of field names to include. If None, includes all fields.
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
Transfer dictionary with metadata structure:
|
|
403
|
+
{
|
|
404
|
+
"batch_indexes": [batch_idx1, batch_idx2, ...],
|
|
405
|
+
"global_indexes": [global_idx1, global_idx2, ...],
|
|
406
|
+
"local_indexes": [local_idx1, local_idx2, ...],
|
|
407
|
+
"fields": ["field1", "field2", ...],
|
|
408
|
+
"field_data": {} # Placeholder - actual data added by _add_field_data()
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
Example:
|
|
412
|
+
>>> group = StorageMetaGroup("storage1")
|
|
413
|
+
>>> # Add multiple samples with different batch/global indexes and storage locations
|
|
414
|
+
>>> group.add_sample_meta(SampleMeta(batch_index=0, global_index=10, fields={"img": ...}), 4)
|
|
415
|
+
>>> group.add_sample_meta(SampleMeta(batch_index=1, global_index=11, fields={"img": ...}), 5)
|
|
416
|
+
>>> group.add_sample_meta(SampleMeta(batch_index=2, global_index=12, fields={"img": ...}), 6)
|
|
417
|
+
>>> transfer_dict = group.get_transfer_data(["img"])
|
|
418
|
+
>>> transfer_dict["local_indexes"] # [4, 5, 6] - storage locations
|
|
419
|
+
>>> transfer_dict["batch_indexes"] # [0, 1, 2] - original data locations
|
|
420
|
+
>>> transfer_dict["global_indexes"] # [10, 11, 12] - global identifiers
|
|
421
|
+
"""
|
|
422
|
+
if field_names is None:
|
|
423
|
+
field_names = self.get_field_names()
|
|
424
|
+
return {
|
|
425
|
+
"batch_indexes": self.get_batch_indexes(),
|
|
426
|
+
"global_indexes": self.get_global_indexes(),
|
|
427
|
+
"local_indexes": self.get_local_indexes(),
|
|
428
|
+
"fields": field_names,
|
|
429
|
+
"field_data": {}, # Placeholder for field data to be filled later
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
@property
|
|
433
|
+
def size(self) -> int:
|
|
434
|
+
"""Number of samples in this storage meta group"""
|
|
435
|
+
return len(self.sample_metas)
|
|
436
|
+
|
|
437
|
+
@property
|
|
438
|
+
def is_empty(self) -> bool:
|
|
439
|
+
"""Check if this storage meta group is empty"""
|
|
440
|
+
return len(self.sample_metas) == 0
|
|
441
|
+
|
|
442
|
+
def __len__(self) -> int:
|
|
443
|
+
"""Number of samples in this storage meta group"""
|
|
444
|
+
return self.size
|
|
445
|
+
|
|
446
|
+
def __bool__(self) -> bool:
|
|
447
|
+
"""Truthiness based on whether group has samples"""
|
|
448
|
+
return not self.is_empty
|
|
449
|
+
|
|
450
|
+
def __str__(self) -> str:
|
|
451
|
+
return f"StorageMetaGroup(storage_id='{self.storage_id}', size={self.size})"
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|