atdata 0.2.2b1__py3-none-any.whl → 0.3.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. atdata/.gitignore +1 -0
  2. atdata/__init__.py +31 -1
  3. atdata/_cid.py +29 -35
  4. atdata/_exceptions.py +168 -0
  5. atdata/_helpers.py +33 -17
  6. atdata/_hf_api.py +109 -59
  7. atdata/_logging.py +70 -0
  8. atdata/_protocols.py +74 -132
  9. atdata/_schema_codec.py +38 -41
  10. atdata/_sources.py +57 -64
  11. atdata/_stub_manager.py +31 -26
  12. atdata/_type_utils.py +47 -7
  13. atdata/atmosphere/__init__.py +31 -24
  14. atdata/atmosphere/_types.py +11 -11
  15. atdata/atmosphere/client.py +11 -8
  16. atdata/atmosphere/lens.py +27 -30
  17. atdata/atmosphere/records.py +34 -39
  18. atdata/atmosphere/schema.py +35 -31
  19. atdata/atmosphere/store.py +16 -20
  20. atdata/cli/__init__.py +163 -168
  21. atdata/cli/diagnose.py +12 -8
  22. atdata/cli/inspect.py +69 -0
  23. atdata/cli/local.py +5 -2
  24. atdata/cli/preview.py +63 -0
  25. atdata/cli/schema.py +109 -0
  26. atdata/dataset.py +678 -533
  27. atdata/lens.py +85 -83
  28. atdata/local/__init__.py +71 -0
  29. atdata/local/_entry.py +157 -0
  30. atdata/local/_index.py +940 -0
  31. atdata/local/_repo_legacy.py +218 -0
  32. atdata/local/_s3.py +349 -0
  33. atdata/local/_schema.py +380 -0
  34. atdata/manifest/__init__.py +28 -0
  35. atdata/manifest/_aggregates.py +156 -0
  36. atdata/manifest/_builder.py +163 -0
  37. atdata/manifest/_fields.py +154 -0
  38. atdata/manifest/_manifest.py +146 -0
  39. atdata/manifest/_query.py +150 -0
  40. atdata/manifest/_writer.py +74 -0
  41. atdata/promote.py +20 -24
  42. atdata/providers/__init__.py +25 -0
  43. atdata/providers/_base.py +140 -0
  44. atdata/providers/_factory.py +69 -0
  45. atdata/providers/_postgres.py +214 -0
  46. atdata/providers/_redis.py +171 -0
  47. atdata/providers/_sqlite.py +191 -0
  48. atdata/repository.py +323 -0
  49. atdata/testing.py +337 -0
  50. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/METADATA +5 -1
  51. atdata-0.3.0b1.dist-info/RECORD +54 -0
  52. atdata/local.py +0 -1707
  53. atdata-0.2.2b1.dist-info/RECORD +0 -28
  54. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/WHEEL +0 -0
  55. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/entry_points.txt +0 -0
  56. {atdata-0.2.2b1.dist-info → atdata-0.3.0b1.dist-info}/licenses/LICENSE +0 -0
atdata/dataset.py CHANGED
@@ -13,18 +13,16 @@ The implementation handles automatic conversion between numpy arrays and bytes
13
13
  during serialization, enabling efficient storage of numerical data in WebDataset
14
14
  archives.
15
15
 
16
- Example:
17
- ::
18
-
19
- >>> @packable
20
- ... class ImageSample:
21
- ... image: NDArray
22
- ... label: str
23
- ...
24
- >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
25
- >>> for batch in ds.shuffled(batch_size=32):
26
- ... images = batch.image # Stacked numpy array (32, H, W, C)
27
- ... labels = batch.label # List of 32 strings
16
+ Examples:
17
+ >>> @packable
18
+ ... class ImageSample:
19
+ ... image: NDArray
20
+ ... label: str
21
+ ...
22
+ >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
23
+ >>> for batch in ds.shuffled(batch_size=32):
24
+ ... images = batch.image # Stacked numpy array (32, H, W, C)
25
+ ... labels = batch.label # List of 32 strings
28
26
  """
29
27
 
30
28
  ##
@@ -33,6 +31,7 @@ Example:
33
31
  import webdataset as wds
34
32
 
35
33
  from pathlib import Path
34
+ import itertools
36
35
  import uuid
37
36
 
38
37
  import dataclasses
@@ -43,16 +42,17 @@ from dataclasses import (
43
42
  )
44
43
  from abc import ABC
45
44
 
46
- from ._sources import URLSource, S3Source
47
- from ._protocols import DataSource
45
+ from ._sources import URLSource
46
+ from ._protocols import DataSource, Packable
47
+ from ._exceptions import SampleKeyError, PartialFailureError
48
48
 
49
- from tqdm import tqdm
50
49
  import numpy as np
51
50
  import pandas as pd
52
51
  import requests
53
52
 
54
53
  import typing
55
54
  from typing import (
55
+ TYPE_CHECKING,
56
56
  Any,
57
57
  Optional,
58
58
  Dict,
@@ -66,7 +66,11 @@ from typing import (
66
66
  TypeVar,
67
67
  TypeAlias,
68
68
  dataclass_transform,
69
+ overload,
69
70
  )
71
+
72
+ if TYPE_CHECKING:
73
+ from .manifest._query import SampleLocation
70
74
  from numpy.typing import NDArray
71
75
 
72
76
  import msgpack
@@ -85,30 +89,31 @@ WDSRawSample: TypeAlias = Dict[str, Any]
85
89
  WDSRawBatch: TypeAlias = Dict[str, Any]
86
90
 
87
91
  SampleExportRow: TypeAlias = Dict[str, Any]
88
- SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
92
+ SampleExportMap: TypeAlias = Callable[["PackableSample"], SampleExportRow]
89
93
 
90
94
 
91
95
  ##
92
96
  # Main base classes
93
97
 
94
- DT = TypeVar( 'DT' )
98
+ DT = TypeVar("DT")
95
99
 
96
100
 
97
- def _make_packable( x ):
101
+ def _make_packable(x):
98
102
  """Convert numpy arrays to bytes; pass through other values unchanged."""
99
- if isinstance( x, np.ndarray ):
100
- return eh.array_to_bytes( x )
103
+ if isinstance(x, np.ndarray):
104
+ return eh.array_to_bytes(x)
101
105
  return x
102
106
 
103
107
 
104
- def _is_possibly_ndarray_type( t ):
108
+ def _is_possibly_ndarray_type(t):
105
109
  """Return True if type annotation is NDArray or Optional[NDArray]."""
106
110
  if t == NDArray:
107
111
  return True
108
- if isinstance( t, types.UnionType ):
109
- return any( x == NDArray for x in t.__args__ )
112
+ if isinstance(t, types.UnionType):
113
+ return any(x == NDArray for x in t.__args__)
110
114
  return False
111
115
 
116
+
112
117
  class DictSample:
113
118
  """Dynamic sample type providing dict-like access to raw msgpack data.
114
119
 
@@ -126,24 +131,22 @@ class DictSample:
126
131
  ``@packable``-decorated class. Every ``@packable`` class automatically
127
132
  registers a lens from ``DictSample``, making this conversion seamless.
128
133
 
129
- Example:
130
- ::
131
-
132
- >>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
133
- >>> for sample in ds.ordered():
134
- ... print(sample.some_field) # Attribute access
135
- ... print(sample["other_field"]) # Dict access
136
- ... print(sample.keys()) # Inspect available fields
137
- ...
138
- >>> # Convert to typed schema
139
- >>> typed_ds = ds.as_type(MyTypedSample)
134
+ Examples:
135
+ >>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
136
+ >>> for sample in ds.ordered():
137
+ ... print(sample.some_field) # Attribute access
138
+ ... print(sample["other_field"]) # Dict access
139
+ ... print(sample.keys()) # Inspect available fields
140
+ ...
141
+ >>> # Convert to typed schema
142
+ >>> typed_ds = ds.as_type(MyTypedSample)
140
143
 
141
144
  Note:
142
145
  NDArray fields are stored as raw bytes in DictSample. They are only
143
146
  converted to numpy arrays when accessed through a typed sample class.
