TransferQueue 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,603 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ from dataclasses import dataclass
17
+ from typing import Any, Optional
18
+
19
+ import numpy as np
20
+ from tensordict import TensorDict
21
+
22
+ from transfer_queue.utils.utils import ProductionStatus
23
+
24
+
25
+ @dataclass
26
+ class FieldMeta:
27
+ """
28
+ Records the metadata of a single data field. (name, dtype, shape, etc.)
29
+ """
30
+
31
+ # field name (e.g., 'prompt', 'response', etc.)
32
+ name: str
33
+
34
+ # data schema info
35
+ dtype: Optional[Any] # if data has dtype attribute, e.g., torch.float32, numpy.float32, etc.
36
+ shape: Optional[Any] # if data has shape attribute, e.g., torch.Size([3, 224, 224]), (3, 224, 224), etc.
37
+
38
+ # data status info
39
+ production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED # production status for this field
40
+
41
+ def __str__(self) -> str:
42
+ return (
43
+ f"FieldMeta(name='{self.name}', dtype={self.dtype}, "
44
+ f"shape={self.shape}, production_status={self.production_status})"
45
+ )
46
+
47
+ @property
48
+ def is_ready(self) -> bool:
49
+ """Check if this field is ready for consumption"""
50
+ return self.production_status == ProductionStatus.READY_FOR_CONSUME
51
+
52
+
53
+ @dataclass
54
+ class SampleMeta:
55
+ """
56
+ Records the metadata of a single data sample (stored as a row in the data system).
57
+ """
58
+
59
+ # algorithm related info
60
+ global_step: int # global step, used for data versioning
61
+
62
+ # data retrival info
63
+ global_index: int # global row index, uniquely identifies a data sample
64
+ storage_id: str # storage unit id
65
+ local_index: int # local row index in the storage unit
66
+
67
+ # data fields info
68
+ # this fields may not contain all the fields of the sample, but only fields-of-interest
69
+ fields: dict[str, FieldMeta]
70
+
71
+ def __post_init__(self):
72
+ """Initialize is_ready property based on field readiness"""
73
+ # Check if all fields are ready and update is_ready property
74
+ object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
75
+
76
+ def __str__(self) -> str:
77
+ return (
78
+ f"SampleMeta(global_step={self.global_step}, "
79
+ f"global_index={self.global_index}, storage_id='{self.storage_id}', "
80
+ f"local_index={self.local_index}, fields={self.fields})"
81
+ )
82
+
83
+ @property
84
+ def field_names(self) -> list[str]:
85
+ """Get list of field names for this sample"""
86
+ return list(self.fields.keys())
87
+
88
+ @property
89
+ def batch_index(self) -> int:
90
+ """Get the batch index of this sample (to be set by BatchMeta)"""
91
+ return getattr(self, "_batch_index", -1)
92
+
93
+ def get_field_by_name(self, name: str) -> Optional[FieldMeta]:
94
+ """Get FieldMeta by field name"""
95
+ return self.fields.get(name)
96
+
97
+ def has_field(self, name: str) -> bool:
98
+ """Check if this sample has a specific field"""
99
+ return name in self.fields
100
+
101
+ def is_field_ready(self, field_name: str) -> bool:
102
+ """Check if a specific field is ready for consumption"""
103
+ field = self.fields.get(field_name)
104
+ return field.is_ready if field else False
105
+
106
+ def add_fields(self, fields: dict[str, FieldMeta]) -> "SampleMeta":
107
+ """
108
+ Add new fields to this sample. New fields will be initialized with given dtype, shape
109
+ and production_status (if provided). If not provided, default values (None, None, READY_FOR_CONSUME)
110
+ will be used.
111
+ This modifies the sample in-place to include the new fields.
112
+ """
113
+ self.fields = _union_fields(self.fields, fields)
114
+ # Update is_ready property
115
+ object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
116
+ return self
117
+
118
+ def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta":
119
+ """
120
+ Create a union of this sample's fields with another sample's fields.
121
+ Assume both samples have the same global index. If fields overlap, the
122
+ fields in this sample will be replaced by the other sample's fields.
123
+
124
+ Args:
125
+ other: Another SampleMeta to union with
126
+ validate: Whether to validate union conditions
127
+
128
+ Returns:
129
+ New SampleMeta with unioned fields (None if validation fails)
130
+ """
131
+ if validate:
132
+ if self.global_index != other.global_index:
133
+ raise ValueError(
134
+ f"Error: Global indexes ({self.global_index} and {other.global_index}) do not match for union."
135
+ )
136
+
137
+ # Merge fields
138
+ self.fields = _union_fields(self.fields, other.fields)
139
+
140
+ # Update is_ready property
141
+ object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
142
+ return self
143
+
144
+ @property
145
+ def is_ready(self) -> bool:
146
+ """Check if all fields in this sample are ready for consumption"""
147
+ return getattr(self, "_is_ready", False)
148
+
149
+ @property
150
+ def production_status(self) -> dict[str, ProductionStatus]:
151
+ """Get production status for all fields (backward compatibility)"""
152
+ return {name: field.production_status for name, field in self.fields.items()}
153
+
154
+
155
+ @dataclass
156
+ class StorageMetaGroup:
157
+ """
158
+ Represents a group of samples stored in the same storage unit.
159
+ Used to organize samples by their storage_id for efficient client operations.
160
+ """
161
+
162
+ storage_id: str
163
+ sample_metas: list[SampleMeta] = dataclasses.field(default_factory=list)
164
+
165
+ def add_sample_meta(self, sample_meta: SampleMeta) -> None:
166
+ """Add a SampleMeta object to this storage group"""
167
+ self.sample_metas.append(sample_meta)
168
+
169
+ def get_batch_indexes(self) -> list[int]:
170
+ """Get all internal indexes from stored SampleMeta objects"""
171
+ return [meta.batch_index for meta in self.sample_metas]
172
+
173
+ def get_global_indexes(self) -> list[int]:
174
+ """Get all global indexes from stored SampleMeta objects"""
175
+ return [meta.global_index for meta in self.sample_metas]
176
+
177
+ def get_local_indexes(self) -> list[int]:
178
+ """Get all local indexes from stored SampleMeta objects"""
179
+ return [meta.local_index for meta in self.sample_metas]
180
+
181
+ def get_field_names(self) -> list[str]:
182
+ """Get all unique field names from stored SampleMeta objects"""
183
+ all_fields: set[str] = set()
184
+ for meta in self.sample_metas:
185
+ all_fields.update(meta.fields.keys())
186
+ return list(all_fields)
187
+
188
+ def get_transfer_info(self, field_names: Optional[list[str]] = None) -> dict[str, list | dict]:
189
+ """Convert to dictionary format for backward compatibility"""
190
+ if field_names is None:
191
+ field_names = self.get_field_names()
192
+ return {
193
+ "batch_indexes": self.get_batch_indexes(),
194
+ "global_indexes": self.get_global_indexes(),
195
+ "local_indexes": self.get_local_indexes(),
196
+ "fields": field_names,
197
+ "field_data": {}, # Placeholder for field data to be filled later
198
+ }
199
+
200
+ @property
201
+ def size(self) -> int:
202
+ """Number of samples in this storage meta group"""
203
+ return len(self.sample_metas)
204
+
205
+ @property
206
+ def is_empty(self) -> bool:
207
+ """Check if this storage meta group is empty"""
208
+ return len(self.sample_metas) == 0
209
+
210
+ def __len__(self) -> int:
211
+ """Number of samples in this storage meta group"""
212
+ return self.size
213
+
214
+ def __bool__(self) -> bool:
215
+ """Truthiness based on whether group has samples"""
216
+ return not self.is_empty
217
+
218
+ def __str__(self) -> str:
219
+ return f"StorageMetaGroup(storage_id='{self.storage_id}', size={self.size})"
220
+
221
+
222
+ @dataclass
223
+ class BatchMeta:
224
+ """
225
+ Records the metadata of a batch of data samples.
226
+ """
227
+
228
+ samples: list[SampleMeta]
229
+ extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
230
+
231
+ def __post_init__(self):
232
+ """Initialize all computed properties during initialization"""
233
+ # Basic properties
234
+ object.__setattr__(self, "_size", len(self.samples))
235
+ object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
236
+
237
+ # Pre-compute all list properties for better performance
238
+ if self.samples:
239
+ for idx, sample in enumerate(self.samples):
240
+ object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly
241
+
242
+ object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
243
+ object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples])
244
+ object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples])
245
+
246
+ # assume all samples have the same fields.
247
+ object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
248
+
249
+ # Initialize storage groups for efficient client operations
250
+ storage_meta_groups = self._build_storage_meta_groups()
251
+ object.__setattr__(self, "_storage_meta_groups", storage_meta_groups)
252
+ else:
253
+ object.__setattr__(self, "_global_indexes", [])
254
+ object.__setattr__(self, "_local_indexes", [])
255
+ object.__setattr__(self, "_storage_ids", [])
256
+ object.__setattr__(self, "_field_names", [])
257
+ object.__setattr__(self, "_storage_meta_groups", {})
258
+
259
+ @property
260
+ def size(self) -> int:
261
+ """Return the number of samples in this batch"""
262
+ return getattr(self, "_size", 0)
263
+
264
+ @property
265
+ def global_indexes(self) -> list[int]:
266
+ """Get all global indexes in this batch"""
267
+ return getattr(self, "_global_indexes", [])
268
+
269
+ @property
270
+ def field_names(self) -> list[str]:
271
+ """Get all unique field names in this batch"""
272
+ return getattr(self, "_field_names", [])
273
+
274
+ @property
275
+ def local_indexes(self) -> list[int]:
276
+ """Get all local indexes in this batch"""
277
+ return getattr(self, "_local_indexes", [])
278
+
279
+ @property
280
+ def storage_ids(self) -> list[str]:
281
+ """Get all storage unit IDs in this batch"""
282
+ return getattr(self, "_storage_ids", [])
283
+
284
+ @property
285
+ def is_ready(self) -> bool:
286
+ """Check if all samples in this batch are ready for consumption"""
287
+ # TODO: get ready status from controller realtime
288
+ return getattr(self, "_is_ready", False)
289
+
290
+ def _build_storage_meta_groups(self) -> dict[str, StorageMetaGroup]:
291
+ """Build storage groups from samples during initialization"""
292
+ storage_meta_groups: dict[str, StorageMetaGroup] = {}
293
+
294
+ for sample in self.samples:
295
+ storage_id = sample.storage_id
296
+ if storage_id not in storage_meta_groups:
297
+ storage_meta_groups[storage_id] = StorageMetaGroup(storage_id=storage_id)
298
+
299
+ # Use add_sample_meta to store SampleMeta references directly
300
+ storage_meta_groups[storage_id].add_sample_meta(sample)
301
+
302
+ return storage_meta_groups
303
+
304
+ @property
305
+ def storage_meta_groups(self) -> dict[str, StorageMetaGroup]:
306
+ """Get storage groups organized by storage_id"""
307
+ return getattr(self, "_storage_meta_groups", {})
308
+
309
+ @property
310
+ def storage_unit_ids(self) -> list[str]:
311
+ """Get list of all storage unit IDs"""
312
+ return list(self.storage_meta_groups.keys())
313
+
314
+ def get_storage_meta_groups(self, storage_id: str) -> Optional[StorageMetaGroup]:
315
+ """Get storage group by storage ID"""
316
+ return self.storage_meta_groups.get(storage_id)
317
+
318
+ # Extra info interface methods
319
+ def get_extra_info(self, key: str, default: Any = None) -> Any:
320
+ """Get extra info by key"""
321
+ return self.extra_info.get(key, default)
322
+
323
+ def set_extra_info(self, key: str, value: Any) -> None:
324
+ """Set extra info by key"""
325
+ self.extra_info[key] = value
326
+
327
+ def update_extra_info(self, info_dict: dict[str, Any]) -> None:
328
+ """Update extra info with multiple key-value pairs"""
329
+ self.extra_info.update(info_dict)
330
+
331
+ def remove_extra_info(self, key: str) -> Any:
332
+ """Remove extra info by key and return its value"""
333
+ return self.extra_info.pop(key, None)
334
+
335
+ def clear_extra_info(self) -> None:
336
+ """Clear all extra info"""
337
+ self.extra_info.clear()
338
+
339
+ def has_extra_info(self, key: str) -> bool:
340
+ """Check if extra info contains a specific key"""
341
+ return key in self.extra_info
342
+
343
+ def get_all_extra_info(self) -> dict[str, Any]:
344
+ """Get all extra info as a dictionary"""
345
+ return self.extra_info.copy()
346
+
347
+ def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta":
348
+ """
349
+ Add new fields from a TensorDict to all samples in this batch.
350
+ This modifies each sample in-place to include the new fields.
351
+
352
+ Args:
353
+ tensor_dict (TensorDict): The input TensorDict containing new fields.
354
+ set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True.
355
+ """
356
+ fields = _extract_field_metas(tensor_dict, set_all_ready)
357
+ if len(fields) > 0:
358
+ for idx, sample in enumerate(self.samples):
359
+ sample.add_fields(fields=fields[idx])
360
+
361
+ # Update batch-level fields cache
362
+ object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
363
+ object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
364
+ return self
365
+
366
+ def __len__(self) -> int:
367
+ """Return the number of samples in this batch."""
368
+ return len(self.samples)
369
+
370
+ def __getitem__(self, item):
371
+ if isinstance(item, int | np.integer):
372
+ sample_meta = self.samples[item] if self.samples else []
373
+ return BatchMeta(samples=[sample_meta], extra_info=self.extra_info)
374
+ else:
375
+ raise TypeError(f"Indexing with {type(item)} is not supported now!")
376
+
377
+ def chunk(self, chunks: int) -> list["BatchMeta"]:
378
+ """
379
+ Split this batch into smaller chunks.
380
+
381
+ Args:
382
+ chunks: number of chunks
383
+
384
+ Return:
385
+ List of smaller BatchMeta chunks
386
+ """
387
+ chunk_list = []
388
+ n = len(self.samples)
389
+
390
+ # Calculate the base size and remainder of each chunk
391
+ base_size = n // chunks
392
+ remainder = n % chunks
393
+
394
+ start = 0
395
+ for i in range(chunks):
396
+ # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size)
397
+ current_chunk_size = base_size + 1 if i < remainder else base_size
398
+ end = start + current_chunk_size
399
+ chunk_samples = self.samples[start:end]
400
+ chunk = BatchMeta(samples=chunk_samples, extra_info=self.extra_info.copy())
401
+ chunk_list.append(chunk)
402
+ start = end
403
+ return chunk_list
404
+
405
+ @classmethod
406
+ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> Optional["BatchMeta"]:
407
+ """
408
+ Concatenate multiple BatchMeta chunks into one large batch.
409
+
410
+ Args:
411
+ data: List of BatchMeta chunks to concatenate
412
+ validate: Whether to validate concatenation conditions
413
+
414
+ Returns:
415
+ Concatenated BatchMeta
416
+
417
+ Raises:
418
+ ValueError: If validation fails (e.g., field names do not match)
419
+ """
420
+ if not data:
421
+ return None
422
+
423
+ if validate:
424
+ base_fields = data[0].field_names
425
+
426
+ for chunk in data:
427
+ if chunk.field_names != base_fields:
428
+ raise ValueError("Error: Field names do not match for concatenation.")
429
+
430
+ # Combine all samples
431
+ all_samples = []
432
+ for chunk in data:
433
+ all_samples.extend(chunk.samples)
434
+ # Merge all extra_info dictionaries from the chunks
435
+ merged_extra_info = {}
436
+ for chunk in data:
437
+ merged_extra_info.update(chunk.extra_info)
438
+ return BatchMeta(samples=all_samples, extra_info=merged_extra_info)
439
+
440
+ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]:
441
+ """
442
+ Create a union of this batch's fields with another batch's fields.
443
+ Assume both batches have the same global indices. If fields overlap, the
444
+ fields in this batch will be replaced by the other batch's fields.
445
+
446
+ Args:
447
+ other: Another BatchMeta to union with
448
+ validate: Whether to validate union conditions
449
+
450
+ Returns:
451
+ New BatchMeta with unioned fields
452
+
453
+ Raises:
454
+ ValueError: If validation fails (e.g., batch sizes or global indexes do not match)
455
+ """
456
+ if validate:
457
+ if self.size != other.size:
458
+ raise ValueError("Error: Batch sizes do not match for union.")
459
+
460
+ self_global_indexes = sorted(self.global_indexes)
461
+ other_global_indexes = sorted(other.global_indexes)
462
+ if self_global_indexes != other_global_indexes:
463
+ raise ValueError("Error: Global indexes do not match for union.")
464
+
465
+ # Create a mapping from global_index to SampleMeta in the other batch
466
+ other_sample_map = {sample.global_index: sample for sample in other.samples}
467
+
468
+ # Merge samples
469
+ merged_samples = []
470
+ for sample in self.samples:
471
+ if sample.global_index in other_sample_map:
472
+ other_sample = other_sample_map[sample.global_index]
473
+ merged_sample = sample.union(other_sample, validate=validate)
474
+ merged_samples.append(merged_sample)
475
+ else:
476
+ merged_samples.append(sample)
477
+
478
+ # Merge extra info dictionaries
479
+ merged_extra_info = {**self.extra_info, **other.extra_info}
480
+
481
+ return BatchMeta(samples=merged_samples, extra_info=merged_extra_info)
482
+
483
+ def reorder(self, indices: list[int]):
484
+ """
485
+ Reorder the SampleMeta in the BatchMeta according to the given indices.
486
+
487
+ The operation is performed in-place, modifying the current BatchMeta's SampleMeta order.
488
+
489
+ Args:
490
+ indices : list[int]
491
+ A list of integers specifying the new order of SampleMeta. Each integer
492
+ represents the current index of the SampleMeta in the BatchMeta.
493
+ """
494
+ # Reorder the samples
495
+ reordered_samples = [self.samples[i] for i in indices]
496
+ object.__setattr__(self, "samples", reordered_samples)
497
+
498
+ # Update necessary attributes
499
+ self._update_after_reorder()
500
+
501
+ def _update_after_reorder(self) -> None:
502
+ """Update related attributes specifically for the reorder operation"""
503
+ # Update batch_index for each sample
504
+ for idx, sample in enumerate(self.samples):
505
+ object.__setattr__(sample, "_batch_index", idx)
506
+
507
+ # Update cached index lists
508
+ object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
509
+ object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples])
510
+ object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples])
511
+
512
+ # Note: No need to rebuild storage_meta_groups as samples' storage_id remain unchanged
513
+ # and their order does not affect the grouping
514
+ # storage_meta_groups = self._build_storage_meta_groups()
515
+ # object.__setattr__(self, "_storage_meta_groups", storage_meta_groups)
516
+
517
+ # Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder
518
+
519
+ @classmethod
520
+ def from_samples(
521
+ cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None
522
+ ) -> "BatchMeta":
523
+ """
524
+ Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects.
525
+
526
+ Args:
527
+ samples: A single SampleMeta or a list of SampleMeta objects
528
+ extra_info: Optional additional information to store with the batch
529
+
530
+ Returns:
531
+ BatchMeta instance containing the provided sample(s)
532
+
533
+ Example:
534
+ >>> sample_meta = SampleMeta(...)
535
+ >>> batch_meta = BatchMeta.from_samples(sample_meta)
536
+
537
+ >>> sample_metas = [sample1, sample2, sample3]
538
+ >>> batch_meta = BatchMeta.from_samples(sample_metas, extra_info={"source": "training"})
539
+ """
540
+ if extra_info is None:
541
+ extra_info = {}
542
+
543
+ if isinstance(samples, SampleMeta):
544
+ samples = [samples]
545
+
546
+ return cls(samples=samples, extra_info=extra_info)
547
+
548
+ @classmethod
549
+ def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta":
550
+ """
551
+ Create an empty BatchMeta with no samples.
552
+
553
+ Args:
554
+ extra_info: Optional additional information to store with the batch
555
+
556
+ Returns:
557
+ Empty BatchMeta instance
558
+
559
+ Example:
560
+ >>> empty_batch = BatchMeta.empty()
561
+ """
562
+ if extra_info is None:
563
+ extra_info = {}
564
+ return cls(samples=[], extra_info=extra_info)
565
+
566
+
567
+ def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]:
568
+ """Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2."""
569
+ for name in fields2.keys():
570
+ fields1[name] = fields2[name]
571
+ return fields1
572
+
573
+
574
+ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> list[dict[str, FieldMeta]]:
575
+ """
576
+ Extract field metas from a TensorDict. If data in tensor_dict does not have dtype or shape attribute,
577
+ the corresponding dtype or shape will be set to None.
578
+
579
+ Args:
580
+ tensor_dict (TensorDict): The input TensorDict.
581
+ set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME.
582
+ Otherwise, set to NOT_PRODUCED. Default is True.
583
+
584
+ Returns:
585
+ all_fields (list[dict[str, FieldMeta]]): A list of dictionaries containing field metadata.
586
+ """
587
+ all_fields = []
588
+ batch_size = tensor_dict.batch_size[0]
589
+ for idx in range(batch_size):
590
+ fields = {}
591
+ sample = tensor_dict[idx]
592
+ for name, value in sample.items():
593
+ fields[name] = FieldMeta(
594
+ name=name,
595
+ dtype=value.dtype if hasattr(value, "dtype") else None,
596
+ shape=value.shape if hasattr(value, "shape") else None,
597
+ production_status=ProductionStatus.READY_FOR_CONSUME
598
+ if set_all_ready
599
+ else ProductionStatus.NOT_PRODUCED,
600
+ )
601
+ all_fields.append(fields)
602
+
603
+ return all_fields