TransferQueue 0.0.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +307 -0
- recipe/simple_use_case/sync_demo.py +223 -0
- tests/test_client.py +390 -0
- tests/test_controller.py +268 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +479 -0
- transfer_queue/__init__.py +42 -0
- transfer_queue/client.py +663 -0
- transfer_queue/controller.py +772 -0
- transfer_queue/metadata.py +603 -0
- transfer_queue/storage.py +515 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +98 -0
- transfer_queue/utils/zmq_utils.py +175 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.0.1.dev0.dist-info/METADATA +15 -0
- transferqueue-0.0.1.dev0.dist-info/RECORD +21 -0
- transferqueue-0.0.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.0.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.0.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,603 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import dataclasses
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Any, Optional
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
from tensordict import TensorDict
|
|
21
|
+
|
|
22
|
+
from transfer_queue.utils.utils import ProductionStatus
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class FieldMeta:
|
|
27
|
+
"""
|
|
28
|
+
Records the metadata of a single data field. (name, dtype, shape, etc.)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# field name (e.g., 'prompt', 'response', etc.)
|
|
32
|
+
name: str
|
|
33
|
+
|
|
34
|
+
# data schema info
|
|
35
|
+
dtype: Optional[Any] # if data has dtype attribute, e.g., torch.float32, numpy.float32, etc.
|
|
36
|
+
shape: Optional[Any] # if data has shape attribute, e.g., torch.Size([3, 224, 224]), (3, 224, 224), etc.
|
|
37
|
+
|
|
38
|
+
# data status info
|
|
39
|
+
production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED # production status for this field
|
|
40
|
+
|
|
41
|
+
def __str__(self) -> str:
|
|
42
|
+
return (
|
|
43
|
+
f"FieldMeta(name='{self.name}', dtype={self.dtype}, "
|
|
44
|
+
f"shape={self.shape}, production_status={self.production_status})"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def is_ready(self) -> bool:
|
|
49
|
+
"""Check if this field is ready for consumption"""
|
|
50
|
+
return self.production_status == ProductionStatus.READY_FOR_CONSUME
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class SampleMeta:
|
|
55
|
+
"""
|
|
56
|
+
Records the metadata of a single data sample (stored as a row in the data system).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
# algorithm related info
|
|
60
|
+
global_step: int # global step, used for data versioning
|
|
61
|
+
|
|
62
|
+
# data retrival info
|
|
63
|
+
global_index: int # global row index, uniquely identifies a data sample
|
|
64
|
+
storage_id: str # storage unit id
|
|
65
|
+
local_index: int # local row index in the storage unit
|
|
66
|
+
|
|
67
|
+
# data fields info
|
|
68
|
+
# this fields may not contain all the fields of the sample, but only fields-of-interest
|
|
69
|
+
fields: dict[str, FieldMeta]
|
|
70
|
+
|
|
71
|
+
def __post_init__(self):
|
|
72
|
+
"""Initialize is_ready property based on field readiness"""
|
|
73
|
+
# Check if all fields are ready and update is_ready property
|
|
74
|
+
object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
|
|
75
|
+
|
|
76
|
+
def __str__(self) -> str:
|
|
77
|
+
return (
|
|
78
|
+
f"SampleMeta(global_step={self.global_step}, "
|
|
79
|
+
f"global_index={self.global_index}, storage_id='{self.storage_id}', "
|
|
80
|
+
f"local_index={self.local_index}, fields={self.fields})"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def field_names(self) -> list[str]:
|
|
85
|
+
"""Get list of field names for this sample"""
|
|
86
|
+
return list(self.fields.keys())
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def batch_index(self) -> int:
|
|
90
|
+
"""Get the batch index of this sample (to be set by BatchMeta)"""
|
|
91
|
+
return getattr(self, "_batch_index", -1)
|
|
92
|
+
|
|
93
|
+
def get_field_by_name(self, name: str) -> Optional[FieldMeta]:
|
|
94
|
+
"""Get FieldMeta by field name"""
|
|
95
|
+
return self.fields.get(name)
|
|
96
|
+
|
|
97
|
+
def has_field(self, name: str) -> bool:
|
|
98
|
+
"""Check if this sample has a specific field"""
|
|
99
|
+
return name in self.fields
|
|
100
|
+
|
|
101
|
+
def is_field_ready(self, field_name: str) -> bool:
|
|
102
|
+
"""Check if a specific field is ready for consumption"""
|
|
103
|
+
field = self.fields.get(field_name)
|
|
104
|
+
return field.is_ready if field else False
|
|
105
|
+
|
|
106
|
+
def add_fields(self, fields: dict[str, FieldMeta]) -> "SampleMeta":
|
|
107
|
+
"""
|
|
108
|
+
Add new fields to this sample. New fields will be initialized with given dtype, shape
|
|
109
|
+
and production_status (if provided). If not provided, default values (None, None, READY_FOR_CONSUME)
|
|
110
|
+
will be used.
|
|
111
|
+
This modifies the sample in-place to include the new fields.
|
|
112
|
+
"""
|
|
113
|
+
self.fields = _union_fields(self.fields, fields)
|
|
114
|
+
# Update is_ready property
|
|
115
|
+
object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
|
|
116
|
+
return self
|
|
117
|
+
|
|
118
|
+
def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta":
|
|
119
|
+
"""
|
|
120
|
+
Create a union of this sample's fields with another sample's fields.
|
|
121
|
+
Assume both samples have the same global index. If fields overlap, the
|
|
122
|
+
fields in this sample will be replaced by the other sample's fields.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
other: Another SampleMeta to union with
|
|
126
|
+
validate: Whether to validate union conditions
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
New SampleMeta with unioned fields (None if validation fails)
|
|
130
|
+
"""
|
|
131
|
+
if validate:
|
|
132
|
+
if self.global_index != other.global_index:
|
|
133
|
+
raise ValueError(
|
|
134
|
+
f"Error: Global indexes ({self.global_index} and {other.global_index}) do not match for union."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Merge fields
|
|
138
|
+
self.fields = _union_fields(self.fields, other.fields)
|
|
139
|
+
|
|
140
|
+
# Update is_ready property
|
|
141
|
+
object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
|
|
142
|
+
return self
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def is_ready(self) -> bool:
|
|
146
|
+
"""Check if all fields in this sample are ready for consumption"""
|
|
147
|
+
return getattr(self, "_is_ready", False)
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def production_status(self) -> dict[str, ProductionStatus]:
|
|
151
|
+
"""Get production status for all fields (backward compatibility)"""
|
|
152
|
+
return {name: field.production_status for name, field in self.fields.items()}
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@dataclass
|
|
156
|
+
class StorageMetaGroup:
|
|
157
|
+
"""
|
|
158
|
+
Represents a group of samples stored in the same storage unit.
|
|
159
|
+
Used to organize samples by their storage_id for efficient client operations.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
storage_id: str
|
|
163
|
+
sample_metas: list[SampleMeta] = dataclasses.field(default_factory=list)
|
|
164
|
+
|
|
165
|
+
def add_sample_meta(self, sample_meta: SampleMeta) -> None:
|
|
166
|
+
"""Add a SampleMeta object to this storage group"""
|
|
167
|
+
self.sample_metas.append(sample_meta)
|
|
168
|
+
|
|
169
|
+
def get_batch_indexes(self) -> list[int]:
|
|
170
|
+
"""Get all internal indexes from stored SampleMeta objects"""
|
|
171
|
+
return [meta.batch_index for meta in self.sample_metas]
|
|
172
|
+
|
|
173
|
+
def get_global_indexes(self) -> list[int]:
|
|
174
|
+
"""Get all global indexes from stored SampleMeta objects"""
|
|
175
|
+
return [meta.global_index for meta in self.sample_metas]
|
|
176
|
+
|
|
177
|
+
def get_local_indexes(self) -> list[int]:
|
|
178
|
+
"""Get all local indexes from stored SampleMeta objects"""
|
|
179
|
+
return [meta.local_index for meta in self.sample_metas]
|
|
180
|
+
|
|
181
|
+
def get_field_names(self) -> list[str]:
|
|
182
|
+
"""Get all unique field names from stored SampleMeta objects"""
|
|
183
|
+
all_fields: set[str] = set()
|
|
184
|
+
for meta in self.sample_metas:
|
|
185
|
+
all_fields.update(meta.fields.keys())
|
|
186
|
+
return list(all_fields)
|
|
187
|
+
|
|
188
|
+
def get_transfer_info(self, field_names: Optional[list[str]] = None) -> dict[str, list | dict]:
|
|
189
|
+
"""Convert to dictionary format for backward compatibility"""
|
|
190
|
+
if field_names is None:
|
|
191
|
+
field_names = self.get_field_names()
|
|
192
|
+
return {
|
|
193
|
+
"batch_indexes": self.get_batch_indexes(),
|
|
194
|
+
"global_indexes": self.get_global_indexes(),
|
|
195
|
+
"local_indexes": self.get_local_indexes(),
|
|
196
|
+
"fields": field_names,
|
|
197
|
+
"field_data": {}, # Placeholder for field data to be filled later
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def size(self) -> int:
|
|
202
|
+
"""Number of samples in this storage meta group"""
|
|
203
|
+
return len(self.sample_metas)
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def is_empty(self) -> bool:
|
|
207
|
+
"""Check if this storage meta group is empty"""
|
|
208
|
+
return len(self.sample_metas) == 0
|
|
209
|
+
|
|
210
|
+
def __len__(self) -> int:
|
|
211
|
+
"""Number of samples in this storage meta group"""
|
|
212
|
+
return self.size
|
|
213
|
+
|
|
214
|
+
def __bool__(self) -> bool:
|
|
215
|
+
"""Truthiness based on whether group has samples"""
|
|
216
|
+
return not self.is_empty
|
|
217
|
+
|
|
218
|
+
def __str__(self) -> str:
|
|
219
|
+
return f"StorageMetaGroup(storage_id='{self.storage_id}', size={self.size})"
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@dataclass
|
|
223
|
+
class BatchMeta:
|
|
224
|
+
"""
|
|
225
|
+
Records the metadata of a batch of data samples.
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
samples: list[SampleMeta]
|
|
229
|
+
extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
230
|
+
|
|
231
|
+
def __post_init__(self):
|
|
232
|
+
"""Initialize all computed properties during initialization"""
|
|
233
|
+
# Basic properties
|
|
234
|
+
object.__setattr__(self, "_size", len(self.samples))
|
|
235
|
+
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
|
|
236
|
+
|
|
237
|
+
# Pre-compute all list properties for better performance
|
|
238
|
+
if self.samples:
|
|
239
|
+
for idx, sample in enumerate(self.samples):
|
|
240
|
+
object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly
|
|
241
|
+
|
|
242
|
+
object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
|
|
243
|
+
object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples])
|
|
244
|
+
object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples])
|
|
245
|
+
|
|
246
|
+
# assume all samples have the same fields.
|
|
247
|
+
object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
|
|
248
|
+
|
|
249
|
+
# Initialize storage groups for efficient client operations
|
|
250
|
+
storage_meta_groups = self._build_storage_meta_groups()
|
|
251
|
+
object.__setattr__(self, "_storage_meta_groups", storage_meta_groups)
|
|
252
|
+
else:
|
|
253
|
+
object.__setattr__(self, "_global_indexes", [])
|
|
254
|
+
object.__setattr__(self, "_local_indexes", [])
|
|
255
|
+
object.__setattr__(self, "_storage_ids", [])
|
|
256
|
+
object.__setattr__(self, "_field_names", [])
|
|
257
|
+
object.__setattr__(self, "_storage_meta_groups", {})
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def size(self) -> int:
|
|
261
|
+
"""Return the number of samples in this batch"""
|
|
262
|
+
return getattr(self, "_size", 0)
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def global_indexes(self) -> list[int]:
|
|
266
|
+
"""Get all global indexes in this batch"""
|
|
267
|
+
return getattr(self, "_global_indexes", [])
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def field_names(self) -> list[str]:
|
|
271
|
+
"""Get all unique field names in this batch"""
|
|
272
|
+
return getattr(self, "_field_names", [])
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def local_indexes(self) -> list[int]:
|
|
276
|
+
"""Get all local indexes in this batch"""
|
|
277
|
+
return getattr(self, "_local_indexes", [])
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def storage_ids(self) -> list[str]:
|
|
281
|
+
"""Get all storage unit IDs in this batch"""
|
|
282
|
+
return getattr(self, "_storage_ids", [])
|
|
283
|
+
|
|
284
|
+
@property
|
|
285
|
+
def is_ready(self) -> bool:
|
|
286
|
+
"""Check if all samples in this batch are ready for consumption"""
|
|
287
|
+
# TODO: get ready status from controller realtime
|
|
288
|
+
return getattr(self, "_is_ready", False)
|
|
289
|
+
|
|
290
|
+
def _build_storage_meta_groups(self) -> dict[str, StorageMetaGroup]:
|
|
291
|
+
"""Build storage groups from samples during initialization"""
|
|
292
|
+
storage_meta_groups: dict[str, StorageMetaGroup] = {}
|
|
293
|
+
|
|
294
|
+
for sample in self.samples:
|
|
295
|
+
storage_id = sample.storage_id
|
|
296
|
+
if storage_id not in storage_meta_groups:
|
|
297
|
+
storage_meta_groups[storage_id] = StorageMetaGroup(storage_id=storage_id)
|
|
298
|
+
|
|
299
|
+
# Use add_sample_meta to store SampleMeta references directly
|
|
300
|
+
storage_meta_groups[storage_id].add_sample_meta(sample)
|
|
301
|
+
|
|
302
|
+
return storage_meta_groups
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def storage_meta_groups(self) -> dict[str, StorageMetaGroup]:
|
|
306
|
+
"""Get storage groups organized by storage_id"""
|
|
307
|
+
return getattr(self, "_storage_meta_groups", {})
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def storage_unit_ids(self) -> list[str]:
|
|
311
|
+
"""Get list of all storage unit IDs"""
|
|
312
|
+
return list(self.storage_meta_groups.keys())
|
|
313
|
+
|
|
314
|
+
def get_storage_meta_groups(self, storage_id: str) -> Optional[StorageMetaGroup]:
|
|
315
|
+
"""Get storage group by storage ID"""
|
|
316
|
+
return self.storage_meta_groups.get(storage_id)
|
|
317
|
+
|
|
318
|
+
# Extra info interface methods
|
|
319
|
+
def get_extra_info(self, key: str, default: Any = None) -> Any:
|
|
320
|
+
"""Get extra info by key"""
|
|
321
|
+
return self.extra_info.get(key, default)
|
|
322
|
+
|
|
323
|
+
def set_extra_info(self, key: str, value: Any) -> None:
|
|
324
|
+
"""Set extra info by key"""
|
|
325
|
+
self.extra_info[key] = value
|
|
326
|
+
|
|
327
|
+
def update_extra_info(self, info_dict: dict[str, Any]) -> None:
|
|
328
|
+
"""Update extra info with multiple key-value pairs"""
|
|
329
|
+
self.extra_info.update(info_dict)
|
|
330
|
+
|
|
331
|
+
def remove_extra_info(self, key: str) -> Any:
|
|
332
|
+
"""Remove extra info by key and return its value"""
|
|
333
|
+
return self.extra_info.pop(key, None)
|
|
334
|
+
|
|
335
|
+
def clear_extra_info(self) -> None:
|
|
336
|
+
"""Clear all extra info"""
|
|
337
|
+
self.extra_info.clear()
|
|
338
|
+
|
|
339
|
+
def has_extra_info(self, key: str) -> bool:
|
|
340
|
+
"""Check if extra info contains a specific key"""
|
|
341
|
+
return key in self.extra_info
|
|
342
|
+
|
|
343
|
+
def get_all_extra_info(self) -> dict[str, Any]:
|
|
344
|
+
"""Get all extra info as a dictionary"""
|
|
345
|
+
return self.extra_info.copy()
|
|
346
|
+
|
|
347
|
+
def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta":
|
|
348
|
+
"""
|
|
349
|
+
Add new fields from a TensorDict to all samples in this batch.
|
|
350
|
+
This modifies each sample in-place to include the new fields.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
tensor_dict (TensorDict): The input TensorDict containing new fields.
|
|
354
|
+
set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True.
|
|
355
|
+
"""
|
|
356
|
+
fields = _extract_field_metas(tensor_dict, set_all_ready)
|
|
357
|
+
if len(fields) > 0:
|
|
358
|
+
for idx, sample in enumerate(self.samples):
|
|
359
|
+
sample.add_fields(fields=fields[idx])
|
|
360
|
+
|
|
361
|
+
# Update batch-level fields cache
|
|
362
|
+
object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
|
|
363
|
+
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
|
|
364
|
+
return self
|
|
365
|
+
|
|
366
|
+
def __len__(self) -> int:
|
|
367
|
+
"""Return the number of samples in this batch."""
|
|
368
|
+
return len(self.samples)
|
|
369
|
+
|
|
370
|
+
def __getitem__(self, item):
|
|
371
|
+
if isinstance(item, int | np.integer):
|
|
372
|
+
sample_meta = self.samples[item] if self.samples else []
|
|
373
|
+
return BatchMeta(samples=[sample_meta], extra_info=self.extra_info)
|
|
374
|
+
else:
|
|
375
|
+
raise TypeError(f"Indexing with {type(item)} is not supported now!")
|
|
376
|
+
|
|
377
|
+
def chunk(self, chunks: int) -> list["BatchMeta"]:
|
|
378
|
+
"""
|
|
379
|
+
Split this batch into smaller chunks.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
chunks: number of chunks
|
|
383
|
+
|
|
384
|
+
Return:
|
|
385
|
+
List of smaller BatchMeta chunks
|
|
386
|
+
"""
|
|
387
|
+
chunk_list = []
|
|
388
|
+
n = len(self.samples)
|
|
389
|
+
|
|
390
|
+
# Calculate the base size and remainder of each chunk
|
|
391
|
+
base_size = n // chunks
|
|
392
|
+
remainder = n % chunks
|
|
393
|
+
|
|
394
|
+
start = 0
|
|
395
|
+
for i in range(chunks):
|
|
396
|
+
# Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size)
|
|
397
|
+
current_chunk_size = base_size + 1 if i < remainder else base_size
|
|
398
|
+
end = start + current_chunk_size
|
|
399
|
+
chunk_samples = self.samples[start:end]
|
|
400
|
+
chunk = BatchMeta(samples=chunk_samples, extra_info=self.extra_info.copy())
|
|
401
|
+
chunk_list.append(chunk)
|
|
402
|
+
start = end
|
|
403
|
+
return chunk_list
|
|
404
|
+
|
|
405
|
+
@classmethod
|
|
406
|
+
def concat(cls, data: list["BatchMeta"], validate: bool = True) -> Optional["BatchMeta"]:
|
|
407
|
+
"""
|
|
408
|
+
Concatenate multiple BatchMeta chunks into one large batch.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
data: List of BatchMeta chunks to concatenate
|
|
412
|
+
validate: Whether to validate concatenation conditions
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
Concatenated BatchMeta
|
|
416
|
+
|
|
417
|
+
Raises:
|
|
418
|
+
ValueError: If validation fails (e.g., field names do not match)
|
|
419
|
+
"""
|
|
420
|
+
if not data:
|
|
421
|
+
return None
|
|
422
|
+
|
|
423
|
+
if validate:
|
|
424
|
+
base_fields = data[0].field_names
|
|
425
|
+
|
|
426
|
+
for chunk in data:
|
|
427
|
+
if chunk.field_names != base_fields:
|
|
428
|
+
raise ValueError("Error: Field names do not match for concatenation.")
|
|
429
|
+
|
|
430
|
+
# Combine all samples
|
|
431
|
+
all_samples = []
|
|
432
|
+
for chunk in data:
|
|
433
|
+
all_samples.extend(chunk.samples)
|
|
434
|
+
# Merge all extra_info dictionaries from the chunks
|
|
435
|
+
merged_extra_info = {}
|
|
436
|
+
for chunk in data:
|
|
437
|
+
merged_extra_info.update(chunk.extra_info)
|
|
438
|
+
return BatchMeta(samples=all_samples, extra_info=merged_extra_info)
|
|
439
|
+
|
|
440
|
+
def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]:
|
|
441
|
+
"""
|
|
442
|
+
Create a union of this batch's fields with another batch's fields.
|
|
443
|
+
Assume both batches have the same global indices. If fields overlap, the
|
|
444
|
+
fields in this batch will be replaced by the other batch's fields.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
other: Another BatchMeta to union with
|
|
448
|
+
validate: Whether to validate union conditions
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
New BatchMeta with unioned fields
|
|
452
|
+
|
|
453
|
+
Raises:
|
|
454
|
+
ValueError: If validation fails (e.g., batch sizes or global indexes do not match)
|
|
455
|
+
"""
|
|
456
|
+
if validate:
|
|
457
|
+
if self.size != other.size:
|
|
458
|
+
raise ValueError("Error: Batch sizes do not match for union.")
|
|
459
|
+
|
|
460
|
+
self_global_indexes = sorted(self.global_indexes)
|
|
461
|
+
other_global_indexes = sorted(other.global_indexes)
|
|
462
|
+
if self_global_indexes != other_global_indexes:
|
|
463
|
+
raise ValueError("Error: Global indexes do not match for union.")
|
|
464
|
+
|
|
465
|
+
# Create a mapping from global_index to SampleMeta in the other batch
|
|
466
|
+
other_sample_map = {sample.global_index: sample for sample in other.samples}
|
|
467
|
+
|
|
468
|
+
# Merge samples
|
|
469
|
+
merged_samples = []
|
|
470
|
+
for sample in self.samples:
|
|
471
|
+
if sample.global_index in other_sample_map:
|
|
472
|
+
other_sample = other_sample_map[sample.global_index]
|
|
473
|
+
merged_sample = sample.union(other_sample, validate=validate)
|
|
474
|
+
merged_samples.append(merged_sample)
|
|
475
|
+
else:
|
|
476
|
+
merged_samples.append(sample)
|
|
477
|
+
|
|
478
|
+
# Merge extra info dictionaries
|
|
479
|
+
merged_extra_info = {**self.extra_info, **other.extra_info}
|
|
480
|
+
|
|
481
|
+
return BatchMeta(samples=merged_samples, extra_info=merged_extra_info)
|
|
482
|
+
|
|
483
|
+
def reorder(self, indices: list[int]):
|
|
484
|
+
"""
|
|
485
|
+
Reorder the SampleMeta in the BatchMeta according to the given indices.
|
|
486
|
+
|
|
487
|
+
The operation is performed in-place, modifying the current BatchMeta's SampleMeta order.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
indices : list[int]
|
|
491
|
+
A list of integers specifying the new order of SampleMeta. Each integer
|
|
492
|
+
represents the current index of the SampleMeta in the BatchMeta.
|
|
493
|
+
"""
|
|
494
|
+
# Reorder the samples
|
|
495
|
+
reordered_samples = [self.samples[i] for i in indices]
|
|
496
|
+
object.__setattr__(self, "samples", reordered_samples)
|
|
497
|
+
|
|
498
|
+
# Update necessary attributes
|
|
499
|
+
self._update_after_reorder()
|
|
500
|
+
|
|
501
|
+
def _update_after_reorder(self) -> None:
|
|
502
|
+
"""Update related attributes specifically for the reorder operation"""
|
|
503
|
+
# Update batch_index for each sample
|
|
504
|
+
for idx, sample in enumerate(self.samples):
|
|
505
|
+
object.__setattr__(sample, "_batch_index", idx)
|
|
506
|
+
|
|
507
|
+
# Update cached index lists
|
|
508
|
+
object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
|
|
509
|
+
object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples])
|
|
510
|
+
object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples])
|
|
511
|
+
|
|
512
|
+
# Note: No need to rebuild storage_meta_groups as samples' storage_id remain unchanged
|
|
513
|
+
# and their order does not affect the grouping
|
|
514
|
+
# storage_meta_groups = self._build_storage_meta_groups()
|
|
515
|
+
# object.__setattr__(self, "_storage_meta_groups", storage_meta_groups)
|
|
516
|
+
|
|
517
|
+
# Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder
|
|
518
|
+
|
|
519
|
+
@classmethod
|
|
520
|
+
def from_samples(
|
|
521
|
+
cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None
|
|
522
|
+
) -> "BatchMeta":
|
|
523
|
+
"""
|
|
524
|
+
Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
samples: A single SampleMeta or a list of SampleMeta objects
|
|
528
|
+
extra_info: Optional additional information to store with the batch
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
BatchMeta instance containing the provided sample(s)
|
|
532
|
+
|
|
533
|
+
Example:
|
|
534
|
+
>>> sample_meta = SampleMeta(...)
|
|
535
|
+
>>> batch_meta = BatchMeta.from_samples(sample_meta)
|
|
536
|
+
|
|
537
|
+
>>> sample_metas = [sample1, sample2, sample3]
|
|
538
|
+
>>> batch_meta = BatchMeta.from_samples(sample_metas, extra_info={"source": "training"})
|
|
539
|
+
"""
|
|
540
|
+
if extra_info is None:
|
|
541
|
+
extra_info = {}
|
|
542
|
+
|
|
543
|
+
if isinstance(samples, SampleMeta):
|
|
544
|
+
samples = [samples]
|
|
545
|
+
|
|
546
|
+
return cls(samples=samples, extra_info=extra_info)
|
|
547
|
+
|
|
548
|
+
@classmethod
|
|
549
|
+
def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta":
|
|
550
|
+
"""
|
|
551
|
+
Create an empty BatchMeta with no samples.
|
|
552
|
+
|
|
553
|
+
Args:
|
|
554
|
+
extra_info: Optional additional information to store with the batch
|
|
555
|
+
|
|
556
|
+
Returns:
|
|
557
|
+
Empty BatchMeta instance
|
|
558
|
+
|
|
559
|
+
Example:
|
|
560
|
+
>>> empty_batch = BatchMeta.empty()
|
|
561
|
+
"""
|
|
562
|
+
if extra_info is None:
|
|
563
|
+
extra_info = {}
|
|
564
|
+
return cls(samples=[], extra_info=extra_info)
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]:
|
|
568
|
+
"""Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2."""
|
|
569
|
+
for name in fields2.keys():
|
|
570
|
+
fields1[name] = fields2[name]
|
|
571
|
+
return fields1
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> list[dict[str, FieldMeta]]:
|
|
575
|
+
"""
|
|
576
|
+
Extract field metas from a TensorDict. If data in tensor_dict does not have dtype or shape attribute,
|
|
577
|
+
the corresponding dtype or shape will be set to None.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
tensor_dict (TensorDict): The input TensorDict.
|
|
581
|
+
set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME.
|
|
582
|
+
Otherwise, set to NOT_PRODUCED. Default is True.
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
all_fields (list[dict[str, FieldMeta]]): A list of dictionaries containing field metadata.
|
|
586
|
+
"""
|
|
587
|
+
all_fields = []
|
|
588
|
+
batch_size = tensor_dict.batch_size[0]
|
|
589
|
+
for idx in range(batch_size):
|
|
590
|
+
fields = {}
|
|
591
|
+
sample = tensor_dict[idx]
|
|
592
|
+
for name, value in sample.items():
|
|
593
|
+
fields[name] = FieldMeta(
|
|
594
|
+
name=name,
|
|
595
|
+
dtype=value.dtype if hasattr(value, "dtype") else None,
|
|
596
|
+
shape=value.shape if hasattr(value, "shape") else None,
|
|
597
|
+
production_status=ProductionStatus.READY_FOR_CONSUME
|
|
598
|
+
if set_all_ready
|
|
599
|
+
else ProductionStatus.NOT_PRODUCED,
|
|
600
|
+
)
|
|
601
|
+
all_fields.append(fields)
|
|
602
|
+
|
|
603
|
+
return all_fields
|