144
147
  """
145
148
 
146
- __slots__ = ('_data',)
149
+ __slots__ = ("_data",)
147
150
 
148
151
  def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None:
149
152
  """Create a DictSample from a dictionary or keyword arguments.
@@ -153,48 +156,28 @@ class DictSample:
153
156
  **kwargs: Field values if _data is not provided.
154
157
  """
155
158
  if _data is not None:
156
- object.__setattr__(self, '_data', _data)
159
+ object.__setattr__(self, "_data", _data)
157
160
  else:
158
- object.__setattr__(self, '_data', kwargs)
161
+ object.__setattr__(self, "_data", kwargs)
159
162
 
160
163
  @classmethod
161
- def from_data(cls, data: dict[str, Any]) -> 'DictSample':
162
- """Create a DictSample from unpacked msgpack data.
163
-
164
- Args:
165
- data: Dictionary with field names as keys.
166
-
167
- Returns:
168
- New DictSample instance wrapping the data.
169
- """
164
+ def from_data(cls, data: dict[str, Any]) -> "DictSample":
165
+ """Create a DictSample from unpacked msgpack data."""
170
166
  return cls(_data=data)
171
167
 
172
168
  @classmethod
173
- def from_bytes(cls, bs: bytes) -> 'DictSample':
174
- """Create a DictSample from raw msgpack bytes.
175
-
176
- Args:
177
- bs: Raw bytes from a msgpack-serialized sample.
178
-
179
- Returns:
180
- New DictSample instance with the unpacked data.
181
- """
169
+ def from_bytes(cls, bs: bytes) -> "DictSample":
170
+ """Create a DictSample from raw msgpack bytes."""
182
171
  return cls.from_data(ormsgpack.unpackb(bs))
183
172
 
184
173
  def __getattr__(self, name: str) -> Any:
185
174
  """Access a field by attribute name.
186
175
 
187
- Args:
188
- name: Field name to access.
189
-
190
- Returns:
191
- The field value.
192
-
193
176
  Raises:
194
177
  AttributeError: If the field doesn't exist.
195
178
  """
196
179
  # Avoid infinite recursion for _data lookup
197
- if name == '_data':
180
+ if name == "_data":
198
181
  raise AttributeError(name)
199
182
  try:
200
183
  return self._data[name]
@@ -205,21 +188,9 @@ class DictSample:
205
188
  ) from None
206
189
 
207
190
  def __getitem__(self, key: str) -> Any:
208
- """Access a field by dict key.
209
-
210
- Args:
211
- key: Field name to access.
212
-
213
- Returns:
214
- The field value.
215
-
216
- Raises:
217
- KeyError: If the field doesn't exist.
218
- """
219
191
  return self._data[key]
220
192
 
221
193
  def __contains__(self, key: str) -> bool:
222
- """Check if a field exists."""
223
194
  return key in self._data
224
195
 
225
196
  def keys(self) -> list[str]:
@@ -227,23 +198,13 @@ class DictSample:
227
198
  return list(self._data.keys())
228
199
 
229
200
  def values(self) -> list[Any]:
230
- """Return list of field values."""
231
201
  return list(self._data.values())
232
202
 
233
203
  def items(self) -> list[tuple[str, Any]]:
234
- """Return list of (field_name, value) tuples."""
235
204
  return list(self._data.items())
236
205
 
237
206
  def get(self, key: str, default: Any = None) -> Any:
238
- """Get a field value with optional default.
239
-
240
- Args:
241
- key: Field name to access.
242
- default: Value to return if field doesn't exist.
243
-
244
- Returns:
245
- The field value or default.
246
- """
207
+ """Get a field value, returning *default* if missing."""
247
208
  return self._data.get(key, default)
248
209
 
249
210
  def to_dict(self) -> dict[str, Any]:
@@ -252,32 +213,24 @@ class DictSample:
252
213
 
253
214
  @property
254
215
  def packed(self) -> bytes:
255
- """Pack this sample's data into msgpack bytes.
256
-
257
- Returns:
258
- Raw msgpack bytes representing this sample's data.
259
- """
216
+ """Serialize to msgpack bytes."""
260
217
  return msgpack.packb(self._data)
261
218
 
262
219
  @property
263
- def as_wds(self) -> 'WDSRawSample':
264
- """Pack this sample's data for writing to WebDataset.
265
-
266
- Returns:
267
- A dictionary with ``__key__`` and ``msgpack`` fields.
268
- """
220
+ def as_wds(self) -> "WDSRawSample":
221
+ """Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
269
222
  return {
270
- '__key__': str(uuid.uuid1(0, 0)),
271
- 'msgpack': self.packed,
223
+ "__key__": str(uuid.uuid1(0, 0)),
224
+ "msgpack": self.packed,
272
225
  }
273
226
 
274
227
  def __repr__(self) -> str:
275
- fields = ', '.join(f'{k}=...' for k in self._data.keys())
276
- return f'DictSample({fields})'
228
+ fields = ", ".join(f"{k}=..." for k in self._data.keys())
229
+ return f"DictSample({fields})"
277
230
 
278
231
 
279
232
  @dataclass
280
- class PackableSample( ABC ):
233
+ class PackableSample(ABC):
281
234
  """Base class for samples that can be serialized with msgpack.
282
235
 
283
236
  This abstract base class provides automatic serialization/deserialization
@@ -289,218 +242,122 @@ class PackableSample( ABC ):
289
242
  1. Direct inheritance with the ``@dataclass`` decorator
290
243
  2. Using the ``@packable`` decorator (recommended)
291
244
 
292
- Example:
293
- ::
294
-
295
- >>> @packable
296
- ... class MyData:
297
- ... name: str
298
- ... embeddings: NDArray
299
- ...
300
- >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
301
- >>> packed = sample.packed # Serialize to bytes
302
- >>> restored = MyData.from_bytes(packed) # Deserialize
245
+ Examples:
246
+ >>> @packable
247
+ ... class MyData:
248
+ ... name: str
249
+ ... embeddings: NDArray
250
+ ...
251
+ >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
252
+ >>> packed = sample.packed # Serialize to bytes
253
+ >>> restored = MyData.from_bytes(packed) # Deserialize
303
254
  """
304
255
 
305
- def _ensure_good( self ):
256
+ def _ensure_good(self):
306
257
  """Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
307
-
308
- # Auto-convert known types when annotated
309
- # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
310
- for field in dataclasses.fields( self ):
311
- var_name = field.name
312
- var_type = field.type
313
-
314
- # Annotation for this variable is to be an NDArray
315
- if _is_possibly_ndarray_type( var_type ):
316
- # ... so, we'll always auto-convert to numpy
317
-
318
- var_cur_value = getattr( self, var_name )
319
-
320
- # Execute the appropriate conversion for intermediate data
321
- # based on what is provided
322
-
323
- if isinstance( var_cur_value, np.ndarray ):
324
- # Already the correct type, no conversion needed
258
+ for field in dataclasses.fields(self):
259
+ if _is_possibly_ndarray_type(field.type):
260
+ value = getattr(self, field.name)
261
+ if isinstance(value, np.ndarray):
325
262
  continue
263
+ elif isinstance(value, bytes):
264
+ setattr(self, field.name, eh.bytes_to_array(value))
326
265
 
327
- elif isinstance( var_cur_value, bytes ):
328
- # Design note: bytes in NDArray-typed fields are always interpreted
329
- # as serialized arrays. This means raw bytes fields must not be
330
- # annotated as NDArray.
331
- setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
332
-
333
- def __post_init__( self ):
266
+ def __post_init__(self):
334
267
  self._ensure_good()
335
268
 
336
269
  ##
337
270
 
338
271
  @classmethod
339
- def from_data( cls, data: WDSRawSample ) -> Self:
340
- """Create a sample instance from unpacked msgpack data.
272
+ def from_data(cls, data: WDSRawSample) -> Self:
273
+ """Create an instance from unpacked msgpack data."""
274
+ return cls(**data)
341
275
 
