atdata 0.2.2b1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/dataset.py CHANGED
@@ -13,18 +13,16 @@ The implementation handles automatic conversion between numpy arrays and bytes
13
13
  during serialization, enabling efficient storage of numerical data in WebDataset
14
14
  archives.
15
15
 
16
- Example:
17
- ::
18
-
19
- >>> @packable
20
- ... class ImageSample:
21
- ... image: NDArray
22
- ... label: str
23
- ...
24
- >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
25
- >>> for batch in ds.shuffled(batch_size=32):
26
- ... images = batch.image # Stacked numpy array (32, H, W, C)
27
- ... labels = batch.label # List of 32 strings
16
+ Examples:
17
+ >>> @packable
18
+ ... class ImageSample:
19
+ ... image: NDArray
20
+ ... label: str
21
+ ...
22
+ >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
23
+ >>> for batch in ds.shuffled(batch_size=32):
24
+ ... images = batch.image # Stacked numpy array (32, H, W, C)
25
+ ... labels = batch.label # List of 32 strings
28
26
  """
29
27
 
30
28
  ##
@@ -43,7 +41,7 @@ from dataclasses import (
43
41
  )
44
42
  from abc import ABC
45
43
 
46
- from ._sources import URLSource, S3Source
44
+ from ._sources import URLSource
47
45
  from ._protocols import DataSource
48
46
 
49
47
  from tqdm import tqdm
@@ -66,6 +64,7 @@ from typing import (
66
64
  TypeVar,
67
65
  TypeAlias,
68
66
  dataclass_transform,
67
+ overload,
69
68
  )
70
69
  from numpy.typing import NDArray
71
70
 
@@ -85,30 +84,31 @@ WDSRawSample: TypeAlias = Dict[str, Any]
85
84
  WDSRawBatch: TypeAlias = Dict[str, Any]
86
85
 
87
86
  SampleExportRow: TypeAlias = Dict[str, Any]
88
- SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
87
+ SampleExportMap: TypeAlias = Callable[["PackableSample"], SampleExportRow]
89
88
 
90
89
 
91
90
  ##
92
91
  # Main base classes
93
92
 
94
- DT = TypeVar( 'DT' )
93
+ DT = TypeVar("DT")
95
94
 
96
95
 
97
- def _make_packable( x ):
96
+ def _make_packable(x):
98
97
  """Convert numpy arrays to bytes; pass through other values unchanged."""
99
- if isinstance( x, np.ndarray ):
100
- return eh.array_to_bytes( x )
98
+ if isinstance(x, np.ndarray):
99
+ return eh.array_to_bytes(x)
101
100
  return x
102
101
 
103
102
 
104
- def _is_possibly_ndarray_type( t ):
103
+ def _is_possibly_ndarray_type(t):
105
104
  """Return True if type annotation is NDArray or Optional[NDArray]."""
106
105
  if t == NDArray:
107
106
  return True
108
- if isinstance( t, types.UnionType ):
109
- return any( x == NDArray for x in t.__args__ )
107
+ if isinstance(t, types.UnionType):
108
+ return any(x == NDArray for x in t.__args__)
110
109
  return False
111
110
 
111
+
112
112
  class DictSample:
113
113
  """Dynamic sample type providing dict-like access to raw msgpack data.
114
114
 
@@ -126,24 +126,22 @@ class DictSample:
126
126
  ``@packable``-decorated class. Every ``@packable`` class automatically
127
127
  registers a lens from ``DictSample``, making this conversion seamless.
128
128
 
129
- Example:
130
- ::
131
-
132
- >>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
133
- >>> for sample in ds.ordered():
134
- ... print(sample.some_field) # Attribute access
135
- ... print(sample["other_field"]) # Dict access
136
- ... print(sample.keys()) # Inspect available fields
137
- ...
138
- >>> # Convert to typed schema
139
- >>> typed_ds = ds.as_type(MyTypedSample)
129
+ Examples:
130
+ >>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
131
+ >>> for sample in ds.ordered():
132
+ ... print(sample.some_field) # Attribute access
133
+ ... print(sample["other_field"]) # Dict access
134
+ ... print(sample.keys()) # Inspect available fields
135
+ ...
136
+ >>> # Convert to typed schema
137
+ >>> typed_ds = ds.as_type(MyTypedSample)
140
138
 
141
139
  Note:
142
140
  NDArray fields are stored as raw bytes in DictSample. They are only
143
141
  converted to numpy arrays when accessed through a typed sample class.
144
142
  """
145
143
 
146
- __slots__ = ('_data',)
144
+ __slots__ = ("_data",)
147
145
 
148
146
  def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None:
149
147
  """Create a DictSample from a dictionary or keyword arguments.
@@ -153,12 +151,12 @@ class DictSample:
153
151
  **kwargs: Field values if _data is not provided.
