atdata 0.2.0a1__py3-none-any.whl → 0.2.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/dataset.py CHANGED
@@ -14,15 +14,17 @@ during serialization, enabling efficient storage of numerical data in WebDataset
14
14
  archives.
15
15
 
16
16
  Example:
17
- >>> @packable
18
- ... class ImageSample:
19
- ... image: NDArray
20
- ... label: str
21
- ...
22
- >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
23
- >>> for batch in ds.shuffled(batch_size=32):
24
- ... images = batch.image # Stacked numpy array (32, H, W, C)
25
- ... labels = batch.label # List of 32 strings
17
+ ::
18
+
19
+ >>> @packable
20
+ ... class ImageSample:
21
+ ... image: NDArray
22
+ ... label: str
23
+ ...
24
+ >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
25
+ >>> for batch in ds.shuffled(batch_size=32):
26
+ ... images = batch.image # Stacked numpy array (32, H, W, C)
27
+ ... labels = batch.label # List of 32 strings
26
28
  """
27
29
 
28
30
  ##
@@ -41,6 +43,9 @@ from dataclasses import (
41
43
  )
42
44
  from abc import ABC
43
45
 
46
+ from ._sources import URLSource, S3Source
47
+ from ._protocols import DataSource
48
+
44
49
  from tqdm import tqdm
45
50
  import numpy as np
46
51
  import pandas as pd
@@ -51,16 +56,16 @@ from typing import (
51
56
  Any,
52
57
  Optional,
53
58
  Dict,
59
+ Iterator,
54
60
  Sequence,
55
61
  Iterable,
56
62
  Callable,
57
- Union,
58
- #
59
63
  Self,
60
64
  Generic,
61
65
  Type,
62
66
  TypeVar,
63
67
  TypeAlias,
68
+ dataclass_transform,
64
69
  )
65
70
  from numpy.typing import NDArray
66
71
 
@@ -75,6 +80,7 @@ from .lens import Lens, LensNetwork
75
80
 
76
81
  Pathlike = str | Path
77
82
 
83
+ # WebDataset sample/batch dictionaries (contain __key__, msgpack, etc.)
78
84
  WDSRawSample: TypeAlias = Dict[str, Any]
79
85
  WDSRawBatch: TypeAlias = Dict[str, Any]
80
86
 
@@ -87,49 +93,189 @@ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
87
93
 
88
94
  DT = TypeVar( 'DT' )
89
95
 
90
- MsgpackRawSample: TypeAlias = Dict[str, Any]
91
-
92
96
 
93
97
  def _make_packable( x ):
94
- """Convert a value to a msgpack-compatible format.
95
-
96
- Args:
97
- x: A value to convert. If it's a numpy array, converts to bytes.
98
- Otherwise returns the value unchanged.
99
-
100
- Returns:
101
- The value in a format suitable for msgpack serialization.
102
- """
98
+ """Convert numpy arrays to bytes; pass through other values unchanged."""
103
99
  if isinstance( x, np.ndarray ):
104
100
  return eh.array_to_bytes( x )
105
101
  return x
106
102
 
107
- def _is_possibly_ndarray_type( t ):
108
- """Check if a type annotation is or contains NDArray.
109
103
 