342
- Args:
343
- data: Dictionary with keys matching the sample's field names.
344
-
345
- Returns:
346
- New instance with NDArray fields auto-converted from bytes.
347
- """
348
- return cls( **data )
349
-
350
276
  @classmethod
351
- def from_bytes( cls, bs: bytes ) -> Self:
352
- """Create a sample instance from raw msgpack bytes.
353
-
354
- Args:
355
- bs: Raw bytes from a msgpack-serialized sample.
356
-
357
- Returns:
358
- A new instance of this sample class deserialized from the bytes.
359
- """
360
- return cls.from_data( ormsgpack.unpackb( bs ) )
277
+ def from_bytes(cls, bs: bytes) -> Self:
278
+ """Create an instance from raw msgpack bytes."""
279
+ return cls.from_data(ormsgpack.unpackb(bs))
361
280
 
362
281
  @property
363
- def packed( self ) -> bytes:
364
- """Pack this sample's data into msgpack bytes.
365
-
366
- NDArray fields are automatically converted to bytes before packing.
367
- All other fields are packed as-is if they're msgpack-compatible.
368
-
369
- Returns:
370
- Raw msgpack bytes representing this sample's data.
282
+ def packed(self) -> bytes:
283
+ """Serialize to msgpack bytes. NDArray fields are auto-converted.
371
284
 
372
285
  Raises:
373
286
  RuntimeError: If msgpack serialization fails.
374
287
  """
375
-
376
- # Make sure that all of our (possibly unpackable) data is in a packable
377
- # format
378
- o = {
379
- k: _make_packable( v )
380
- for k, v in vars( self ).items()
381
- }
382
-
383
- ret = msgpack.packb( o )
384
-
288
+ o = {k: _make_packable(v) for k, v in vars(self).items()}
289
+ ret = msgpack.packb(o)
385
290
  if ret is None:
386
- raise RuntimeError( f'Failed to pack sample to bytes: {o}' )
387
-
291
+ raise RuntimeError(f"Failed to pack sample to bytes: {o}")
388
292
  return ret
389
-
390
- @property
391
- def as_wds( self ) -> WDSRawSample:
392
- """Pack this sample's data for writing to WebDataset.
393
-
394
- Returns:
395
- A dictionary with ``__key__`` (UUID v1 for sortable keys) and
396
- ``msgpack`` (packed sample data) fields suitable for WebDataset.
397
293
 
398
- Note:
399
- Keys are auto-generated as UUID v1 for time-sortable ordering.
400
- Custom key specification is not currently supported.
401
- """
294
+ @property
295
+ def as_wds(self) -> WDSRawSample:
296
+ """Serialize for writing to WebDataset (``__key__`` + ``msgpack``)."""
402
297
  return {
403
- # Generates a UUID that is timelike-sortable
404
- '__key__': str( uuid.uuid1( 0, 0 ) ),
405
- 'msgpack': self.packed,
298
+ "__key__": str(uuid.uuid1(0, 0)),
299
+ "msgpack": self.packed,
406
300
  }
407
301
 
408
- def _batch_aggregate( xs: Sequence ):
302
+
303
+ def _batch_aggregate(xs: Sequence):
409
304
  """Stack arrays into numpy array with batch dim; otherwise return list."""
410
305
  if not xs:
411
306
  return []
412
- if isinstance( xs[0], np.ndarray ):
413
- return np.array( list( xs ) )
414
- return list( xs )
307
+ if isinstance(xs[0], np.ndarray):
308
+ return np.array(list(xs))
309
+ return list(xs)
415
310
 
416
- class SampleBatch( Generic[DT] ):
417
- """A batch of samples with automatic attribute aggregation.
418
311
 
419
- This class wraps a sequence of samples and provides magic ``__getattr__``
420
- access to aggregate sample attributes. When you access an attribute that
421
- exists on the sample type, it automatically aggregates values across all
422
- samples in the batch.
312
+ class SampleBatch(Generic[DT]):
313
+ """A batch of samples with automatic attribute aggregation.
423
314
 
424
- NDArray fields are stacked into a numpy array with a batch dimension.
425
- Other fields are aggregated into a list.
315
+ Accessing an attribute aggregates that field across all samples:
316
+ NDArray fields are stacked into a numpy array with a batch dimension;
317
+ other fields are collected into a list. Results are cached.
426
318
 
427
319
  Parameters:
428
320
  DT: The sample type, must derive from ``PackableSample``.
429
321
 
430
- Attributes:
431
- samples: The list of sample instances in this batch.
432
-
433
- Example:
434
- ::
435
-
436
- >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
437
- >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
438
- >>> batch.names # Returns list of names
439
-
440
- Note:
441
- This class uses Python's ``__orig_class__`` mechanism to extract the
442
- type parameter at runtime. Instances must be created using the
443
- subscripted syntax ``SampleBatch[MyType](samples)`` rather than
444
- calling the constructor directly with an unsubscripted class.
322
+ Examples:
323
+ >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
324
+ >>> batch.embeddings # Stacked numpy array of shape (3, ...)
325
+ >>> batch.names # List of names
445
326
  """
446
- # Design note: The docstring uses "Parameters:" for type parameters because
447
- # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
448
-
449
- def __init__( self, samples: Sequence[DT] ):
450
- """Create a batch from a sequence of samples.
451
327
 
452
- Args:
453
- samples: A sequence of sample instances to aggregate into a batch.
454
- Each sample must be an instance of a type derived from
455
- ``PackableSample``.
456
- """
457
- self.samples = list( samples )
328
+ def __init__(self, samples: Sequence[DT]):
329
+ """Create a batch from a sequence of samples."""
330
+ self.samples = list(samples)
458
331
  self._aggregate_cache = dict()
459
332
  self._sample_type_cache: Type | None = None
460
333
 
461
334
  @property
462
- def sample_type( self ) -> Type:
463
- """The type of each sample in this batch.
464
-
465
- Returns:
466
- The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
467
- """
335
+ def sample_type(self) -> Type:
336
+ """The type parameter ``DT`` used when creating this batch."""
468
337
  if self._sample_type_cache is None:
469
- self._sample_type_cache = typing.get_args( self.__orig_class__)[0]
470
- assert self._sample_type_cache is not None
338
+ self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
339
+ if self._sample_type_cache is None:
340
+ raise TypeError(
341
+ "SampleBatch requires a type parameter, e.g. SampleBatch[MySample]"
342
+ )
471
343
  return self._sample_type_cache
472
344
 
473
- def __getattr__( self, name ):
474
- """Aggregate an attribute across all samples in the batch.
475
-
476
- This magic method enables attribute-style access to aggregated sample
477
- fields. Results are cached for efficiency.
478
-
479
- Args:
480
- name: The attribute name to aggregate across samples.
481
-
482
- Returns:
483
- For NDArray fields: a stacked numpy array with batch dimension.
484
- For other fields: a list of values from each sample.
485
-
486
- Raises:
487
- AttributeError: If the attribute doesn't exist on the sample type.
488
- """
345
+ def __getattr__(self, name):
346
+ """Aggregate a field across all samples (cached)."""
489
347
  # Aggregate named params of sample type
490
- if name in vars( self.sample_type )['__annotations__']:
348
+ if name in vars(self.sample_type)["__annotations__"]:
491
349
  if name not in self._aggregate_cache:
492
350
  self._aggregate_cache[name] = _batch_aggregate(
493
- [ getattr( x, name )
494
- for x in self.samples ]
351
+ [getattr(x, name) for x in self.samples]
495
352
  )
496
353
 
497
354
  return self._aggregate_cache[name]
498
355
 
499
- raise AttributeError( f'No sample attribute named {name}' )
356
+ raise AttributeError(f"No sample attribute named {name}")
500
357
 
501
358
 
502
- ST = TypeVar( 'ST', bound = PackableSample )
503
- RT = TypeVar( 'RT', bound = PackableSample )
359
+ ST = TypeVar("ST", bound=Packable)
360
+ RT = TypeVar("RT", bound=Packable)
504
361
 
505
362
 
506
363
  class _ShardListStage(wds.utils.PipelineStage):
@@ -538,7 +395,7 @@ class _StreamOpenerStage(wds.utils.PipelineStage):
538
395
  yield sample
539
396
 
540
397
 
541
- class Dataset( Generic[ST] ):
398
+ class Dataset(Generic[ST]):
542
399
  """A typed dataset built on WebDataset with lens transformations.