154
152
  """
155
153
  if _data is not None:
156
- object.__setattr__(self, '_data', _data)
154
+ object.__setattr__(self, "_data", _data)
157
155
  else:
158
- object.__setattr__(self, '_data', kwargs)
156
+ object.__setattr__(self, "_data", kwargs)
159
157
 
160
158
  @classmethod
161
- def from_data(cls, data: dict[str, Any]) -> 'DictSample':
159
+ def from_data(cls, data: dict[str, Any]) -> "DictSample":
162
160
  """Create a DictSample from unpacked msgpack data.
163
161
 
164
162
  Args:
@@ -170,7 +168,7 @@ class DictSample:
170
168
  return cls(_data=data)
171
169
 
172
170
  @classmethod
173
- def from_bytes(cls, bs: bytes) -> 'DictSample':
171
+ def from_bytes(cls, bs: bytes) -> "DictSample":
174
172
  """Create a DictSample from raw msgpack bytes.
175
173
 
176
174
  Args:
@@ -194,7 +192,7 @@ class DictSample:
194
192
  AttributeError: If the field doesn't exist.
195
193
  """
196
194
  # Avoid infinite recursion for _data lookup
197
- if name == '_data':
195
+ if name == "_data":
198
196
  raise AttributeError(name)
199
197
  try:
200
198
  return self._data[name]
@@ -260,24 +258,24 @@ class DictSample:
260
258
  return msgpack.packb(self._data)
261
259
 
262
260
  @property
263
- def as_wds(self) -> 'WDSRawSample':
261
+ def as_wds(self) -> "WDSRawSample":
264
262
  """Pack this sample's data for writing to WebDataset.
265
263
 
266
264
  Returns:
267
265
  A dictionary with ``__key__`` and ``msgpack`` fields.
268
266
  """
269
267
  return {
270
- '__key__': str(uuid.uuid1(0, 0)),
271
- 'msgpack': self.packed,
268
+ "__key__": str(uuid.uuid1(0, 0)),
269
+ "msgpack": self.packed,
272
270
  }
273
271
 
274
272
  def __repr__(self) -> str:
275
- fields = ', '.join(f'{k}=...' for k in self._data.keys())
276
- return f'DictSample({fields})'
273
+ fields = ", ".join(f"{k}=..." for k in self._data.keys())
274
+ return f"DictSample({fields})"
277
275
 
278
276
 
279
277
  @dataclass
280
- class PackableSample( ABC ):
278
+ class PackableSample(ABC):
281
279
  """Base class for samples that can be serialized with msgpack.
282
280
 
283
281
  This abstract base class provides automatic serialization/deserialization
@@ -289,54 +287,52 @@ class PackableSample( ABC ):
289
287
  1. Direct inheritance with the ``@dataclass`` decorator
290
288
  2. Using the ``@packable`` decorator (recommended)
291
289
 
292
- Example:
293
- ::
294
-
295
- >>> @packable
296
- ... class MyData:
297
- ... name: str
298
- ... embeddings: NDArray
299
- ...
300
- >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
301
- >>> packed = sample.packed # Serialize to bytes
302
- >>> restored = MyData.from_bytes(packed) # Deserialize
290
+ Examples:
291
+ >>> @packable
292
+ ... class MyData:
293
+ ... name: str
294
+ ... embeddings: NDArray
295
+ ...
296
+ >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
297
+ >>> packed = sample.packed # Serialize to bytes
298
+ >>> restored = MyData.from_bytes(packed) # Deserialize
303
299
  """
304
300
 
305
- def _ensure_good( self ):
301
+ def _ensure_good(self):
306
302
  """Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
307
303
 
308
304
  # Auto-convert known types when annotated
309
305
  # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
310
- for field in dataclasses.fields( self ):
306
+ for field in dataclasses.fields(self):
311
307
  var_name = field.name
312
308
  var_type = field.type
313
309
 
314
310
  # Annotation for this variable is to be an NDArray
315
- if _is_possibly_ndarray_type( var_type ):
311
+ if _is_possibly_ndarray_type(var_type):
316
312
  # ... so, we'll always auto-convert to numpy
317
313
 
318
- var_cur_value = getattr( self, var_name )
314
+ var_cur_value = getattr(self, var_name)
319
315
 
320
316
  # Execute the appropriate conversion for intermediate data
321
317
  # based on what is provided
322
318
 
323
- if isinstance( var_cur_value, np.ndarray ):
319
+ if isinstance(var_cur_value, np.ndarray):
324
320
  # Already the correct type, no conversion needed
325
321
  continue
326
322
 
327
- elif isinstance( var_cur_value, bytes ):
323
+ elif isinstance(var_cur_value, bytes):
328
324
  # Design note: bytes in NDArray-typed fields are always interpreted
329
325
  # as serialized arrays. This means raw bytes fields must not be
330
326
  # annotated as NDArray.
331
- setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
327
+ setattr(self, var_name, eh.bytes_to_array(var_cur_value))
332
328
 
333
- def __post_init__( self ):
329
+ def __post_init__(self):
334
330
  self._ensure_good()
335
331
 
336
332
  ##
337
333
 
338
334
  @classmethod
339
- def from_data( cls, data: WDSRawSample ) -> Self:
335
+ def from_data(cls, data: WDSRawSample) -> Self:
340
336
  """Create a sample instance from unpacked msgpack data.