110
- Args:
111
- t: A type annotation to check.
112
-
113
- Returns:
114
- ``True`` if the type is ``NDArray`` or a union containing ``NDArray``
115
- (e.g., ``NDArray | None``), ``False`` otherwise.
116
- """
117
-
118
- # Directly an NDArray
104
+ def _is_possibly_ndarray_type( t ):
105
+ """Return True if type annotation is NDArray or Optional[NDArray]."""
119
106
  if t == NDArray:
120
- # print( 'is an NDArray' )
121
107
  return True
122
-
123
- # Check for Optionals (i.e., NDArray | None)
124
108
  if isinstance( t, types.UnionType ):
125
- t_parts = t.__args__
126
- if any( x == NDArray
127
- for x in t_parts ):
128
- return True
129
-
130
- # Not an NDArray
109
+ return any( x == NDArray for x in t.__args__ )
131
110
  return False
132
111
 
112
+ class DictSample:
113
+ """Dynamic sample type providing dict-like access to raw msgpack data.
114
+
115
+ This class is the default sample type for datasets when no explicit type is
116
+ specified. It stores the raw unpacked msgpack data and provides both
117
+ attribute-style (``sample.field``) and dict-style (``sample["field"]``)
118
+ access to fields.
119
+
120
+ ``DictSample`` is useful for:
121
+ - Exploring datasets without defining a schema first
122
+ - Working with datasets that have variable schemas
123
+ - Prototyping before committing to a typed schema
124
+
125
+ To convert to a typed schema, use ``Dataset.as_type()`` with a
126
+ ``@packable``-decorated class. Every ``@packable`` class automatically
127
+ registers a lens from ``DictSample``, making this conversion seamless.
128
+
129
+ Example:
130
+ ::
131
+
132
+ >>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
133
+ >>> for sample in ds.ordered():
134
+ ... print(sample.some_field) # Attribute access
135
+ ... print(sample["other_field"]) # Dict access
136
+ ... print(sample.keys()) # Inspect available fields
137
+ ...
138
+ >>> # Convert to typed schema
139
+ >>> typed_ds = ds.as_type(MyTypedSample)
140
+
141
+ Note:
142
+ NDArray fields are stored as raw bytes in DictSample. They are only
143
+ converted to numpy arrays when accessed through a typed sample class.
144
+ """
145
+
146
+ __slots__ = ('_data',)
147
+
148
+ def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None:
149
+ """Create a DictSample from a dictionary or keyword arguments.
150
+
151
+ Args:
152
+ _data: Raw data dictionary. If provided, kwargs are ignored.
153
+ **kwargs: Field values if _data is not provided.
154
+ """
155
+ if _data is not None:
156
+ object.__setattr__(self, '_data', _data)
157
+ else:
158
+ object.__setattr__(self, '_data', kwargs)
159
+
160
+ @classmethod
161
+ def from_data(cls, data: dict[str, Any]) -> 'DictSample':
162
+ """Create a DictSample from unpacked msgpack data.
163
+
164
+ Args:
165
+ data: Dictionary with field names as keys.
166
+
167
+ Returns:
168
+ New DictSample instance wrapping the data.
169
+ """
170
+ return cls(_data=data)
171
+
172
+ @classmethod
173
+ def from_bytes(cls, bs: bytes) -> 'DictSample':
174
+ """Create a DictSample from raw msgpack bytes.
175
+
176
+ Args:
177
+ bs: Raw bytes from a msgpack-serialized sample.
178
+
179
+ Returns:
180
+ New DictSample instance with the unpacked data.
181
+ """
182
+ return cls.from_data(ormsgpack.unpackb(bs))
183
+
184
+ def __getattr__(self, name: str) -> Any:
185
+ """Access a field by attribute name.
186
+
187
+ Args:
188
+ name: Field name to access.
189
+
190
+ Returns:
191
+ The field value.
192
+
193
+ Raises:
194
+ AttributeError: If the field doesn't exist.
195
+ """
196
+ # Avoid infinite recursion for _data lookup
197
+ if name == '_data':
198
+ raise AttributeError(name)
199
+ try:
200
+ return self._data[name]
201
+ except KeyError:
202
+ raise AttributeError(
203
+ f"'{type(self).__name__}' has no field '{name}'. "
204
+ f"Available fields: {list(self._data.keys())}"
205
+ ) from None
206
+
207
+ def __getitem__(self, key: str) -> Any:
208
+ """Access a field by dict key.
209
+
210
+ Args:
211
+ key: Field name to access.
212
+
213
+ Returns:
214
+ The field value.
215
+
216
+ Raises:
217
+ KeyError: If the field doesn't exist.
218
+ """
219
+ return self._data[key]
220
+
221
+ def __contains__(self, key: str) -> bool:
222
+ """Check if a field exists."""
223
+ return key in self._data
224
+
225
+ def keys(self) -> list[str]:
226
+ """Return list of field names."""
227
+ return list(self._data.keys())
228
+
229
+ def values(self) -> list[Any]:
230
+ """Return list of field values."""
231
+ return list(self._data.values())
232
+
233
+ def items(self) -> list[tuple[str, Any]]:
234
+ """Return list of (field_name, value) tuples."""
235
+ return list(self._data.items())
236
+
237
+ def get(self, key: str, default: Any = None) -> Any:
238
+ """Get a field value with optional default.
239
+
240
+ Args:
241
+ key: Field name to access.
242
+ default: Value to return if field doesn't exist.
243
+
244
+ Returns:
245
+ The field value or default.
246
+ """
247
+ return self._data.get(key, default)
248
+
249
+ def to_dict(self) -> dict[str, Any]:
250
+ """Return a copy of the underlying data dictionary."""
251
+ return dict(self._data)
252
+
253
+ @property
254
+ def packed(self) -> bytes:
255
+ """Pack this sample's data into msgpack bytes.
256
+
257
+ Returns:
258
+ Raw msgpack bytes representing this sample's data.
259
+ """
260
+ return msgpack.packb(self._data)
261
+
262
+ @property
263
+ def as_wds(self) -> 'WDSRawSample':
264
+ """Pack this sample's data for writing to WebDataset.
265
+
266
+ Returns:
267
+ A dictionary with ``__key__`` and ``msgpack`` fields.
268
+ """
269
+ return {
270
+ '__key__': str(uuid.uuid1(0, 0)),
271
+ 'msgpack': self.packed,
272
+ }
273
+
274
+ def __repr__(self) -> str:
275
+ fields = ', '.join(f'{k}=...' for k in self._data.keys())
276
+ return f'DictSample({fields})'
277
+
278
+
133
279
  @dataclass
134
280
  class PackableSample( ABC ):
135
281
  """Base class for samples that can be serialized with msgpack.