543
400
 
544
401
  This class wraps WebDataset tar archives and provides type-safe iteration
@@ -557,16 +414,14 @@ class Dataset( Generic[ST] ):
557
414
  Attributes:
558
415
  url: WebDataset brace-notation URL for the tar file(s).
559
416
 
560
- Example:
561
- ::
562
-
563
- >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
564
- >>> for sample in ds.ordered(batch_size=32):
565
- ... # sample is SampleBatch[MyData] with batch_size samples
566
- ... embeddings = sample.embeddings # shape: (32, ...)
567
- ...
568
- >>> # Transform to a different view
569
- >>> ds_view = ds.as_type(MyDataView)
417
+ Examples:
418
+ >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
419
+ >>> for sample in ds.ordered(batch_size=32):
420
+ ... # sample is SampleBatch[MyData] with batch_size samples
421
+ ... embeddings = sample.embeddings # shape: (32, ...)
422
+ ...
423
+ >>> # Transform to a different view
424
+ >>> ds_view = ds.as_type(MyDataView)
570
425
 
571
426
  Note:
572
427
  This class uses Python's ``__orig_class__`` mechanism to extract the
@@ -574,35 +429,33 @@ class Dataset( Generic[ST] ):
574
429
  subscripted syntax ``Dataset[MyType](url)`` rather than calling the
575
430
  constructor directly with an unsubscripted class.
576
431
  """
432
+
577
433
  # Design note: The docstring uses "Parameters:" for type parameters because
578
434
  # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
579
435
 
580
436
  @property
581
- def sample_type( self ) -> Type:
582
- """The type of each returned sample from this dataset's iterator.
583
-
584
- Returns:
585
- The type parameter ``ST`` used when creating this ``Dataset[ST]``.
586
- """
437
+ def sample_type(self) -> Type:
438
+ """The type parameter ``ST`` used when creating this dataset."""
587
439
  if self._sample_type_cache is None:
588
- self._sample_type_cache = typing.get_args( self.__orig_class__ )[0]
589
- assert self._sample_type_cache is not None
440
+ self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
441
+ if self._sample_type_cache is None:
442
+ raise TypeError(
443
+ "Dataset requires a type parameter, e.g. Dataset[MySample]"
444
+ )
590
445
  return self._sample_type_cache
591
- @property
592
- def batch_type( self ) -> Type:
593
- """The type of batches produced by this dataset.
594
446
 
595
- Returns:
596
- ``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
597
- """
447
+ @property
448
+ def batch_type(self) -> Type:
449
+ """``SampleBatch[ST]`` where ``ST`` is this dataset's sample type."""
598
450
  return SampleBatch[self.sample_type]
599
451
 
600
- def __init__( self,
601
- source: DataSource | str | None = None,
602
- metadata_url: str | None = None,
603
- *,
604
- url: str | None = None,
605
- ) -> None:
452
+ def __init__(
453
+ self,
454
+ source: DataSource | str | None = None,
455
+ metadata_url: str | None = None,
456
+ *,
457
+ url: str | None = None,
458
+ ) -> None:
606
459
  """Create a dataset from a DataSource or URL.
607
460
 
608
461
  Args:
@@ -620,28 +473,21 @@ class Dataset( Generic[ST] ):
620
473
  """
621
474
  super().__init__()
622
475
 
623
- # Handle backward compatibility: url= keyword argument
624
476
  if source is None and url is not None:
625
477
  source = url
626
478
  elif source is None:
627
479
  raise TypeError("Dataset() missing required argument: 'source' or 'url'")
628
480
 
629
- # Normalize source: strings become URLSource for backward compatibility
630
481
  if isinstance(source, str):
631
482
  self._source: DataSource = URLSource(source)
632
483
  self.url = source
633
484
  else:
634
485
  self._source = source
635
- # For compatibility, expose URL if source has list_shards
636
486
  shards = source.list_shards()
637
- # Design note: Using first shard as url for legacy compatibility.
638
- # Full shard list is available via list_shards() method.
639
487
  self.url = shards[0] if shards else ""
640
488
 
641
489
  self._metadata: dict[str, Any] | None = None
642
490
  self.metadata_url: str | None = metadata_url
643
- """Optional URL to msgpack-encoded metadata for this dataset."""
644
-
645
491
  self._output_lens: Lens | None = None
646
492
  self._sample_type_cache: Type | None = None
647
493
 
@@ -650,50 +496,24 @@ class Dataset( Generic[ST] ):
650
496
  """The underlying data source for this dataset."""
651
497
  return self._source
652
498
 
653
- def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
654
- """View this dataset through a different sample type using a registered lens.
655
-
656
- Args:
657
- other: The target sample type to transform into. Must be a type
658
- derived from ``PackableSample``.
659
-
660
- Returns:
661
- A new ``Dataset`` instance that yields samples of type ``other``
662
- by applying the appropriate lens transformation from the global
663
- ``LensNetwork`` registry.
499
+ def as_type(self, other: Type[RT]) -> "Dataset[RT]":
500
+ """View this dataset through a different sample type via a registered lens.
664
501
 
665
502
  Raises:
666
- ValueError: If no registered lens exists between the current
667
- sample type and the target type.
503
+ ValueError: If no lens exists between the current and target types.
668
504
  """
669
- ret = Dataset[other]( self._source )
670
- # Get the singleton lens registry
505
+ ret = Dataset[other](self._source)
671
506
  lenses = LensNetwork()
672
- ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
507
+ ret._output_lens = lenses.transform(self.sample_type, ret.sample_type)
673
508
  return ret
674
509
 
675
510
  @property
676
511
  def shards(self) -> Iterator[str]:
677
- """Lazily iterate over shard identifiers.
678
-
679
- Yields:
680
- Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
681
-
682
- Example:
683
- ::
684
-
685
- >>> for shard in ds.shards:
686
- ... print(f"Processing {shard}")
687
- """
512
+ """Lazily iterate over shard identifiers."""
688
513
  return iter(self._source.list_shards())
689
514
 
690
515
  def list_shards(self) -> list[str]:
691
- """Get list of individual dataset shards.
692
-
693
- Returns:
694
- A full (non-lazy) list of the individual ``tar`` files within the
695
- source WebDataset.
696
- """
516
+ """Return all shard paths/URLs as a list."""
697
517
  return self._source.list_shards()
698
518
 
699
519
  # Legacy alias for backwards compatibility
@@ -705,6 +525,7 @@ class Dataset( Generic[ST] ):
705
525
  Use :meth:`list_shards` instead.
706
526
  """
707
527
  import warnings