341
337
 
342
338
  Args:
@@ -345,10 +341,10 @@ class PackableSample( ABC ):
345
341
  Returns:
346
342
  New instance with NDArray fields auto-converted from bytes.
347
343
  """
348
- return cls( **data )
349
-
344
+ return cls(**data)
345
+
350
346
  @classmethod
351
- def from_bytes( cls, bs: bytes ) -> Self:
347
+ def from_bytes(cls, bs: bytes) -> Self:
352
348
  """Create a sample instance from raw msgpack bytes.
353
349
 
354
350
  Args:
@@ -357,10 +353,10 @@ class PackableSample( ABC ):
357
353
  Returns:
358
354
  A new instance of this sample class deserialized from the bytes.
359
355
  """
360
- return cls.from_data( ormsgpack.unpackb( bs ) )
356
+ return cls.from_data(ormsgpack.unpackb(bs))
361
357
 
362
358
  @property
363
- def packed( self ) -> bytes:
359
+ def packed(self) -> bytes:
364
360
  """Pack this sample's data into msgpack bytes.
365
361
 
366
362
  NDArray fields are automatically converted to bytes before packing.
@@ -375,20 +371,17 @@ class PackableSample( ABC ):
375
371
 
376
372
  # Make sure that all of our (possibly unpackable) data is in a packable
377
373
  # format
378
- o = {
379
- k: _make_packable( v )
380
- for k, v in vars( self ).items()
381
- }
374
+ o = {k: _make_packable(v) for k, v in vars(self).items()}
382
375
 
383
- ret = msgpack.packb( o )
376
+ ret = msgpack.packb(o)
384
377
 
385
378
  if ret is None:
386
- raise RuntimeError( f'Failed to pack sample to bytes: {o}' )
379
+ raise RuntimeError(f"Failed to pack sample to bytes: {o}")
387
380
 
388
381
  return ret
389
-
382
+
390
383
  @property
391
- def as_wds( self ) -> WDSRawSample:
384
+ def as_wds(self) -> WDSRawSample:
392
385
  """Pack this sample's data for writing to WebDataset.
393
386
 
394
387
  Returns:
@@ -401,19 +394,21 @@ class PackableSample( ABC ):
401
394
  """
402
395
  return {
403
396
  # Generates a UUID that is timelike-sortable
404
- '__key__': str( uuid.uuid1( 0, 0 ) ),
405
- 'msgpack': self.packed,
397
+ "__key__": str(uuid.uuid1(0, 0)),
398
+ "msgpack": self.packed,
406
399
  }
407
400
 
408
- def _batch_aggregate( xs: Sequence ):
401
+
402
+ def _batch_aggregate(xs: Sequence):
409
403
  """Stack arrays into numpy array with batch dim; otherwise return list."""
410
404
  if not xs:
411
405
  return []
412
- if isinstance( xs[0], np.ndarray ):
413
- return np.array( list( xs ) )
414
- return list( xs )
406
+ if isinstance(xs[0], np.ndarray):
407
+ return np.array(list(xs))
408
+ return list(xs)
409
+
415
410
 
416
- class SampleBatch( Generic[DT] ):
411
+ class SampleBatch(Generic[DT]):
417
412
  """A batch of samples with automatic attribute aggregation.
418
413
 
419
414
  This class wraps a sequence of samples and provides magic ``__getattr__``
@@ -430,12 +425,10 @@ class SampleBatch( Generic[DT] ):
430
425
  Attributes:
431
426
  samples: The list of sample instances in this batch.
432
427
 
433
- Example:
434
- ::
435
-
436
- >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
437
- >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
438
- >>> batch.names # Returns list of names
428
+ Examples:
429
+ >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
430
+ >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
431
+ >>> batch.names # Returns list of names
439
432
 
440
433
  Note:
441
434
  This class uses Python's ``__orig_class__`` mechanism to extract the
@@ -443,10 +436,11 @@ class SampleBatch( Generic[DT] ):
443
436
  subscripted syntax ``SampleBatch[MyType](samples)`` rather than
444
437
  calling the constructor directly with an unsubscripted class.
445
438
  """
439
+
446
440
  # Design note: The docstring uses "Parameters:" for type parameters because
447
441
  # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
448
442
 
449
- def __init__( self, samples: Sequence[DT] ):
443
+ def __init__(self, samples: Sequence[DT]):
450
444
  """Create a batch from a sequence of samples.
451
445
 
452
446
  Args:
@@ -454,23 +448,23 @@ class SampleBatch( Generic[DT] ):
454
448
  Each sample must be an instance of a type derived from
455
449
  ``PackableSample``.
456
450
  """
457
- self.samples = list( samples )
451
+ self.samples = list(samples)
458
452
  self._aggregate_cache = dict()
459
453
  self._sample_type_cache: Type | None = None
460
454
 