@@ -144,28 +290,20 @@ class PackableSample( ABC ):
144
290
  2. Using the ``@packable`` decorator (recommended)
145
291
 
146
292
  Example:
147
- >>> @packable
148
- ... class MyData:
149
- ... name: str
150
- ... embeddings: NDArray
151
- ...
152
- >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
153
- >>> packed = sample.packed # Serialize to bytes
154
- >>> restored = MyData.from_bytes(packed) # Deserialize
293
+ ::
294
+
295
+ >>> @packable
296
+ ... class MyData:
297
+ ... name: str
298
+ ... embeddings: NDArray
299
+ ...
300
+ >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
301
+ >>> packed = sample.packed # Serialize to bytes
302
+ >>> restored = MyData.from_bytes(packed) # Deserialize
155
303
  """
156
304
 
157
305
  def _ensure_good( self ):
158
- """Auto-convert annotated NDArray fields from bytes to numpy arrays.
159
-
160
- This method scans all dataclass fields and for any field annotated as
161
- ``NDArray`` or ``NDArray | None``, automatically converts bytes values
162
- to numpy arrays using the helper deserialization function. This enables
163
- transparent handling of array serialization in msgpack data.
164
-
165
- Note:
166
- This is called during ``__post_init__`` to ensure proper type
167
- conversion after deserialization.
168
- """
306
+ """Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
169
307
 
170
308
  # Auto-convert known types when annotated
171
309
  # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
@@ -187,9 +325,9 @@ class PackableSample( ABC ):
187
325
  continue
188
326
 
189
327
  elif isinstance( var_cur_value, bytes ):
190
- # TODO This does create a constraint that serialized bytes
191
- # in a field that might be an NDArray are always interpreted
192
- # as being the NDArray interpretation
328
+ # Design note: bytes in NDArray-typed fields are always interpreted
329
+ # as serialized arrays. This means raw bytes fields must not be
330
+ # annotated as NDArray.
193
331
  setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
194
332
 
195
333
  def __post_init__( self ):
@@ -198,20 +336,16 @@ class PackableSample( ABC ):
198
336
  ##
199
337
 
200
338
  @classmethod
201
- def from_data( cls, data: MsgpackRawSample ) -> Self:
339
+ def from_data( cls, data: WDSRawSample ) -> Self:
202
340
  """Create a sample instance from unpacked msgpack data.
203
341
 
204
342
  Args:
205
- data: A dictionary of unpacked msgpack data with keys matching
206
- the sample's field names.
343
+ data: Dictionary with keys matching the sample's field names.
207
344
 
208
345
  Returns:
209
- A new instance of this sample class with fields populated from
210
- the data dictionary and NDArray fields auto-converted from bytes.
346
+ New instance with NDArray fields auto-converted from bytes.
211
347
  """
212
- ret = cls( **data )
213
- ret._ensure_good()
214
- return ret
348
+ return cls( **data )
215
349
 
216
350
  @classmethod
217
351
  def from_bytes( cls, bs: bytes ) -> Self:
@@ -253,7 +387,6 @@ class PackableSample( ABC ):
253
387
 
254
388
  return ret
255
389
 
256
- # TODO Expand to allow for specifying explicit __key__
257
390
  @property
258
391
  def as_wds( self ) -> WDSRawSample:
259
392
  """Pack this sample's data for writing to WebDataset.
@@ -263,7 +396,8 @@ class PackableSample( ABC ):
263
396
  ``msgpack`` (packed sample data) fields suitable for WebDataset.
264
397
 
265
398
  Note:
266
- TODO: Expand to allow specifying explicit ``__key__`` values.
399
+ Keys are auto-generated as UUID v1 for time-sortable ordering.
400
+ Custom key specification is not currently supported.
267
401
  """
268
402
  return {
269
403
  # Generates a UUID that is timelike-sortable
@@ -272,25 +406,11 @@ class PackableSample( ABC ):
272
406
  }
273
407
 
274
408
  def _batch_aggregate( xs: Sequence ):
275
- """Aggregate a sequence of values into a batch-appropriate format.
276
-
277
- Args:
278
- xs: A sequence of values to aggregate. If the first element is a numpy
279
- array, all elements are stacked into a single array. Otherwise,
280
- returns a list.
281
-
282
- Returns:
283
- A numpy array (if elements are arrays) or a list (otherwise).
284
- """
285
-
409
+ """Stack arrays into numpy array with batch dim; otherwise return list."""
286
410
  if not xs:
287
- # Empty sequence
288
411
  return []
289
-
290
- # Aggregate
291
412
  if isinstance( xs[0], np.ndarray ):
292
413
  return np.array( list( xs ) )
293
-
294
414
  return list( xs )
295
415
 
296
416
  class SampleBatch( Generic[DT] ):
@@ -304,17 +424,27 @@ class SampleBatch( Generic[DT] ):
304
424
  NDArray fields are stacked into a numpy array with a batch dimension.
305
425
  Other fields are aggregated into a list.
306
426
 
307
- Type Parameters:
427
+ Parameters:
308
428
  DT: The sample type, must derive from ``PackableSample``.
309
429
 
310
430
  Attributes:
311
431
  samples: The list of sample instances in this batch.
312
432
 
313
433
  Example:
314
- >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
315
- >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
316
- >>> batch.names # Returns list of names
434
+ ::
435
+
436
+ >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
437
+ >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
438
+ >>> batch.names # Returns list of names
439
+
440
+ Note:
441
+ This class uses Python's ``__orig_class__`` mechanism to extract the
442
+ type parameter at runtime. Instances must be created using the
443
+ subscripted syntax ``SampleBatch[MyType](samples)`` rather than
444
+ calling the constructor directly with an unsubscripted class.
317
445
  """
446
+ # Design note: The docstring uses "Parameters:" for type parameters because
447
+ # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
318
448
 
319
449
  def __init__( self, samples: Sequence[DT] ):
320
450
  """Create a batch from a sequence of samples.
@@ -326,6 +456,7 @@ class SampleBatch( Generic[DT] ):
326
456
  """
327
457
  self.samples = list( samples )
328
458
  self._aggregate_cache = dict()
459
+ self._sample_type_cache: Type | None = None
329
460
 
330
461
  @property
331
462
  def sample_type( self ) -> Type:
@@ -334,7 +465,10 @@ class SampleBatch( Generic[DT] ):
334
465
  Returns:
335
466
  The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
336
467
  """
337
- return typing.get_args( self.__orig_class__)[0]
468
+ if self._sample_type_cache is None:
469
+ self._sample_type_cache = typing.get_args( self.__orig_class__)[0]
470
+ assert self._sample_type_cache is not None
471
+ return self._sample_type_cache
338
472
 
339
473
  def __getattr__( self, name ):
340
474
  """Aggregate an attribute across all samples in the batch.
@@ -368,6 +502,42 @@ class SampleBatch( Generic[DT] ):
368
502
  ST = TypeVar( 'ST', bound = PackableSample )
369
503
  RT = TypeVar( 'RT', bound = PackableSample )
370
504
 
505
+
506
+ class _ShardListStage(wds.utils.PipelineStage):
507
+ """Pipeline stage that yields {url: shard_id} dicts from a DataSource.
508
+
509
+ This is analogous to SimpleShardList but works with any DataSource.
510
+ Used as the first stage before split_by_worker.
511
+ """
512
+
513
+ def __init__(self, source: DataSource):
514
+ self.source = source
515
+
516
+ def run(self):
517
+ """Yield {url: shard_id} dicts for each shard."""
518
+ for shard_id in self.source.list_shards():
519
+ yield {"url": shard_id}
520
+
521
+
522
+ class _StreamOpenerStage(wds.utils.PipelineStage):
523
+ """Pipeline stage that opens streams from a DataSource.
524
+
525
+ Takes {url: shard_id} dicts and adds a stream using source.open_shard().
526
+ This replaces WebDataset's url_opener stage.
527
+ """
528
+
529
+ def __init__(self, source: DataSource):
530
+ self.source = source
531
+
532
+ def run(self, src):
533
+ """Open streams for each shard dict."""
534
+ for sample in src:
535
+ shard_id = sample["url"]
536
+ stream = self.source.open_shard(shard_id)
537
+ sample["stream"] = stream
538
+ yield sample
539
+
540
+
371
541
  class Dataset( Generic[ST] ):
372
542
  """A typed dataset built on WebDataset with lens transformations.
373
543
 
@@ -381,22 +551,31 @@ class Dataset( Generic[ST] ):
381
551
  - Type transformations via the lens system (``as_type()``)
382
552
  - Export to parquet format
383
553
 
384
- Type Parameters:
554
+ Parameters:
385
555
  ST: The sample type for this dataset, must derive from ``PackableSample``.
386
556
 
387
557
  Attributes:
388
558
  url: WebDataset brace-notation URL for the tar file(s).
389
559
 
390
560
  Example:
391
- >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
392
- >>> for sample in ds.ordered(batch_size=32):
393
- ... # sample is SampleBatch[MyData] with batch_size samples
394
- ... embeddings = sample.embeddings # shape: (32, ...)
395
- ...
396
- >>> # Transform to a different view
397
- >>> ds_view = ds.as_type(MyDataView)
398
-
561
+ ::
562
+
563
+ >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
564
+ >>> for sample in ds.ordered(batch_size=32):
565
+ ... # sample is SampleBatch[MyData] with batch_size samples
566
+ ... embeddings = sample.embeddings # shape: (32, ...)
567
+ ...
568
+ >>> # Transform to a different view
569
+ >>> ds_view = ds.as_type(MyDataView)
570
+
571
+ Note:
572
+ This class uses Python's ``__orig_class__`` mechanism to extract the
573
+ type parameter at runtime. Instances must be created using the
574
+ subscripted syntax ``Dataset[MyType](url)`` rather than calling the
575
+ constructor directly with an unsubscripted class.
399
576
  """
577
+ # Design note: The docstring uses "Parameters:" for type parameters because
578
+ # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
400
579
 
401
580
  @property
402
581
  def sample_type( self ) -> Type:
@@ -404,12 +583,11 @@ class Dataset( Generic[ST] ):
404
583
 
405
584
  Returns:
406
585
  The type parameter ``ST`` used when creating this ``Dataset[ST]``.
407
-
408
- Note:
409
- Extracts the type parameter at runtime using ``__orig_class__``.
410
586
  """
411
- # NOTE: Linting may fail here due to __orig_class__ being a runtime attribute
412
- return typing.get_args( self.__orig_class__ )[0]
587
+ if self._sample_type_cache is None:
588
+ self._sample_type_cache = typing.get_args( self.__orig_class__ )[0]
589
+ assert self._sample_type_cache is not None
590
+ return self._sample_type_cache
413
591
  @property
414
592
  def batch_type( self ) -> Type:
415
593
  """The type of batches produced by this dataset.
@@ -419,29 +597,58 @@ class Dataset( Generic[ST] ):
419
597
  """
420
598
  return SampleBatch[self.sample_type]
421
599
 