528
+
708
529
  warnings.warn(
709
530
  "shard_list is deprecated, use list_shards() instead",
710
531
  DeprecationWarning,
@@ -713,40 +534,414 @@ class Dataset( Generic[ST] ):
713
534
  return self.list_shards()
714
535
 
715
536
  @property
716
- def metadata( self ) -> dict[str, Any] | None:
717
- """Fetch and cache metadata from metadata_url.
718
-
719
- Returns:
720
- Deserialized metadata dictionary, or None if no metadata_url is set.
721
-
722
- Raises:
723
- requests.HTTPError: If metadata fetch fails.
724
- """
537
+ def metadata(self) -> dict[str, Any] | None:
538
+ """Fetch and cache metadata from metadata_url, or ``None`` if unset."""
725
539
  if self.metadata_url is None:
726
540
  return None
727
541
 
728
542
  if self._metadata is None:
729
- with requests.get( self.metadata_url, stream = True ) as response:
543
+ with requests.get(self.metadata_url, stream=True) as response:
730
544
  response.raise_for_status()
731
- self._metadata = msgpack.unpackb( response.content, raw = False )
732
-
545
+ self._metadata = msgpack.unpackb(response.content, raw=False)
546
+
733
547
  # Use our cached values
734
548
  return self._metadata
735
-
736
- def ordered( self,
737
- batch_size: int | None = None,
738
- ) -> Iterable[ST]:
739
- """Iterate over the dataset in order
549
+
550
+ ##
551
+ # Convenience methods (GH#38 developer experience)
552
+
553
+ @property
554
+ def schema(self) -> dict[str, type]:
555
+ """Field names and types for this dataset's sample type.
556
+
557
+ Examples:
558
+ >>> ds = Dataset[MyData]("data.tar")
559
+ >>> ds.schema
560
+ {'name': <class 'str'>, 'embedding': numpy.ndarray}
561
+ """
562
+ st = self.sample_type
563
+ if st is DictSample:
564
+ return {"_data": dict}
565
+ if dataclasses.is_dataclass(st):
566
+ return {f.name: f.type for f in dataclasses.fields(st)}
567
+ return {}
568
+
569
+ @property
570
+ def column_names(self) -> list[str]:
571
+ """List of field names for this dataset's sample type."""
572
+ st = self.sample_type
573
+ if dataclasses.is_dataclass(st):
574
+ return [f.name for f in dataclasses.fields(st)]
575
+ return []
576
+
577
+ def __iter__(self) -> Iterator[ST]:
578
+ """Shorthand for ``ds.ordered()``."""
579
+ return iter(self.ordered())
580
+
581
+ def __len__(self) -> int:
582
+ """Total sample count (iterates all shards on first call, then cached)."""
583
+ if not hasattr(self, "_len_cache"):
584
+ self._len_cache: int = sum(1 for _ in self.ordered())
585
+ return self._len_cache
586
+
587
+ def head(self, n: int = 5) -> list[ST]:
588
+ """Return the first *n* samples from the dataset.
589
+
590
+ Args:
591
+ n: Number of samples to return. Default: 5.
592
+
593
+ Returns:
594
+ List of up to *n* samples in shard order.
595
+
596
+ Examples:
597
+ >>> samples = ds.head(3)
598
+ >>> len(samples)
599
+ 3
600
+ """
601
+ return list(itertools.islice(self.ordered(), n))
602
+
603
+ def get(self, key: str) -> ST:
604
+ """Retrieve a single sample by its ``__key__``.
605
+
606
+ Scans shards sequentially until a sample with a matching key is found.
607
+ This is O(n) for streaming datasets.
740
608
 
741
609
  Args:
742
- batch_size (:obj:`int`, optional): The size of iterated batches.
743
- Default: None (unbatched). If ``None``, iterates over one
744
- sample at a time with no batch dimension.
610
+ key: The WebDataset ``__key__`` string to search for.
745
611
 
746
612
  Returns:
747
- :obj:`webdataset.DataPipeline` A data pipeline that iterates over
748
- the dataset in its original sample order
613
+ The matching sample.
614
+
615
+ Raises:
616
+ SampleKeyError: If no sample with the given key exists.
749
617
 
618
+ Examples:
619
+ >>> sample = ds.get("00000001-0001-1000-8000-010000000000")
620
+ """
621
+ pipeline = wds.pipeline.DataPipeline(
622
+ _ShardListStage(self._source),
623
+ wds.shardlists.split_by_worker,
624
+ _StreamOpenerStage(self._source),
625
+ wds.tariterators.tar_file_expander,
626
+ wds.tariterators.group_by_keys,
627
+ )
628
+ for raw_sample in pipeline:
629
+ if raw_sample.get("__key__") == key:
630
+ return self.wrap(raw_sample)
631
+ raise SampleKeyError(key)
632
+
633
+ def describe(self) -> dict[str, Any]:
634
+ """Summary statistics: sample_type, fields, num_shards, shards, url, metadata."""
635
+ shards = self.list_shards()
636
+ return {
637
+ "sample_type": self.sample_type.__name__,
638
+ "fields": self.schema,
639
+ "num_shards": len(shards),
640
+ "shards": shards,
641
+ "url": self.url,
642
+ "metadata": self.metadata,
643
+ }
644
+
645
+ def filter(self, predicate: Callable[[ST], bool]) -> "Dataset[ST]":
646
+ """Return a new dataset that yields only samples matching *predicate*.
647
+
648
+ The filter is applied lazily during iteration — no data is copied.
649
+
650
+ Args:
651
+ predicate: A function that takes a sample and returns ``True``
652
+ to keep it or ``False`` to discard it.
653
+
654
+ Returns:
655
+ A new ``Dataset`` whose iterators apply the filter.
656
+
657
+ Examples:
658
+ >>> long_names = ds.filter(lambda s: len(s.name) > 10)
659
+ >>> for sample in long_names:
660
+ ... assert len(sample.name) > 10
661
+ """
662
+ filtered = Dataset[self.sample_type](self._source, self.metadata_url)
663
+ filtered._sample_type_cache = self._sample_type_cache
664
+ filtered._output_lens = self._output_lens
665
+ filtered._filter_fn = predicate
666
+ # Preserve any existing filters
667
+ parent_filters = getattr(self, "_filter_fn", None)
668
+ if parent_filters is not None:
669
+ outer = parent_filters
670
+ filtered._filter_fn = lambda s: outer(s) and predicate(s)
671
+ # Preserve any existing map
672
+ if hasattr(self, "_map_fn"):
673
+ filtered._map_fn = self._map_fn
674
+ return filtered
675
+
676
+ def map(self, fn: Callable[[ST], Any]) -> "Dataset":
677
+ """Return a new dataset that applies *fn* to each sample during iteration.
678
+
679
+ The mapping is applied lazily during iteration — no data is copied.
680
+
681
+ Args:
682
+ fn: A function that takes a sample of type ``ST`` and returns
683
+ a transformed value.
684
+
685
+ Returns:
686
+ A new ``Dataset`` whose iterators apply the mapping.
687
+
688
+ Examples:
689
+ >>> names = ds.map(lambda s: s.name)
690
+ >>> for name in names:
691
+ ... print(name)
692
+ """
693
+ mapped = Dataset[self.sample_type](self._source, self.metadata_url)
694
+ mapped._sample_type_cache = self._sample_type_cache
695
+ mapped._output_lens = self._output_lens
696
+ mapped._map_fn = fn
697
+ # Preserve any existing map
698
+ if hasattr(self, "_map_fn"):
699
+ outer = self._map_fn
700
+ mapped._map_fn = lambda s: fn(outer(s))
701
+ # Preserve any existing filter
702
+ if hasattr(self, "_filter_fn"):
703
+ mapped._filter_fn = self._filter_fn
704
+ return mapped
705
+
706
+ def process_shards(
707
+ self,
708
+ fn: Callable[[list[ST]], Any],
709
+ *,
710
+ shards: list[str] | None = None,
711
+ ) -> dict[str, Any]:
712
+ """Process each shard independently, collecting per-shard results.
713
+
714
+ Unlike :meth:`map` (which is lazy and per-sample), this method eagerly
715
+ processes each shard in turn, calling *fn* with the full list of samples
716
+ from that shard. If some shards fail, raises
717
+ :class:`~atdata._exceptions.PartialFailureError` containing both the
718
+ successful results and the per-shard errors.
719
+
720
+ Args:
721
+ fn: Function receiving a list of samples from one shard and
722
+ returning an arbitrary result.
723
+ shards: Optional list of shard identifiers to process. If ``None``,
724
+ processes all shards in the dataset. Useful for retrying only
725
+ the failed shards from a previous ``PartialFailureError``.
726
+
727
+ Returns:
728
+ Dict mapping shard identifier to *fn*'s return value for each shard.
729
+
730
+ Raises:
731
+ PartialFailureError: If at least one shard fails. The exception
732
+ carries ``.succeeded_shards``, ``.failed_shards``, ``.errors``,
733
+ and ``.results`` for inspection and retry.
734
+
735
+ Examples:
736
+ >>> results = ds.process_shards(lambda samples: len(samples))
737
+ >>> # On partial failure, retry just the failed shards:
738
+ >>> try:
739
+ ... results = ds.process_shards(expensive_fn)
740
+ ... except PartialFailureError as e:
741
+ ... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)
742
+ """
743
+ from ._logging import get_logger
744
+
745
+ log = get_logger()
746
+ shard_ids = shards or self.list_shards()
747
+ log.info("process_shards: starting %d shards", len(shard_ids))
748
+
749
+ succeeded: list[str] = []
750
+ failed: list[str] = []
751
+ errors: dict[str, Exception] = {}
752
+ results: dict[str, Any] = {}
753
+
754
+ for shard_id in shard_ids:
755
+ try:
756
+ shard_ds = Dataset[self.sample_type](shard_id)
757
+ shard_ds._sample_type_cache = self._sample_type_cache
758
+ samples = list(shard_ds.ordered())
759
+ results[shard_id] = fn(samples)
760
+ succeeded.append(shard_id)
761
+ log.debug("process_shards: shard ok %s", shard_id)
762
+ except Exception as exc:
763
+ failed.append(shard_id)
764
+ errors[shard_id] = exc
765
+ log.warning("process_shards: shard failed %s: %s", shard_id, exc)
766
+
767
+ if failed:
768
+ log.error(
769
+ "process_shards: %d/%d shards failed",
770
+ len(failed),
771
+ len(shard_ids),
772
+ )
773
+ raise PartialFailureError(
774
+ succeeded_shards=succeeded,
775
+ failed_shards=failed,
776
+ errors=errors,
777
+ results=results,
778
+ )
779
+
780
+ log.info("process_shards: all %d shards succeeded", len(shard_ids))
781
+ return results
782
+
783
+ def select(self, indices: Sequence[int]) -> list[ST]:
784
+ """Return samples at the given integer indices.
785
+
786
+ Iterates through the dataset in order and collects samples whose
787
+ positional index matches. This is O(n) for streaming datasets.
788
+
789
+ Args:
790
+ indices: Sequence of zero-based indices to select.
791
+
792
+ Returns:
793
+ List of samples at the requested positions, in index order.
794
+
795
+ Examples:
796
+ >>> samples = ds.select([0, 5, 10])
797
+ >>> len(samples)
798
+ 3
799
+ """
800
+ if not indices:
801
+ return []
802
+ target = set(indices)
803
+ max_idx = max(indices)
804
+ result: dict[int, ST] = {}
805
+ for i, sample in enumerate(self.ordered()):
806
+ if i in target:
807
+ result[i] = sample
808
+ if i >= max_idx:
809
+ break
810
+ return [result[i] for i in indices if i in result]
811
+
812
+ def query(
813
+ self,
814
+ where: "Callable[[pd.DataFrame], pd.Series]",
815
+ ) -> "list[SampleLocation]":
816
+ """Query this dataset using per-shard manifest metadata.
817
+
818
+ Requires manifests to have been generated during shard writing.
819
+ Discovers manifest files alongside the tar shards, loads them,
820
+ and executes a two-phase query (shard-level aggregate pruning,
821
+ then sample-level parquet filtering).
822
+
823
+ Args:
824
+ where: Predicate function that receives a pandas DataFrame
825
+ of manifest fields and returns a boolean Series selecting
826
+ matching rows.
827
+
828
+ Returns:
829
+ List of ``SampleLocation`` for matching samples.
830
+
831
+ Raises:
832
+ FileNotFoundError: If no manifest files are found alongside shards.
833
+
834
+ Examples:
835
+ >>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
836
+ >>> len(locs)
837
+ 42
838
+ """
839
+ from .manifest import QueryExecutor
840
+
841
+ shard_urls = self.list_shards()
842
+ executor = QueryExecutor.from_shard_urls(shard_urls)
843
+ return executor.query(where=where)
844
+
845
+ def to_pandas(self, limit: int | None = None) -> "pd.DataFrame":
846
+ """Materialize the dataset (or first *limit* samples) as a DataFrame.
847
+
848
+ Args:
849
+ limit: Maximum number of samples to include. ``None`` means all
850
+ samples (may use significant memory for large datasets).
851
+
852
+ Returns:
853
+ A pandas DataFrame with one row per sample and columns matching
854
+ the sample fields.
855
+
856
+ Warning:
857
+ With ``limit=None`` this loads the entire dataset into memory.
858
+
859
+ Examples:
860
+ >>> df = ds.to_pandas(limit=100)
861
+ >>> df.columns.tolist()
862
+ ['name', 'embedding']
863
+ """
864
+ samples = self.head(limit) if limit is not None else list(self.ordered())
865
+ rows = [
866
+ asdict(s) if dataclasses.is_dataclass(s) else s.to_dict() for s in samples
867
+ ]
868
+ return pd.DataFrame(rows)
869
+
870
+ def to_dict(self, limit: int | None = None) -> dict[str, list[Any]]:
871
+ """Materialize the dataset as a column-oriented dictionary.
872
+
873
+ Args:
874
+ limit: Maximum number of samples to include. ``None`` means all.
875
+
876
+ Returns:
877
+ Dictionary mapping field names to lists of values (one entry
878
+ per sample).
879
+
880
+ Warning:
881
+ With ``limit=None`` this loads the entire dataset into memory.
882
+
883
+ Examples:
884
+ >>> d = ds.to_dict(limit=10)
885
+ >>> d.keys()
886
+ dict_keys(['name', 'embedding'])
887
+ >>> len(d['name'])
888
+ 10
889
+ """
890
+ samples = self.head(limit) if limit is not None else list(self.ordered())
891
+ if not samples:
892
+ return {}
893
+ if dataclasses.is_dataclass(samples[0]):
894
+ fields = [f.name for f in dataclasses.fields(samples[0])]
895
+ return {f: [getattr(s, f) for s in samples] for f in fields}
896
+ # DictSample path
897
+ keys = samples[0].keys()
898
+ return {k: [s[k] for s in samples] for k in keys}
899
+
900
+ def _post_wrap_stages(self) -> list:
901
+ """Build extra pipeline stages for filter/map set via .filter()/.map()."""
902
+ stages: list = []
903
+ filter_fn = getattr(self, "_filter_fn", None)
904
+ if filter_fn is not None:
905
+ stages.append(wds.filters.select(filter_fn))
906
+ map_fn = getattr(self, "_map_fn", None)
907
+ if map_fn is not None:
908
+ stages.append(wds.filters.map(map_fn))
909
+ return stages
910
+
911
+ @overload
912
+ def ordered(
913
+ self,
914
+ batch_size: None = None,
915
+ ) -> Iterable[ST]: ...
916
+
917
+ @overload
918
+ def ordered(
919
+ self,
920
+ batch_size: int,
921
+ ) -> Iterable[SampleBatch[ST]]: ...
922
+
923
+ def ordered(
924
+ self,
925
+ batch_size: int | None = None,
926
+ ) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
927
+ """Iterate over the dataset in order.
928
+
929
+ Args:
930
+ batch_size: The size of iterated batches. Default: None (unbatched).
931
+ If ``None``, iterates over one sample at a time with no batch
932
+ dimension.
933
+
934
+ Returns:
935
+ A data pipeline that iterates over the dataset in its original
936
+ sample order. When ``batch_size`` is ``None``, yields individual
937
+ samples of type ``ST``. When ``batch_size`` is an integer, yields
938
+ ``SampleBatch[ST]`` instances containing that many samples.
939
+
940
+ Examples:
941
+ >>> for sample in ds.ordered():
942
+ ... process(sample) # sample is ST
943
+ >>> for batch in ds.ordered(batch_size=32):
944
+ ... process(batch) # batch is SampleBatch[ST]
750
945
  """
751
946
  if batch_size is None:
752
947
  return wds.pipeline.DataPipeline(
@@ -755,7 +950,8 @@ class Dataset( Generic[ST] ):
755
950
  _StreamOpenerStage(self._source),
756
951
  wds.tariterators.tar_file_expander,
757
952
  wds.tariterators.group_by_keys,
758
- wds.filters.map( self.wrap ),
953
+ wds.filters.map(self.wrap),
954
+ *self._post_wrap_stages(),
759
955
  )
760
956
 
761
957
  return wds.pipeline.DataPipeline(
@@ -764,15 +960,33 @@ class Dataset( Generic[ST] ):
764
960
  _StreamOpenerStage(self._source),
765
961
  wds.tariterators.tar_file_expander,
766
962
  wds.tariterators.group_by_keys,
767
- wds.filters.batched( batch_size ),
768
- wds.filters.map( self.wrap_batch ),
963
+ wds.filters.batched(batch_size),
964
+ wds.filters.map(self.wrap_batch),
769
965
  )
770
966
 
771
- def shuffled( self,
772
- buffer_shards: int = 100,
773
- buffer_samples: int = 10_000,
774
- batch_size: int | None = None,
775
- ) -> Iterable[ST]:
967
+ @overload
968
+ def shuffled(
969
+ self,
970
+ buffer_shards: int = 100,
971
+ buffer_samples: int = 10_000,
972
+ batch_size: None = None,
973
+ ) -> Iterable[ST]: ...
974
+
975
+ @overload
976
+ def shuffled(
977
+ self,
978
+ buffer_shards: int = 100,
979
+ buffer_samples: int = 10_000,
980
+ *,
981
+ batch_size: int,
982
+ ) -> Iterable[SampleBatch[ST]]: ...
983
+
984
+ def shuffled(
985
+ self,
986
+ buffer_shards: int = 100,
987
+ buffer_samples: int = 10_000,
988
+ batch_size: int | None = None,
989
+ ) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
776
990
  """Iterate over the dataset in random order.
777
991
 
778
992
  Args:
@@ -787,216 +1001,147 @@ class Dataset( Generic[ST] ):
787
1001
  dimension.
788
1002
 
789
1003
  Returns:
790
- A WebDataset data pipeline that iterates over the dataset in
791
- randomized order. If ``batch_size`` is not ``None``, yields
792
- ``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
793
- samples.
1004
+ A data pipeline that iterates over the dataset in randomized order.
1005
+ When ``batch_size`` is ``None``, yields individual samples of type
1006
+ ``ST``. When ``batch_size`` is an integer, yields ``SampleBatch[ST]``
1007
+ instances containing that many samples.
1008
+
1009
+ Examples:
1010
+ >>> for sample in ds.shuffled():
1011
+ ... process(sample) # sample is ST
1012
+ >>> for batch in ds.shuffled(batch_size=32):
1013
+ ... process(batch) # batch is SampleBatch[ST]
794
1014
  """
795
1015
  if batch_size is None:
796
1016
  return wds.pipeline.DataPipeline(
797
1017
  _ShardListStage(self._source),
798
- wds.filters.shuffle( buffer_shards ),
1018
+ wds.filters.shuffle(buffer_shards),
799
1019
  wds.shardlists.split_by_worker,
800
1020
  _StreamOpenerStage(self._source),
801
1021
  wds.tariterators.tar_file_expander,
802
1022
  wds.tariterators.group_by_keys,
803
- wds.filters.shuffle( buffer_samples ),
804
- wds.filters.map( self.wrap ),
1023
+ wds.filters.shuffle(buffer_samples),
1024
+ wds.filters.map(self.wrap),
1025
+ *self._post_wrap_stages(),
805
1026
  )
806
1027
 
807
1028
  return wds.pipeline.DataPipeline(
808
1029
  _ShardListStage(self._source),
809
- wds.filters.shuffle( buffer_shards ),
1030
+ wds.filters.shuffle(buffer_shards),
810
1031
  wds.shardlists.split_by_worker,
811
1032
  _StreamOpenerStage(self._source),
812
1033
  wds.tariterators.tar_file_expander,
813
1034
  wds.tariterators.group_by_keys,
814
- wds.filters.shuffle( buffer_samples ),
815
- wds.filters.batched( batch_size ),
816
- wds.filters.map( self.wrap_batch ),
1035
+ wds.filters.shuffle(buffer_samples),
1036
+ wds.filters.batched(batch_size),
1037
+ wds.filters.map(self.wrap_batch),
817
1038
  )
818
-
1039
+
819
1040
  # Design note: Uses pandas for parquet export. Could be replaced with
820
1041
  # direct fastparquet calls to reduce dependencies if needed.
821
- def to_parquet( self, path: Pathlike,
822
- sample_map: Optional[SampleExportMap] = None,
823
- maxcount: Optional[int] = None,
824
- **kwargs,
825
- ):
826
- """Export dataset contents to parquet format.
827
-
828
- Converts all samples to a pandas DataFrame and saves to parquet file(s).
829
- Useful for interoperability with data analysis tools.
1042
+ def to_parquet(
1043
+ self,
1044
+ path: Pathlike,
1045
+ sample_map: Optional[SampleExportMap] = None,
1046
+ maxcount: Optional[int] = None,
1047
+ **kwargs,
1048
+ ):
1049
+ """Export dataset to parquet file(s).
830
1050
 
831
1051
  Args:
832
- path: Output path for the parquet file. If ``maxcount`` is specified,
833
- files are named ``{stem}-{segment:06d}.parquet``.
834
- sample_map: Optional function to convert samples to dictionaries.
835
- Defaults to ``dataclasses.asdict``.
836
- maxcount: If specified, split output into multiple files with at most
837
- this many samples each. Recommended for large datasets.
838
- **kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
839
- Common options include ``compression``, ``index``, ``engine``.
840
-
841
- Warning:
842
- **Memory Usage**: When ``maxcount=None`` (default), this method loads
843
- the **entire dataset into memory** as a pandas DataFrame before writing.
844
- For large datasets, this can cause memory exhaustion.
845
-
846
- For datasets larger than available RAM, always specify ``maxcount``::
847
-
848
- # Safe for large datasets - processes in chunks
849
- ds.to_parquet("output.parquet", maxcount=10000)
850
-
851
- This creates multiple parquet files: ``output-000000.parquet``,
852
- ``output-000001.parquet``, etc.
853
-
854
- Example:
855
- ::
856
-
857
- >>> ds = Dataset[MySample]("data.tar")
858
- >>> # Small dataset - load all at once
859
- >>> ds.to_parquet("output.parquet")
860
- >>>
861
- >>> # Large dataset - process in chunks
862
- >>> ds.to_parquet("output.parquet", maxcount=50000)
1052
+ path: Output path. With *maxcount*, files are named
1053
+ ``{stem}-{segment:06d}.parquet``.
1054
+ sample_map: Convert sample to dict. Defaults to ``dataclasses.asdict``.
1055
+ maxcount: Split into files of at most this many samples.
1056
+ Without it, the entire dataset is loaded into memory.
1057
+ **kwargs: Passed to ``pandas.DataFrame.to_parquet()``.
1058
+
1059
+ Examples:
1060
+ >>> ds.to_parquet("output.parquet", maxcount=50000)
863
1061
  """
864
- ##
865
-
866
- # Normalize args
867
- path = Path( path )
1062
+ path = Path(path)
868
1063
  if sample_map is None:
869
1064
  sample_map = asdict
870
-
871
- verbose = kwargs.get( 'verbose', False )
872
-
873
- it = self.ordered( batch_size = None )
874
- if verbose:
875
- it = tqdm( it )
876
-
877
- #
878
1065
 
879
1066
  if maxcount is None:
880
- # Load and save full dataset
881
- df = pd.DataFrame( [ sample_map( x )
882
- for x in self.ordered( batch_size = None ) ] )
883
- df.to_parquet( path, **kwargs )
884
-
1067
+ df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)])
1068
+ df.to_parquet(path, **kwargs)
885
1069
  else:
886
- # Load and save dataset in segments of size `maxcount`
887
-
888
1070
  cur_segment = 0
889
- cur_buffer = []
890
- path_template = (path.parent / f'{path.stem}-{{:06d}}{path.suffix}').as_posix()
891
-
892
- for x in self.ordered( batch_size = None ):
893
- cur_buffer.append( sample_map( x ) )
894
-
895
- if len( cur_buffer ) >= maxcount:
896
- # Write current segment
897
- cur_path = path_template.format( cur_segment )
898
- df = pd.DataFrame( cur_buffer )
899
- df.to_parquet( cur_path, **kwargs )
900
-
1071
+ cur_buffer: list = []
1072
+ path_template = (
1073
+ path.parent / f"{path.stem}-{{:06d}}{path.suffix}"
1074
+ ).as_posix()
1075
+
1076
+ for x in self.ordered(batch_size=None):
1077
+ cur_buffer.append(sample_map(x))
1078
+ if len(cur_buffer) >= maxcount:
1079
+ cur_path = path_template.format(cur_segment)
1080
+ pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
901
1081
  cur_segment += 1
902
1082
  cur_buffer = []
903
-
904
- if len( cur_buffer ) > 0:
905
- # Write one last segment with remainder
906
- cur_path = path_template.format( cur_segment )
907
- df = pd.DataFrame( cur_buffer )
908
- df.to_parquet( cur_path, **kwargs )
909
-
910
- def wrap( self, sample: WDSRawSample ) -> ST:
911
- """Wrap a raw msgpack sample into the appropriate dataset-specific type.
912
1083
 
913
- Args:
914
- sample: A dictionary containing at minimum a ``'msgpack'`` key with
915
- serialized sample bytes.
1084
+ if cur_buffer:
1085
+ cur_path = path_template.format(cur_segment)
1086
+ pd.DataFrame(cur_buffer).to_parquet(cur_path, **kwargs)
916
1087
 
917
- Returns:
918
- A deserialized sample of type ``ST``, optionally transformed through
919
- a lens if ``as_type()`` was called.
920
- """
921
- if 'msgpack' not in sample:
922
- raise ValueError(f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}")
923
- if not isinstance(sample['msgpack'], bytes):
924
- raise ValueError(f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}")
1088
+ def wrap(self, sample: WDSRawSample) -> ST:
1089
+ """Deserialize a raw WDS sample dict into type ``ST``."""
1090
+ if "msgpack" not in sample:
1091
+ raise ValueError(
1092
+ f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}"
1093
+ )
1094
+ if not isinstance(sample["msgpack"], bytes):
1095
+ raise ValueError(
1096
+ f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}"
1097
+ )
925
1098
 
926
1099
  if self._output_lens is None:
927
- return self.sample_type.from_bytes( sample['msgpack'] )
928
-
929
- source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
930
- return self._output_lens( source_sample )
1100
+ return self.sample_type.from_bytes(sample["msgpack"])
931
1101
 
932
- def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
933
- """Wrap a batch of raw msgpack samples into a typed SampleBatch.
934
-
935
- Args:
936
- batch: A dictionary containing a ``'msgpack'`` key with a list of
937
- serialized sample bytes.
938
-
939
- Returns:
940
- A ``SampleBatch[ST]`` containing deserialized samples, optionally
941
- transformed through a lens if ``as_type()`` was called.
1102
+ source_sample = self._output_lens.source_type.from_bytes(sample["msgpack"])
1103
+ return self._output_lens(source_sample)
942
1104
 
943
- Note:
944
- This implementation deserializes samples one at a time, then
945
- aggregates them into a batch.
946
- """
1105
+ def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]:
1106
+ """Deserialize a raw WDS batch dict into ``SampleBatch[ST]``."""
947
1107
 
948
- if 'msgpack' not in batch:
949
- raise ValueError(f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}")
1108
+ if "msgpack" not in batch:
1109
+ raise ValueError(
1110
+ f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}"
1111
+ )
950
1112
 
951
1113
  if self._output_lens is None:
952
- batch_unpacked = [ self.sample_type.from_bytes( bs )
953
- for bs in batch['msgpack'] ]
954
- return SampleBatch[self.sample_type]( batch_unpacked )
1114
+ batch_unpacked = [
1115
+ self.sample_type.from_bytes(bs) for bs in batch["msgpack"]
1116
+ ]
1117
+ return SampleBatch[self.sample_type](batch_unpacked)
955
1118
 
956
- batch_source = [ self._output_lens.source_type.from_bytes( bs )
957
- for bs in batch['msgpack'] ]
958
- batch_view = [ self._output_lens( s )
959
- for s in batch_source ]
960
- return SampleBatch[self.sample_type]( batch_view )
1119
+ batch_source = [
1120
+ self._output_lens.source_type.from_bytes(bs) for bs in batch["msgpack"]
1121
+ ]
1122
+ batch_view = [self._output_lens(s) for s in batch_source]
1123
+ return SampleBatch[self.sample_type](batch_view)
961
1124
 
962
1125
 
963
- _T = TypeVar('_T')
1126
+ _T = TypeVar("_T")
964
1127
 
965
1128
 
966
1129
  @dataclass_transform()
967
- def packable( cls: type[_T] ) -> type[_T]:
968
- """Decorator to convert a regular class into a ``PackableSample``.
969
-
970
- This decorator transforms a class into a dataclass that inherits from
971
- ``PackableSample``, enabling automatic msgpack serialization/deserialization
972
- with special handling for NDArray fields.
973
-
974
- The resulting class satisfies the ``Packable`` protocol, making it compatible
975
- with all atdata APIs that accept packable types (e.g., ``publish_schema``,
976
- lens transformations, etc.).
1130
+ def packable(cls: type[_T]) -> type[Packable]:
1131
+ """Convert a class into a ``PackableSample`` dataclass with msgpack serialization.
977
1132
 
978
- Args:
979
- cls: The class to convert. Should have type annotations for its fields.
980
-
981
- Returns:
982
- A new dataclass that inherits from ``PackableSample`` with the same
983
- name and annotations as the original class. The class satisfies the
984
- ``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
1133
+ The resulting class gains ``packed``, ``as_wds``, ``from_bytes``, and
1134
+ ``from_data`` methods, and satisfies the ``Packable`` protocol.
1135
+ NDArray fields are automatically handled during serialization.
985
1136
 
986
1137
  Examples:
987
- This is a test of the functionality::
988
-
989
- @packable
990
- class MyData:
991
- name: str
992
- values: NDArray
993
-
994
- sample = MyData(name="test", values=np.array([1, 2, 3]))
995
- bytes_data = sample.packed
996
- restored = MyData.from_bytes(bytes_data)
997
-
998
- # Works with Packable-typed APIs
999
- index.publish_schema(MyData, version="1.0.0") # Type-safe
1138
+ >>> @packable
1139
+ ... class MyData:
1140
+ ... name: str
1141
+ ... values: NDArray
1142
+ ...
1143
+ >>> sample = MyData(name="test", values=np.array([1, 2, 3]))
1144
+ >>> restored = MyData.from_bytes(sample.packed)
1000
1145
  """
1001
1146
 
1002
1147
  ##
@@ -1005,14 +1150,14 @@ def packable( cls: type[_T] ) -> type[_T]:
1005
1150
  class_annotations = cls.__annotations__
1006
1151
 
1007
1152
  # Add in dataclass niceness to original class
1008
- as_dataclass = dataclass( cls )
1153
+ as_dataclass = dataclass(cls)
1009
1154
 
1010
1155
  # This triggers a bunch of behind-the-scenes stuff for the newly annotated class
1011
1156
  @dataclass
1012
- class as_packable( as_dataclass, PackableSample ):
1013
- def __post_init__( self ):
1014
- return PackableSample.__post_init__( self )
1015
-
1157
+ class as_packable(as_dataclass, PackableSample):
1158
+ def __post_init__(self):
1159
+ return PackableSample.__post_init__(self)
1160
+
1016
1161
  # Restore original class identity for better repr/debugging
1017
1162
  as_packable.__name__ = class_name
1018
1163
  as_packable.__qualname__ = class_name
@@ -1023,10 +1168,10 @@ def packable( cls: type[_T] ) -> type[_T]:
1023
1168
 
1024
1169
  # Fix qualnames of dataclass-generated methods so they don't show
1025
1170
  # 'packable.<locals>.as_packable' in help() and IDE hints
1026
- old_qualname_prefix = 'packable.<locals>.as_packable'
1027
- for attr_name in ('__init__', '__repr__', '__eq__', '__post_init__'):
1171
+ old_qualname_prefix = "packable.<locals>.as_packable"
1172
+ for attr_name in ("__init__", "__repr__", "__eq__", "__post_init__"):
1028
1173
  attr = getattr(as_packable, attr_name, None)
1029
- if attr is not None and hasattr(attr, '__qualname__'):
1174
+ if attr is not None and hasattr(attr, "__qualname__"):
1030
1175
  if attr.__qualname__.startswith(old_qualname_prefix):
1031
1176
  attr.__qualname__ = attr.__qualname__.replace(
1032
1177
  old_qualname_prefix, class_name, 1
@@ -1042,4 +1187,4 @@ def packable( cls: type[_T] ) -> type[_T]:
1042
1187
 
1043
1188
  ##
1044
1189
 
1045
- return as_packable
1190
+ return as_packable