461
455
  @property
462
- def sample_type( self ) -> Type:
456
+ def sample_type(self) -> Type:
463
457
  """The type of each sample in this batch.
464
458
 
465
459
  Returns:
466
460
  The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
467
461
  """
468
462
  if self._sample_type_cache is None:
469
- self._sample_type_cache = typing.get_args( self.__orig_class__)[0]
463
+ self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
470
464
  assert self._sample_type_cache is not None
471
465
  return self._sample_type_cache
472
466
 
473
- def __getattr__( self, name ):
467
+ def __getattr__(self, name):
474
468
  """Aggregate an attribute across all samples in the batch.
475
469
 
476
470
  This magic method enables attribute-style access to aggregated sample
@@ -487,20 +481,19 @@ class SampleBatch( Generic[DT] ):
487
481
  AttributeError: If the attribute doesn't exist on the sample type.
488
482
  """
489
483
  # Aggregate named params of sample type
490
- if name in vars( self.sample_type )['__annotations__']:
484
+ if name in vars(self.sample_type)["__annotations__"]:
491
485
  if name not in self._aggregate_cache:
492
486
  self._aggregate_cache[name] = _batch_aggregate(
493
- [ getattr( x, name )
494
- for x in self.samples ]
487
+ [getattr(x, name) for x in self.samples]
495
488
  )
496
489
 
497
490
  return self._aggregate_cache[name]
498
491
 
499
- raise AttributeError( f'No sample attribute named {name}' )
492
+ raise AttributeError(f"No sample attribute named {name}")
500
493
 
501
494
 
502
- ST = TypeVar( 'ST', bound = PackableSample )
503
- RT = TypeVar( 'RT', bound = PackableSample )
495
+ ST = TypeVar("ST", bound=PackableSample)
496
+ RT = TypeVar("RT", bound=PackableSample)
504
497
 
505
498
 
506
499
  class _ShardListStage(wds.utils.PipelineStage):
@@ -538,7 +531,7 @@ class _StreamOpenerStage(wds.utils.PipelineStage):
538
531
  yield sample
539
532
 
540
533
 
541
- class Dataset( Generic[ST] ):
534
+ class Dataset(Generic[ST]):
542
535
  """A typed dataset built on WebDataset with lens transformations.
543
536
 
544
537
  This class wraps WebDataset tar archives and provides type-safe iteration
@@ -557,16 +550,14 @@ class Dataset( Generic[ST] ):
557
550
  Attributes:
558
551
  url: WebDataset brace-notation URL for the tar file(s).
559
552
 
560
- Example:
561
- ::
562
-
563
- >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
564
- >>> for sample in ds.ordered(batch_size=32):
565
- ... # sample is SampleBatch[MyData] with batch_size samples
566
- ... embeddings = sample.embeddings # shape: (32, ...)
567
- ...
568
- >>> # Transform to a different view
569
- >>> ds_view = ds.as_type(MyDataView)
553
+ Examples:
554
+ >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
555
+ >>> for sample in ds.ordered(batch_size=32):
556
+ ... # sample is SampleBatch[MyData] with batch_size samples
557
+ ... embeddings = sample.embeddings # shape: (32, ...)
558
+ ...
559
+ >>> # Transform to a different view
560
+ >>> ds_view = ds.as_type(MyDataView)
570
561
 
571
562
  Note:
572
563
  This class uses Python's ``__orig_class__`` mechanism to extract the
@@ -574,22 +565,24 @@ class Dataset( Generic[ST] ):
574
565
  subscripted syntax ``Dataset[MyType](url)`` rather than calling the
575
566
  constructor directly with an unsubscripted class.
576
567
  """
568
+
577
569
  # Design note: The docstring uses "Parameters:" for type parameters because
578
570
  # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
579
571
 
580
572
  @property
581
- def sample_type( self ) -> Type:
573
+ def sample_type(self) -> Type:
582
574
  """The type of each returned sample from this dataset's iterator.
583
575
 
584
576
  Returns:
585
577
  The type parameter ``ST`` used when creating this ``Dataset[ST]``.
586
578
  """
587
579
  if self._sample_type_cache is None:
588
- self._sample_type_cache = typing.get_args( self.__orig_class__ )[0]
580
+ self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
589
581
  assert self._sample_type_cache is not None
590
582
  return self._sample_type_cache
583
+
591
584
  @property
592
- def batch_type( self ) -> Type:
585
+ def batch_type(self) -> Type:
593
586
  """The type of batches produced by this dataset.
594
587
 
595
588
  Returns:
@@ -597,12 +590,13 @@ class Dataset( Generic[ST] ):
597
590
  """
598
591
  return SampleBatch[self.sample_type]
599
592
 