422
- def __init__( self, url: str,
600
+ def __init__( self,
601
+ source: DataSource | str | None = None,
423
602
  metadata_url: str | None = None,
603
+ *,
604
+ url: str | None = None,
424
605
  ) -> None:
425
- """Create a dataset from a WebDataset URL.
606
+ """Create a dataset from a DataSource or URL.
426
607
 
427
608
  Args:
428
- url: WebDataset brace-notation URL pointing to tar files, e.g.,
429
- ``"path/to/file-{000000..000009}.tar"`` for multiple shards or
430
- ``"path/to/file-000000.tar"`` for a single shard.
609
+ source: Either a DataSource implementation or a WebDataset-compatible
610
+ URL string. If a string is provided, it's wrapped in URLSource
611
+ for backward compatibility.
612
+
613
+ Examples:
614
+ - String URL: ``"path/to/file-{000000..000009}.tar"``
615
+ - URLSource: ``URLSource("https://example.com/data.tar")``
616
+ - S3Source: ``S3Source(bucket="my-bucket", keys=["data.tar"])``
617
+
618
+ metadata_url: Optional URL to msgpack-encoded metadata for this dataset.
619
+ url: Deprecated. Use ``source`` instead. Kept for backward compatibility.
431
620
  """
432
621
  super().__init__()
433
- self.url = url
434
- """WebDataset brace-notation URL pointing to tar files, e.g.,
435
- ``"path/to/file-{000000..000009}.tar"`` for multiple shards or
436
- ``"path/to/file-000000.tar"`` for a single shard.
437
- """
622
+
623
+ # Handle backward compatibility: url= keyword argument
624
+ if source is None and url is not None:
625
+ source = url
626
+ elif source is None:
627
+ raise TypeError("Dataset() missing required argument: 'source' or 'url'")
628
+
629
+ # Normalize source: strings become URLSource for backward compatibility
630
+ if isinstance(source, str):
631
+ self._source: DataSource = URLSource(source)
632
+ self.url = source
633
+ else:
634
+ self._source = source
635
+ # For compatibility, expose URL if source has list_shards
636
+ shards = source.list_shards()
637
+ # Design note: Using first shard as url for legacy compatibility.
638
+ # Full shard list is available via list_shards() method.
639
+ self.url = shards[0] if shards else ""
438
640
 
439
641
  self._metadata: dict[str, Any] | None = None
440
642
  self.metadata_url: str | None = metadata_url
441
643
  """Optional URL to msgpack-encoded metadata for this dataset."""
442
644
 
443
- # Allow addition of automatic transformation of raw underlying data
444
645
  self._output_lens: Lens | None = None
646
+ self._sample_type_cache: Type | None = None
647
+
648
+ @property
649
+ def source(self) -> DataSource:
650
+ """The underlying data source for this dataset."""
651
+ return self._source
445
652
 
446
653
  def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
447
654
  """View this dataset through a different sample type using a registered lens.
@@ -459,25 +666,51 @@ class Dataset( Generic[ST] ):
459
666
  ValueError: If no registered lens exists between the current
460
667
  sample type and the target type.
461
668
  """
462
- ret = Dataset[other]( self.url )
669
+ ret = Dataset[other]( self._source )
463
670
  # Get the singleton lens registry
464
671
  lenses = LensNetwork()
465
672
  ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
466
673
  return ret
467
674
 
468
675
  @property
469
- def shard_list( self ) -> list[str]:
470
- """List of individual dataset shards
471
-
676
+ def shards(self) -> Iterator[str]:
677
+ """Lazily iterate over shard identifiers.
678
+
679
+ Yields:
680
+ Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
681
+
682
+ Example:
683
+ ::
684
+
685
+ >>> for shard in ds.shards:
686
+ ... print(f"Processing {shard}")
687
+ """
688
+ return iter(self._source.list_shards())
689
+
690
+ def list_shards(self) -> list[str]:
691
+ """Get list of individual dataset shards.
692
+
472
693
  Returns:
473
694
  A full (non-lazy) list of the individual ``tar`` files within the
474
695
  source WebDataset.
475
696
  """
476
- pipe = wds.pipeline.DataPipeline(
477
- wds.shardlists.SimpleShardList( self.url ),
478
- wds.filters.map( lambda x: x['url'] )
697
+ return self._source.list_shards()
698
+
699
+ # Legacy alias for backwards compatibility
700
+ @property
701
+ def shard_list(self) -> list[str]:
702
+ """List of individual dataset shards (deprecated, use list_shards()).
703
+
704
+ .. deprecated::
705
+ Use :meth:`list_shards` instead.
706
+ """
707
+ import warnings
708
+ warnings.warn(
709
+ "shard_list is deprecated, use list_shards() instead",
710
+ DeprecationWarning,
711
+ stacklevel=2,
479
712
  )
480
- return list( pipe )
713
+ return self.list_shards()
481
714
 
482
715
  @property
483
716
  def metadata( self ) -> dict[str, Any] | None:
@@ -501,33 +734,36 @@ class Dataset( Generic[ST] ):
501
734
  return self._metadata
502
735
 
503
736
  def ordered( self,
504
- batch_size: int | None = 1,
737
+ batch_size: int | None = None,
505
738
  ) -> Iterable[ST]:
506
739
  """Iterate over the dataset in order
507
-
740
+
508
741
  Args:
509
742
  batch_size (:obj:`int`, optional): The size of iterated batches.
510
- Default: 1. If ``None``, iterates over one sample at a time
511
- with no batch dimension.
512
-
743
+ Default: None (unbatched). If ``None``, iterates over one
744
+ sample at a time with no batch dimension.
745
+
513
746
  Returns:
514
747
  :obj:`webdataset.DataPipeline` A data pipeline that iterates over
515
748
  the dataset in its original sample order
516
-
517
- """
518
749
 
750
+ """
519
751
  if batch_size is None:
520
752
  return wds.pipeline.DataPipeline(
521
- wds.shardlists.SimpleShardList( self.url ),
753
+ _ShardListStage(self._source),
522
754
  wds.shardlists.split_by_worker,
523
- wds.tariterators.tarfile_to_samples(),
755
+ _StreamOpenerStage(self._source),
756
+ wds.tariterators.tar_file_expander,
757
+ wds.tariterators.group_by_keys,
524
758
  wds.filters.map( self.wrap ),
525
759
  )
526
760
 
527
761
  return wds.pipeline.DataPipeline(
528
- wds.shardlists.SimpleShardList( self.url ),
762
+ _ShardListStage(self._source),
529
763
  wds.shardlists.split_by_worker,
530
- wds.tariterators.tarfile_to_samples(),
764
+ _StreamOpenerStage(self._source),
765
+ wds.tariterators.tar_file_expander,
766
+ wds.tariterators.group_by_keys,
531
767
  wds.filters.batched( batch_size ),
532
768
  wds.filters.map( self.wrap_batch ),
533
769
  )
@@ -535,7 +771,7 @@ class Dataset( Generic[ST] ):
535
771
  def shuffled( self,
536
772
  buffer_shards: int = 100,
537
773
  buffer_samples: int = 10_000,
538
- batch_size: int | None = 1,
774
+ batch_size: int | None = None,
539
775
  ) -> Iterable[ST]:
540
776
  """Iterate over the dataset in random order.
541
777
 
@@ -546,8 +782,9 @@ class Dataset( Generic[ST] ):
546
782
  buffer_samples: Number of samples to buffer for shuffling within
547
783
  shards. Larger values increase randomness but use more memory.
548
784
  Default: 10,000.
549
- batch_size: The size of iterated batches. Default: 1. If ``None``,
550
- iterates over one sample at a time with no batch dimension.
785
+ batch_size: The size of iterated batches. Default: None (unbatched).
786
+ If ``None``, iterates over one sample at a time with no batch
787
+ dimension.
551
788
 
552
789
  Returns:
553
790
  A WebDataset data pipeline that iterates over the dataset in
@@ -557,34 +794,72 @@ class Dataset( Generic[ST] ):
557
794
  """
558
795
  if batch_size is None:
559
796
  return wds.pipeline.DataPipeline(
560
- wds.shardlists.SimpleShardList( self.url ),
797
+ _ShardListStage(self._source),
561
798
  wds.filters.shuffle( buffer_shards ),
562
799
  wds.shardlists.split_by_worker,
563
- wds.tariterators.tarfile_to_samples(),
800
+ _StreamOpenerStage(self._source),
801
+ wds.tariterators.tar_file_expander,
802
+ wds.tariterators.group_by_keys,
564
803
  wds.filters.shuffle( buffer_samples ),
565
804
  wds.filters.map( self.wrap ),
566
805
  )
567
806
 
568
807
  return wds.pipeline.DataPipeline(
569
- wds.shardlists.SimpleShardList( self.url ),
808
+ _ShardListStage(self._source),
570
809
  wds.filters.shuffle( buffer_shards ),
571
810
  wds.shardlists.split_by_worker,
572
- wds.tariterators.tarfile_to_samples(),
811
+ _StreamOpenerStage(self._source),
812
+ wds.tariterators.tar_file_expander,
813
+ wds.tariterators.group_by_keys,
573
814
  wds.filters.shuffle( buffer_samples ),
574
815
  wds.filters.batched( batch_size ),
575
816
  wds.filters.map( self.wrap_batch ),
576
817
  )
577
818
 
578
- # TODO Rewrite to eliminate `pandas` dependency directly calling
579
- # `fastparquet`
819
+ # Design note: Uses pandas for parquet export. Could be replaced with
820
+ # direct fastparquet calls to reduce dependencies if needed.
580
821
  def to_parquet( self, path: Pathlike,
581
822
  sample_map: Optional[SampleExportMap] = None,
582
823
  maxcount: Optional[int] = None,
583
824
  **kwargs,
584
825
  ):
585
- """Save dataset contents to a `parquet` file at `path`
826
+ """Export dataset contents to parquet format.
827
+
828
+ Converts all samples to a pandas DataFrame and saves to parquet file(s).
829
+ Useful for interoperability with data analysis tools.
586
830
 
587
- `kwargs` sent to `pandas.to_parquet`
831
+ Args:
832
+ path: Output path for the parquet file. If ``maxcount`` is specified,
833
+ files are named ``{stem}-{segment:06d}.parquet``.
834
+ sample_map: Optional function to convert samples to dictionaries.
835
+ Defaults to ``dataclasses.asdict``.
836
+ maxcount: If specified, split output into multiple files with at most
837
+ this many samples each. Recommended for large datasets.
838
+ **kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
839
+ Common options include ``compression``, ``index``, ``engine``.
840
+
841
+ Warning:
842
+ **Memory Usage**: When ``maxcount=None`` (default), this method loads
843
+ the **entire dataset into memory** as a pandas DataFrame before writing.
844
+ For large datasets, this can cause memory exhaustion.
845
+
846
+ For datasets larger than available RAM, always specify ``maxcount``::
847
+
848
+ # Safe for large datasets - processes in chunks
849
+ ds.to_parquet("output.parquet", maxcount=10000)
850
+
851
+ This creates multiple parquet files: ``output-000000.parquet``,
852
+ ``output-000001.parquet``, etc.
853
+
854
+ Example:
855
+ ::
856
+
857
+ >>> ds = Dataset[MySample]("data.tar")
858
+ >>> # Small dataset - load all at once
859
+ >>> ds.to_parquet("output.parquet")
860
+ >>>
861
+ >>> # Large dataset - process in chunks
862
+ >>> ds.to_parquet("output.parquet", maxcount=50000)
588
863
  """
589
864
  ##
590
865
 
@@ -632,7 +907,7 @@ class Dataset( Generic[ST] ):
632
907
  df = pd.DataFrame( cur_buffer )
633
908
  df.to_parquet( cur_path, **kwargs )
634
909
 
635
- def wrap( self, sample: MsgpackRawSample ) -> ST:
910
+ def wrap( self, sample: WDSRawSample ) -> ST:
636
911
  """Wrap a raw msgpack sample into the appropriate dataset-specific type.
637
912
 
638
913
  Args:
@@ -643,9 +918,11 @@ class Dataset( Generic[ST] ):
643
918
  A deserialized sample of type ``ST``, optionally transformed through
644
919
  a lens if ``as_type()`` was called.
645
920
  """
646
- assert 'msgpack' in sample
647
- assert type( sample['msgpack'] ) == bytes
648
-
921
+ if 'msgpack' not in sample:
922
+ raise ValueError(f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}")
923
+ if not isinstance(sample['msgpack'], bytes):
924
+ raise ValueError(f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}")
925
+
649
926
  if self._output_lens is None:
650
927
  return self.sample_type.from_bytes( sample['msgpack'] )
651
928
 
@@ -668,7 +945,8 @@ class Dataset( Generic[ST] ):
668
945
  aggregates them into a batch.
669
946
  """
670
947
 
671
- assert 'msgpack' in batch
948
+ if 'msgpack' not in batch:
949
+ raise ValueError(f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}")
672
950
 
673
951
  if self._output_lens is None:
674
952
  batch_unpacked = [ self.sample_type.from_bytes( bs )
@@ -682,29 +960,43 @@ class Dataset( Generic[ST] ):
682
960
  return SampleBatch[self.sample_type]( batch_view )
683
961
 
684
962
 
685
- def packable( cls ):
963
+ _T = TypeVar('_T')
964
+
965
+
966
+ @dataclass_transform()
967
+ def packable( cls: type[_T] ) -> type[_T]:
686
968
  """Decorator to convert a regular class into a ``PackableSample``.
687
969
 
688
970
  This decorator transforms a class into a dataclass that inherits from
689
971
  ``PackableSample``, enabling automatic msgpack serialization/deserialization
690
972
  with special handling for NDArray fields.
691
973
 
974
+ The resulting class satisfies the ``Packable`` protocol, making it compatible
975
+ with all atdata APIs that accept packable types (e.g., ``publish_schema``,
976
+ lens transformations, etc.).
977
+
692
978
  Args:
693
979
  cls: The class to convert. Should have type annotations for its fields.
694
980
 
695
981
  Returns:
696
982
  A new dataclass that inherits from ``PackableSample`` with the same
697
- name and annotations as the original class.
698
-
699
- Example:
700
- >>> @packable
701
- ... class MyData:
702
- ... name: str
703
- ... values: NDArray
704
- ...
705
- >>> sample = MyData(name="test", values=np.array([1, 2, 3]))
706
- >>> bytes_data = sample.packed
707
- >>> restored = MyData.from_bytes(bytes_data)
983
+ name and annotations as the original class. The class satisfies the
984
+ ``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
985
+
986
+ Examples:
987
+ This is a test of the functionality::
988
+
989
+ @packable
990
+ class MyData:
991
+ name: str
992
+ values: NDArray
993
+
994
+ sample = MyData(name="test", values=np.array([1, 2, 3]))
995
+ bytes_data = sample.packed
996
+ restored = MyData.from_bytes(bytes_data)
997
+
998
+ # Works with Packable-typed APIs
999
+ index.publish_schema(MyData, version="1.0.0") # Type-safe
708
1000
  """
709
1001
 
710
1002
  ##
@@ -721,9 +1013,32 @@ def packable( cls ):
721
1013
  def __post_init__( self ):
722
1014
  return PackableSample.__post_init__( self )
723
1015
 
724
- # TODO This doesn't properly carry over the original
1016
+ # Restore original class identity for better repr/debugging
725
1017
  as_packable.__name__ = class_name
1018
+ as_packable.__qualname__ = class_name
1019
+ as_packable.__module__ = cls.__module__
726
1020
  as_packable.__annotations__ = class_annotations
1021
+ if cls.__doc__:
1022
+ as_packable.__doc__ = cls.__doc__
1023
+
1024
+ # Fix qualnames of dataclass-generated methods so they don't show
1025
+ # 'packable.<locals>.as_packable' in help() and IDE hints
1026
+ old_qualname_prefix = 'packable.<locals>.as_packable'
1027
+ for attr_name in ('__init__', '__repr__', '__eq__', '__post_init__'):
1028
+ attr = getattr(as_packable, attr_name, None)
1029
+ if attr is not None and hasattr(attr, '__qualname__'):
1030
+ if attr.__qualname__.startswith(old_qualname_prefix):
1031
+ attr.__qualname__ = attr.__qualname__.replace(
1032
+ old_qualname_prefix, class_name, 1
1033
+ )
1034
+
1035
+ # Auto-register lens from DictSample to this type
1036
+ # This enables ds.as_type(MyType) when ds is Dataset[DictSample]
1037
+ def _dict_to_typed(ds: DictSample) -> as_packable:
1038
+ return as_packable.from_data(ds._data)
1039
+
1040
+ _dict_lens = Lens(_dict_to_typed)
1041
+ LensNetwork().register(_dict_lens)
727
1042
 
728
1043
  ##
729
1044