600
- def __init__( self,
601
- source: DataSource | str | None = None,
602
- metadata_url: str | None = None,
603
- *,
604
- url: str | None = None,
605
- ) -> None:
593
+ def __init__(
594
+ self,
595
+ source: DataSource | str | None = None,
596
+ metadata_url: str | None = None,
597
+ *,
598
+ url: str | None = None,
599
+ ) -> None:
606
600
  """Create a dataset from a DataSource or URL.
607
601
 
608
602
  Args:
@@ -650,7 +644,7 @@ class Dataset( Generic[ST] ):
650
644
  """The underlying data source for this dataset."""
651
645
  return self._source
652
646
 
653
- def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
647
+ def as_type(self, other: Type[RT]) -> "Dataset[RT]":
654
648
  """View this dataset through a different sample type using a registered lens.
655
649
 
656
650
  Args:
@@ -666,10 +660,10 @@ class Dataset( Generic[ST] ):
666
660
  ValueError: If no registered lens exists between the current
667
661
  sample type and the target type.
668
662
  """
669
- ret = Dataset[other]( self._source )
663
+ ret = Dataset[other](self._source)
670
664
  # Get the singleton lens registry
671
665
  lenses = LensNetwork()
672
- ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
666
+ ret._output_lens = lenses.transform(self.sample_type, ret.sample_type)
673
667
  return ret
674
668
 
675
669
  @property
@@ -679,11 +673,9 @@ class Dataset( Generic[ST] ):
679
673
  Yields:
680
674
  Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
681
675
 
682
- Example:
683
- ::
684
-
685
- >>> for shard in ds.shards:
686
- ... print(f"Processing {shard}")
676
+ Examples:
677
+ >>> for shard in ds.shards:
678
+ ... print(f"Processing {shard}")
687
679
  """
688
680
  return iter(self._source.list_shards())
689
681
 
@@ -705,6 +697,7 @@ class Dataset( Generic[ST] ):
705
697
  Use :meth:`list_shards` instead.
706
698
  """
707
699
  import warnings
700
+
708
701
  warnings.warn(
709
702
  "shard_list is deprecated, use list_shards() instead",
710
703
  DeprecationWarning,
@@ -713,7 +706,7 @@ class Dataset( Generic[ST] ):
713
706
  return self.list_shards()
714
707
 
715
708
  @property
716
- def metadata( self ) -> dict[str, Any] | None:
709
+ def metadata(self) -> dict[str, Any] | None:
717
710
  """Fetch and cache metadata from metadata_url.
718
711
 
719
712
  Returns:
@@ -726,27 +719,47 @@ class Dataset( Generic[ST] ):
726
719
  return None
727
720
 
728
721
  if self._metadata is None:
729
- with requests.get( self.metadata_url, stream = True ) as response:
722
+ with requests.get(self.metadata_url, stream=True) as response:
730
723
  response.raise_for_status()
731
- self._metadata = msgpack.unpackb( response.content, raw = False )
732
-
724
+ self._metadata = msgpack.unpackb(response.content, raw=False)
725
+
733
726
  # Use our cached values
734
727
  return self._metadata
735
-
736
- def ordered( self,
737
- batch_size: int | None = None,
738
- ) -> Iterable[ST]:
739
- """Iterate over the dataset in order
728
+
729
+ @overload
730
+ def ordered(
731
+ self,
732
+ batch_size: None = None,
733
+ ) -> Iterable[ST]: ...
734
+
735
+ @overload
736
+ def ordered(
737
+ self,
738
+ batch_size: int,
739
+ ) -> Iterable[SampleBatch[ST]]: ...
740
+
741
+ def ordered(
742
+ self,
743
+ batch_size: int | None = None,
744
+ ) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
745
+ """Iterate over the dataset in order.
740
746
 
741
747
  Args:
742
- batch_size (:obj:`int`, optional): The size of iterated batches.
743
- Default: None (unbatched). If ``None``, iterates over one
744
- sample at a time with no batch dimension.
748
+ batch_size: The size of iterated batches. Default: None (unbatched).
749
+ If ``None``, iterates over one sample at a time with no batch
750
+ dimension.
745
751
 
746
752
  Returns:
747
- :obj:`webdataset.DataPipeline` A data pipeline that iterates over
748
- the dataset in its original sample order
753
+ A data pipeline that iterates over the dataset in its original
754
+ sample order. When ``batch_size`` is ``None``, yields individual
755
+ samples of type ``ST``. When ``batch_size`` is an integer, yields
756
+ ``SampleBatch[ST]`` instances containing that many samples.
749
757
 
758
+ Examples:
759
+ >>> for sample in ds.ordered():
760
+ ... process(sample) # sample is ST
761
+ >>> for batch in ds.ordered(batch_size=32):
762
+ ... process(batch) # batch is SampleBatch[ST]
750
763
  """
751
764
  if batch_size is None:
752
765
  return wds.pipeline.DataPipeline(
@@ -755,7 +768,7 @@ class Dataset( Generic[ST] ):
755
768
  _StreamOpenerStage(self._source),
756
769
  wds.tariterators.tar_file_expander,
757
770
  wds.tariterators.group_by_keys,
758
- wds.filters.map( self.wrap ),
771
+ wds.filters.map(self.wrap),
759
772
  )
760
773
 
761
774
  return wds.pipeline.DataPipeline(
@@ -764,15 +777,33 @@ class Dataset( Generic[ST] ):
764
777
  _StreamOpenerStage(self._source),
765
778
  wds.tariterators.tar_file_expander,
766
779
  wds.tariterators.group_by_keys,
767
- wds.filters.batched( batch_size ),
768
- wds.filters.map( self.wrap_batch ),
780
+ wds.filters.batched(batch_size),
781
+ wds.filters.map(self.wrap_batch),
769
782
  )
770
783
 
771
- def shuffled( self,
772
- buffer_shards: int = 100,
773
- buffer_samples: int = 10_000,
774
- batch_size: int | None = None,
775
- ) -> Iterable[ST]:
784
+ @overload
785
+ def shuffled(
786
+ self,
787
+ buffer_shards: int = 100,
788
+ buffer_samples: int = 10_000,
789
+ batch_size: None = None,
790
+ ) -> Iterable[ST]: ...
791
+
792
+ @overload
793
+ def shuffled(
794
+ self,
795
+ buffer_shards: int = 100,
796
+ buffer_samples: int = 10_000,
797
+ *,
798
+ batch_size: int,
799
+ ) -> Iterable[SampleBatch[ST]]: ...
800
+
801
+ def shuffled(
802
+ self,
803
+ buffer_shards: int = 100,
804
+ buffer_samples: int = 10_000,
805
+ batch_size: int | None = None,
806
+ ) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
776
807
  """Iterate over the dataset in random order.
777
808
 
778
809
  Args:
@@ -787,42 +818,50 @@ class Dataset( Generic[ST] ):
787
818
  dimension.
788
819
 
789
820
  Returns:
790
- A WebDataset data pipeline that iterates over the dataset in
791
- randomized order. If ``batch_size`` is not ``None``, yields
792
- ``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
793
- samples.
821
+ A data pipeline that iterates over the dataset in randomized order.
822
+ When ``batch_size`` is ``None``, yields individual samples of type
823
+ ``ST``. When ``batch_size`` is an integer, yields ``SampleBatch[ST]``
824
+ instances containing that many samples.
825
+
826
+ Examples:
827
+ >>> for sample in ds.shuffled():
828
+ ... process(sample) # sample is ST
829
+ >>> for batch in ds.shuffled(batch_size=32):
830
+ ... process(batch) # batch is SampleBatch[ST]
794
831
  """
795
832
  if batch_size is None:
796
833
  return wds.pipeline.DataPipeline(
797
834
  _ShardListStage(self._source),
798
- wds.filters.shuffle( buffer_shards ),
835
+ wds.filters.shuffle(buffer_shards),
799
836
  wds.shardlists.split_by_worker,
800
837
  _StreamOpenerStage(self._source),
801
838
  wds.tariterators.tar_file_expander,
802
839
  wds.tariterators.group_by_keys,
803
- wds.filters.shuffle( buffer_samples ),
804
- wds.filters.map( self.wrap ),
840
+ wds.filters.shuffle(buffer_samples),
841
+ wds.filters.map(self.wrap),
805
842
  )
806
843
 
807
844
  return wds.pipeline.DataPipeline(
808
845
  _ShardListStage(self._source),
809
- wds.filters.shuffle( buffer_shards ),
846
+ wds.filters.shuffle(buffer_shards),
810
847
  wds.shardlists.split_by_worker,
811
848
  _StreamOpenerStage(self._source),
812
849
  wds.tariterators.tar_file_expander,
813
850
  wds.tariterators.group_by_keys,
814
- wds.filters.shuffle( buffer_samples ),
815
- wds.filters.batched( batch_size ),
816
- wds.filters.map( self.wrap_batch ),
851
+ wds.filters.shuffle(buffer_samples),
852
+ wds.filters.batched(batch_size),
853
+ wds.filters.map(self.wrap_batch),
817
854
  )
818
-
855
+
819
856
  # Design note: Uses pandas for parquet export. Could be replaced with
820
857
  # direct fastparquet calls to reduce dependencies if needed.
821
- def to_parquet( self, path: Pathlike,
822
- sample_map: Optional[SampleExportMap] = None,
823
- maxcount: Optional[int] = None,
824
- **kwargs,
825
- ):
858
+ def to_parquet(
859
+ self,
860
+ path: Pathlike,
861
+ sample_map: Optional[SampleExportMap] = None,
862
+ maxcount: Optional[int] = None,
863
+ **kwargs,
864
+ ):
826
865
  """Export dataset contents to parquet format.
827
866
 
828
867
  Converts all samples to a pandas DataFrame and saves to parquet file(s).
@@ -851,63 +890,62 @@ class Dataset( Generic[ST] ):
851
890
  This creates multiple parquet files: ``output-000000.parquet``,
852
891
  ``output-000001.parquet``, etc.
853
892
 
854
- Example:
855
- ::
856
-
857
- >>> ds = Dataset[MySample]("data.tar")
858
- >>> # Small dataset - load all at once
859
- >>> ds.to_parquet("output.parquet")
860
- >>>
861
- >>> # Large dataset - process in chunks
862
- >>> ds.to_parquet("output.parquet", maxcount=50000)
893
+ Examples:
894
+ >>> ds = Dataset[MySample]("data.tar")
895
+ >>> # Small dataset - load all at once
896
+ >>> ds.to_parquet("output.parquet")
897
+ >>>
898
+ >>> # Large dataset - process in chunks
899
+ >>> ds.to_parquet("output.parquet", maxcount=50000)
863
900
  """
864
901
  ##
865
902
 
866
903
  # Normalize args
867
- path = Path( path )
904
+ path = Path(path)
868
905
  if sample_map is None:
869
906
  sample_map = asdict
870
-
871
- verbose = kwargs.get( 'verbose', False )
872
907
 
873
- it = self.ordered( batch_size = None )
908
+ verbose = kwargs.get("verbose", False)
909
+
910
+ it = self.ordered(batch_size=None)
874
911
  if verbose:
875
- it = tqdm( it )
912
+ it = tqdm(it)
876
913
 
877
914
  #
878
915
 
879
916
  if maxcount is None:
880
917
  # Load and save full dataset
881
- df = pd.DataFrame( [ sample_map( x )
882
- for x in self.ordered( batch_size = None ) ] )
883
- df.to_parquet( path, **kwargs )
884
-
918
+ df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)])
919
+ df.to_parquet(path, **kwargs)
920
+
885
921
  else:
886
922
  # Load and save dataset in segments of size `maxcount`
887
923
 
888
924
  cur_segment = 0
889
925
  cur_buffer = []
890
- path_template = (path.parent / f'{path.stem}-{{:06d}}{path.suffix}').as_posix()
926
+ path_template = (
927
+ path.parent / f"{path.stem}-{{:06d}}{path.suffix}"
928
+ ).as_posix()
891
929
 
892
- for x in self.ordered( batch_size = None ):
893
- cur_buffer.append( sample_map( x ) )
930
+ for x in self.ordered(batch_size=None):
931
+ cur_buffer.append(sample_map(x))
894
932
 
895
- if len( cur_buffer ) >= maxcount:
933
+ if len(cur_buffer) >= maxcount:
896
934
  # Write current segment
897
- cur_path = path_template.format( cur_segment )
898
- df = pd.DataFrame( cur_buffer )
899
- df.to_parquet( cur_path, **kwargs )
935
+ cur_path = path_template.format(cur_segment)
936
+ df = pd.DataFrame(cur_buffer)
937
+ df.to_parquet(cur_path, **kwargs)
900
938
 
901
939
  cur_segment += 1
902
940
  cur_buffer = []
903
-
904
- if len( cur_buffer ) > 0:
941
+
942
+ if len(cur_buffer) > 0:
905
943
  # Write one last segment with remainder
906
- cur_path = path_template.format( cur_segment )
907
- df = pd.DataFrame( cur_buffer )
908
- df.to_parquet( cur_path, **kwargs )
944
+ cur_path = path_template.format(cur_segment)
945
+ df = pd.DataFrame(cur_buffer)
946
+ df.to_parquet(cur_path, **kwargs)
909
947
 
910
- def wrap( self, sample: WDSRawSample ) -> ST:
948
+ def wrap(self, sample: WDSRawSample) -> ST:
911
949
  """Wrap a raw msgpack sample into the appropriate dataset-specific type.
912
950
 
913
951
  Args:
@@ -918,18 +956,22 @@ class Dataset( Generic[ST] ):
918
956
  A deserialized sample of type ``ST``, optionally transformed through
919
957
  a lens if ``as_type()`` was called.
920
958
  """
921
- if 'msgpack' not in sample:
922
- raise ValueError(f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}")
923
- if not isinstance(sample['msgpack'], bytes):
924
- raise ValueError(f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}")
959
+ if "msgpack" not in sample:
960
+ raise ValueError(
961
+ f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}"
962
+ )
963
+ if not isinstance(sample["msgpack"], bytes):
964
+ raise ValueError(
965
+ f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}"
966
+ )
925
967
 
926
968
  if self._output_lens is None:
927
- return self.sample_type.from_bytes( sample['msgpack'] )
969
+ return self.sample_type.from_bytes(sample["msgpack"])
928
970
 
929
- source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
930
- return self._output_lens( source_sample )
971
+ source_sample = self._output_lens.source_type.from_bytes(sample["msgpack"])
972
+ return self._output_lens(source_sample)
931
973
 
932
- def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
974
+ def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]:
933
975
  """Wrap a batch of raw msgpack samples into a typed SampleBatch.
934
976
 
935
977
  Args:
@@ -945,26 +987,29 @@ class Dataset( Generic[ST] ):
945
987
  aggregates them into a batch.
946
988
  """
947
989
 
948
- if 'msgpack' not in batch:
949
- raise ValueError(f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}")
990
+ if "msgpack" not in batch:
991
+ raise ValueError(
992
+ f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}"
993
+ )
950
994
 
951
995
  if self._output_lens is None:
952
- batch_unpacked = [ self.sample_type.from_bytes( bs )
953
- for bs in batch['msgpack'] ]
954
- return SampleBatch[self.sample_type]( batch_unpacked )
996
+ batch_unpacked = [
997
+ self.sample_type.from_bytes(bs) for bs in batch["msgpack"]
998
+ ]
999
+ return SampleBatch[self.sample_type](batch_unpacked)
955
1000
 
956
- batch_source = [ self._output_lens.source_type.from_bytes( bs )
957
- for bs in batch['msgpack'] ]
958
- batch_view = [ self._output_lens( s )
959
- for s in batch_source ]
960
- return SampleBatch[self.sample_type]( batch_view )
1001
+ batch_source = [
1002
+ self._output_lens.source_type.from_bytes(bs) for bs in batch["msgpack"]
1003
+ ]
1004
+ batch_view = [self._output_lens(s) for s in batch_source]
1005
+ return SampleBatch[self.sample_type](batch_view)
961
1006
 
962
1007
 
963
- _T = TypeVar('_T')
1008
+ _T = TypeVar("_T")
964
1009
 
965
1010
 
966
1011
  @dataclass_transform()
967
- def packable( cls: type[_T] ) -> type[_T]:
1012
+ def packable(cls: type[_T]) -> type[_T]:
968
1013
  """Decorator to convert a regular class into a ``PackableSample``.
969
1014
 
970
1015
  This decorator transforms a class into a dataclass that inherits from
@@ -984,19 +1029,17 @@ def packable( cls: type[_T] ) -> type[_T]:
984
1029
  ``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
985
1030
 
986
1031
  Examples:
987
- This is a test of the functionality::
988
-
989
- @packable
990
- class MyData:
991
- name: str
992
- values: NDArray
993
-
994
- sample = MyData(name="test", values=np.array([1, 2, 3]))
995
- bytes_data = sample.packed
996
- restored = MyData.from_bytes(bytes_data)
997
-
998
- # Works with Packable-typed APIs
999
- index.publish_schema(MyData, version="1.0.0") # Type-safe
1032
+ >>> @packable
1033
+ ... class MyData:
1034
+ ... name: str
1035
+ ... values: NDArray
1036
+ ...
1037
+ >>> sample = MyData(name="test", values=np.array([1, 2, 3]))
1038
+ >>> bytes_data = sample.packed
1039
+ >>> restored = MyData.from_bytes(bytes_data)
1040
+ >>>
1041
+ >>> # Works with Packable-typed APIs
1042
+ >>> index.publish_schema(MyData, version="1.0.0") # Type-safe
1000
1043
  """
1001
1044
 
1002
1045
  ##
@@ -1005,14 +1048,14 @@ def packable( cls: type[_T] ) -> type[_T]:
1005
1048
  class_annotations = cls.__annotations__
1006
1049
 
1007
1050
  # Add in dataclass niceness to original class
1008
- as_dataclass = dataclass( cls )
1051
+ as_dataclass = dataclass(cls)
1009
1052
 
1010
1053
  # This triggers a bunch of behind-the-scenes stuff for the newly annotated class
1011
1054
  @dataclass
1012
- class as_packable( as_dataclass, PackableSample ):
1013
- def __post_init__( self ):
1014
- return PackableSample.__post_init__( self )
1015
-
1055
+ class as_packable(as_dataclass, PackableSample):
1056
+ def __post_init__(self):
1057
+ return PackableSample.__post_init__(self)
1058
+
1016
1059
  # Restore original class identity for better repr/debugging
1017
1060
  as_packable.__name__ = class_name
1018
1061
  as_packable.__qualname__ = class_name
@@ -1023,10 +1066,10 @@ def packable( cls: type[_T] ) -> type[_T]:
1023
1066
 
1024
1067
  # Fix qualnames of dataclass-generated methods so they don't show
1025
1068
  # 'packable.<locals>.as_packable' in help() and IDE hints
1026
- old_qualname_prefix = 'packable.<locals>.as_packable'
1027
- for attr_name in ('__init__', '__repr__', '__eq__', '__post_init__'):
1069
+ old_qualname_prefix = "packable.<locals>.as_packable"
1070
+ for attr_name in ("__init__", "__repr__", "__eq__", "__post_init__"):
1028
1071
  attr = getattr(as_packable, attr_name, None)
1029
- if attr is not None and hasattr(attr, '__qualname__'):
1072
+ if attr is not None and hasattr(attr, "__qualname__"):
1030
1073
  if attr.__qualname__.startswith(old_qualname_prefix):
1031
1074
  attr.__qualname__ = attr.__qualname__.replace(
1032
1075
  old_qualname_prefix, class_name, 1
@@ -1042,4 +1085,4 @@ def packable( cls: type[_T] ) -> type[_T]:
1042
1085
 
1043
1086
  ##
1044
1087
 
1045
- return as_packable
1088
+ return as_packable