atdata 0.1.3b4__py3-none-any.whl → 0.2.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/dataset.py CHANGED
@@ -14,15 +14,17 @@ during serialization, enabling efficient storage of numerical data in WebDataset
14
14
  archives.
15
15
 
16
16
  Example:
17
- >>> @packable
18
- ... class ImageSample:
19
- ... image: NDArray
20
- ... label: str
21
- ...
22
- >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
23
- >>> for batch in ds.shuffled(batch_size=32):
24
- ... images = batch.image # Stacked numpy array (32, H, W, C)
25
- ... labels = batch.label # List of 32 strings
17
+ ::
18
+
19
+ >>> @packable
20
+ ... class ImageSample:
21
+ ... image: NDArray
22
+ ... label: str
23
+ ...
24
+ >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
25
+ >>> for batch in ds.shuffled(batch_size=32):
26
+ ... images = batch.image # Stacked numpy array (32, H, W, C)
27
+ ... labels = batch.label # List of 32 strings
26
28
  """
27
29
 
28
30
  ##
@@ -32,7 +34,6 @@ import webdataset as wds
32
34
 
33
35
  from pathlib import Path
34
36
  import uuid
35
- import functools
36
37
 
37
38
  import dataclasses
38
39
  import types
@@ -40,40 +41,33 @@ from dataclasses import (
40
41
  dataclass,
41
42
  asdict,
42
43
  )
43
- from abc import (
44
- ABC,
45
- abstractmethod,
46
- )
44
+ from abc import ABC
45
+
46
+ from ._sources import URLSource, S3Source
47
+ from ._protocols import DataSource
47
48
 
48
49
  from tqdm import tqdm
49
50
  import numpy as np
50
51
  import pandas as pd
52
+ import requests
51
53
 
52
54
  import typing
53
55
  from typing import (
54
56
  Any,
55
57
  Optional,
56
58
  Dict,
59
+ Iterator,
57
60
  Sequence,
58
61
  Iterable,
59
62
  Callable,
60
- Union,
61
- #
62
63
  Self,
63
64
  Generic,
64
65
  Type,
65
66
  TypeVar,
66
67
  TypeAlias,
68
+ dataclass_transform,
67
69
  )
68
- # from typing_inspect import get_bound, get_parameters
69
- from numpy.typing import (
70
- NDArray,
71
- ArrayLike,
72
- )
73
-
74
- #
75
-
76
- # import ekumen.atmosphere as eat
70
+ from numpy.typing import NDArray
77
71
 
78
72
  import msgpack
79
73
  import ormsgpack
@@ -86,6 +80,7 @@ from .lens import Lens, LensNetwork
86
80
 
87
81
  Pathlike = str | Path
88
82
 
83
+ # WebDataset sample/batch dictionaries (contain __key__, msgpack, etc.)
89
84
  WDSRawSample: TypeAlias = Dict[str, Any]
90
85
  WDSRawBatch: TypeAlias = Dict[str, Any]
91
86
 
@@ -96,83 +91,191 @@ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
96
91
  ##
97
92
  # Main base classes
98
93
 
99
- # TODO Check for best way to ensure this typevar is used as a dataclass type
100
- # DT = TypeVar( 'DT', bound = dataclass.__class__ )
101
94
  DT = TypeVar( 'DT' )
102
95
 
103
- MsgpackRawSample: TypeAlias = Dict[str, Any]
104
-
105
- # @dataclass
106
- # class ArrayBytes:
107
- # """Annotates bytes that should be interpreted as the raw contents of a
108
- # numpy NDArray"""
109
-
110
- # raw_bytes: bytes
111
- # """The raw bytes of the corresponding NDArray"""
112
-
113
- # def __init__( self,
114
- # array: Optional[ArrayLike] = None,
115
- # raw: Optional[bytes] = None,
116
- # ):
117
- # """TODO"""
118
-
119
- # if array is not None:
120
- # array = np.array( array )
121
- # self.raw_bytes = eh.array_to_bytes( array )
122
-
123
- # elif raw is not None:
124
- # self.raw_bytes = raw
125
-
126
- # else:
127
- # raise ValueError( 'Must provide either `array` or `raw` bytes' )
128
-
129
- # @property
130
- # def to_numpy( self ) -> NDArray:
131
- # """Return the `raw_bytes` data as an NDArray"""
132
- # return eh.bytes_to_array( self.raw_bytes )
133
96
 
134
97
  def _make_packable( x ):
135
- """Convert a value to a msgpack-compatible format.
136
-
137
- Args:
138
- x: A value to convert. If it's a numpy array, converts to bytes.
139
- Otherwise returns the value unchanged.
140
-
141
- Returns:
142
- The value in a format suitable for msgpack serialization.
143
- """
144
- # if isinstance( x, ArrayBytes ):
145
- # return x.raw_bytes
98
+ """Convert numpy arrays to bytes; pass through other values unchanged."""
146
99
  if isinstance( x, np.ndarray ):
147
100
  return eh.array_to_bytes( x )
148
101
  return x
149
102
 
150
- def _is_possibly_ndarray_type( t ):
151
- """Check if a type annotation is or contains NDArray.
152
-
153
- Args:
154
- t: A type annotation to check.
155
-
156
- Returns:
157
- ``True`` if the type is ``NDArray`` or a union containing ``NDArray``
158
- (e.g., ``NDArray | None``), ``False`` otherwise.
159
- """
160
103
 
161
- # Directly an NDArray
104
+ def _is_possibly_ndarray_type( t ):
105
+ """Return True if type annotation is NDArray or Optional[NDArray]."""
162
106
  if t == NDArray:
163
- # print( 'is an NDArray' )
164
107
  return True
165
-
166
- # Check for Optionals (i.e., NDArray | None)
167
108
  if isinstance( t, types.UnionType ):
168
- t_parts = t.__args__
169
- if any( x == NDArray
170
- for x in t_parts ):
171
- return True
172
-
173
- # Not an NDArray
109
+ return any( x == NDArray for x in t.__args__ )
174
110
  return False
175
111
 
112
+ class DictSample:
113
+ """Dynamic sample type providing dict-like access to raw msgpack data.
114
+
115
+ This class is the default sample type for datasets when no explicit type is
116
+ specified. It stores the raw unpacked msgpack data and provides both
117
+ attribute-style (``sample.field``) and dict-style (``sample["field"]``)
118
+ access to fields.
119
+
120
+ ``DictSample`` is useful for:
121
+ - Exploring datasets without defining a schema first
122
+ - Working with datasets that have variable schemas
123
+ - Prototyping before committing to a typed schema
124
+
125
+ To convert to a typed schema, use ``Dataset.as_type()`` with a
126
+ ``@packable``-decorated class. Every ``@packable`` class automatically
127
+ registers a lens from ``DictSample``, making this conversion seamless.
128
+
129
+ Example:
130
+ ::
131
+
132
+ >>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
133
+ >>> for sample in ds.ordered():
134
+ ... print(sample.some_field) # Attribute access
135
+ ... print(sample["other_field"]) # Dict access
136
+ ... print(sample.keys()) # Inspect available fields
137
+ ...
138
+ >>> # Convert to typed schema
139
+ >>> typed_ds = ds.as_type(MyTypedSample)
140
+
141
+ Note:
142
+ NDArray fields are stored as raw bytes in DictSample. They are only
143
+ converted to numpy arrays when accessed through a typed sample class.
144
+ """
145
+
146
+ __slots__ = ('_data',)
147
+
148
+ def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None:
149
+ """Create a DictSample from a dictionary or keyword arguments.
150
+
151
+ Args:
152
+ _data: Raw data dictionary. If provided, kwargs are ignored.
153
+ **kwargs: Field values if _data is not provided.
154
+ """
155
+ if _data is not None:
156
+ object.__setattr__(self, '_data', _data)
157
+ else:
158
+ object.__setattr__(self, '_data', kwargs)
159
+
160
+ @classmethod
161
+ def from_data(cls, data: dict[str, Any]) -> 'DictSample':
162
+ """Create a DictSample from unpacked msgpack data.
163
+
164
+ Args:
165
+ data: Dictionary with field names as keys.
166
+
167
+ Returns:
168
+ New DictSample instance wrapping the data.
169
+ """
170
+ return cls(_data=data)
171
+
172
+ @classmethod
173
+ def from_bytes(cls, bs: bytes) -> 'DictSample':
174
+ """Create a DictSample from raw msgpack bytes.
175
+
176
+ Args:
177
+ bs: Raw bytes from a msgpack-serialized sample.
178
+
179
+ Returns:
180
+ New DictSample instance with the unpacked data.
181
+ """
182
+ return cls.from_data(ormsgpack.unpackb(bs))
183
+
184
+ def __getattr__(self, name: str) -> Any:
185
+ """Access a field by attribute name.
186
+
187
+ Args:
188
+ name: Field name to access.
189
+
190
+ Returns:
191
+ The field value.
192
+
193
+ Raises:
194
+ AttributeError: If the field doesn't exist.
195
+ """
196
+ # Avoid infinite recursion for _data lookup
197
+ if name == '_data':
198
+ raise AttributeError(name)
199
+ try:
200
+ return self._data[name]
201
+ except KeyError:
202
+ raise AttributeError(
203
+ f"'{type(self).__name__}' has no field '{name}'. "
204
+ f"Available fields: {list(self._data.keys())}"
205
+ ) from None
206
+
207
+ def __getitem__(self, key: str) -> Any:
208
+ """Access a field by dict key.
209
+
210
+ Args:
211
+ key: Field name to access.
212
+
213
+ Returns:
214
+ The field value.
215
+
216
+ Raises:
217
+ KeyError: If the field doesn't exist.
218
+ """
219
+ return self._data[key]
220
+
221
+ def __contains__(self, key: str) -> bool:
222
+ """Check if a field exists."""
223
+ return key in self._data
224
+
225
+ def keys(self) -> list[str]:
226
+ """Return list of field names."""
227
+ return list(self._data.keys())
228
+
229
+ def values(self) -> list[Any]:
230
+ """Return list of field values."""
231
+ return list(self._data.values())
232
+
233
+ def items(self) -> list[tuple[str, Any]]:
234
+ """Return list of (field_name, value) tuples."""
235
+ return list(self._data.items())
236
+
237
+ def get(self, key: str, default: Any = None) -> Any:
238
+ """Get a field value with optional default.
239
+
240
+ Args:
241
+ key: Field name to access.
242
+ default: Value to return if field doesn't exist.
243
+
244
+ Returns:
245
+ The field value or default.
246
+ """
247
+ return self._data.get(key, default)
248
+
249
+ def to_dict(self) -> dict[str, Any]:
250
+ """Return a copy of the underlying data dictionary."""
251
+ return dict(self._data)
252
+
253
+ @property
254
+ def packed(self) -> bytes:
255
+ """Pack this sample's data into msgpack bytes.
256
+
257
+ Returns:
258
+ Raw msgpack bytes representing this sample's data.
259
+ """
260
+ return msgpack.packb(self._data)
261
+
262
+ @property
263
+ def as_wds(self) -> 'WDSRawSample':
264
+ """Pack this sample's data for writing to WebDataset.
265
+
266
+ Returns:
267
+ A dictionary with ``__key__`` and ``msgpack`` fields.
268
+ """
269
+ return {
270
+ '__key__': str(uuid.uuid1(0, 0)),
271
+ 'msgpack': self.packed,
272
+ }
273
+
274
+ def __repr__(self) -> str:
275
+ fields = ', '.join(f'{k}=...' for k in self._data.keys())
276
+ return f'DictSample({fields})'
277
+
278
+
176
279
  @dataclass
177
280
  class PackableSample( ABC ):
178
281
  """Base class for samples that can be serialized with msgpack.
@@ -187,28 +290,20 @@ class PackableSample( ABC ):
187
290
  2. Using the ``@packable`` decorator (recommended)
188
291
 
189
292
  Example:
190
- >>> @packable
191
- ... class MyData:
192
- ... name: str
193
- ... embeddings: NDArray
194
- ...
195
- >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
196
- >>> packed = sample.packed # Serialize to bytes
197
- >>> restored = MyData.from_bytes(packed) # Deserialize
293
+ ::
294
+
295
+ >>> @packable
296
+ ... class MyData:
297
+ ... name: str
298
+ ... embeddings: NDArray
299
+ ...
300
+ >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
301
+ >>> packed = sample.packed # Serialize to bytes
302
+ >>> restored = MyData.from_bytes(packed) # Deserialize
198
303
  """
199
304
 
200
305
  def _ensure_good( self ):
201
- """Auto-convert annotated NDArray fields from bytes to numpy arrays.
202
-
203
- This method scans all dataclass fields and for any field annotated as
204
- ``NDArray`` or ``NDArray | None``, automatically converts bytes values
205
- to numpy arrays using the helper deserialization function. This enables
206
- transparent handling of array serialization in msgpack data.
207
-
208
- Note:
209
- This is called during ``__post_init__`` to ensure proper type
210
- conversion after deserialization.
211
- """
306
+ """Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
212
307
 
213
308
  # Auto-convert known types when annotated
214
309
  # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
@@ -226,16 +321,13 @@ class PackableSample( ABC ):
226
321
  # based on what is provided
227
322
 
228
323
  if isinstance( var_cur_value, np.ndarray ):
229
- # we're good!
230
- pass
231
-
232
- # elif isinstance( var_cur_value, ArrayBytes ):
233
- # setattr( self, var_name, var_cur_value.to_numpy )
324
+ # Already the correct type, no conversion needed
325
+ continue
234
326
 
235
327
  elif isinstance( var_cur_value, bytes ):
236
- # TODO This does create a constraint that serialized bytes
237
- # in a field that might be an NDArray are always interpreted
238
- # as being the NDArray interpretation
328
+ # Design note: bytes in NDArray-typed fields are always interpreted
329
+ # as serialized arrays. This means raw bytes fields must not be
330
+ # annotated as NDArray.
239
331
  setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
240
332
 
241
333
  def __post_init__( self ):
@@ -244,20 +336,16 @@ class PackableSample( ABC ):
244
336
  ##
245
337
 
246
338
  @classmethod
247
- def from_data( cls, data: MsgpackRawSample ) -> Self:
339
+ def from_data( cls, data: WDSRawSample ) -> Self:
248
340
  """Create a sample instance from unpacked msgpack data.
249
341
 
250
342
  Args:
251
- data: A dictionary of unpacked msgpack data with keys matching
252
- the sample's field names.
343
+ data: Dictionary with keys matching the sample's field names.
253
344
 
254
345
  Returns:
255
- A new instance of this sample class with fields populated from
256
- the data dictionary and NDArray fields auto-converted from bytes.
346
+ New instance with NDArray fields auto-converted from bytes.
257
347
  """
258
- ret = cls( **data )
259
- ret._ensure_good()
260
- return ret
348
+ return cls( **data )
261
349
 
262
350
  @classmethod
263
351
  def from_bytes( cls, bs: bytes ) -> Self:
@@ -299,7 +387,6 @@ class PackableSample( ABC ):
299
387
 
300
388
  return ret
301
389
 
302
- # TODO Expand to allow for specifying explicit __key__
303
390
  @property
304
391
  def as_wds( self ) -> WDSRawSample:
305
392
  """Pack this sample's data for writing to WebDataset.
@@ -309,7 +396,8 @@ class PackableSample( ABC ):
309
396
  ``msgpack`` (packed sample data) fields suitable for WebDataset.
310
397
 
311
398
  Note:
312
- TODO: Expand to allow specifying explicit ``__key__`` values.
399
+ Keys are auto-generated as UUID v1 for time-sortable ordering.
400
+ Custom key specification is not currently supported.
313
401
  """
314
402
  return {
315
403
  # Generates a UUID that is timelike-sortable
@@ -318,25 +406,11 @@ class PackableSample( ABC ):
318
406
  }
319
407
 
320
408
  def _batch_aggregate( xs: Sequence ):
321
- """Aggregate a sequence of values into a batch-appropriate format.
322
-
323
- Args:
324
- xs: A sequence of values to aggregate. If the first element is a numpy
325
- array, all elements are stacked into a single array. Otherwise,
326
- returns a list.
327
-
328
- Returns:
329
- A numpy array (if elements are arrays) or a list (otherwise).
330
- """
331
-
409
+ """Stack arrays into numpy array with batch dim; otherwise return list."""
332
410
  if not xs:
333
- # Empty sequence
334
411
  return []
335
-
336
- # Aggregate
337
412
  if isinstance( xs[0], np.ndarray ):
338
413
  return np.array( list( xs ) )
339
-
340
414
  return list( xs )
341
415
 
342
416
  class SampleBatch( Generic[DT] ):
@@ -350,17 +424,27 @@ class SampleBatch( Generic[DT] ):
350
424
  NDArray fields are stacked into a numpy array with a batch dimension.
351
425
  Other fields are aggregated into a list.
352
426
 
353
- Type Parameters:
427
+ Parameters:
354
428
  DT: The sample type, must derive from ``PackableSample``.
355
429
 
356
430
  Attributes:
357
431
  samples: The list of sample instances in this batch.
358
432
 
359
433
  Example:
360
- >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
361
- >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
362
- >>> batch.names # Returns list of names
434
+ ::
435
+
436
+ >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
437
+ >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
438
+ >>> batch.names # Returns list of names
439
+
440
+ Note:
441
+ This class uses Python's ``__orig_class__`` mechanism to extract the
442
+ type parameter at runtime. Instances must be created using the
443
+ subscripted syntax ``SampleBatch[MyType](samples)`` rather than
444
+ calling the constructor directly with an unsubscripted class.
363
445
  """
446
+ # Design note: The docstring uses "Parameters:" for type parameters because
447
+ # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
364
448
 
365
449
  def __init__( self, samples: Sequence[DT] ):
366
450
  """Create a batch from a sequence of samples.
@@ -372,6 +456,7 @@ class SampleBatch( Generic[DT] ):
372
456
  """
373
457
  self.samples = list( samples )
374
458
  self._aggregate_cache = dict()
459
+ self._sample_type_cache: Type | None = None
375
460
 
376
461
  @property
377
462
  def sample_type( self ) -> Type:
@@ -380,7 +465,10 @@ class SampleBatch( Generic[DT] ):
380
465
  Returns:
381
466
  The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
382
467
  """
383
- return typing.get_args( self.__orig_class__)[0]
468
+ if self._sample_type_cache is None:
469
+ self._sample_type_cache = typing.get_args( self.__orig_class__)[0]
470
+ assert self._sample_type_cache is not None
471
+ return self._sample_type_cache
384
472
 
385
473
  def __getattr__( self, name ):
386
474
  """Aggregate an attribute across all samples in the batch.
@@ -411,23 +499,44 @@ class SampleBatch( Generic[DT] ):
411
499
  raise AttributeError( f'No sample attribute named {name}' )
412
500
 
413
501
 
414
- # class AnySample( BaseModel ):
415
- # """A sample that can hold anything"""
416
- # value: Any
502
+ ST = TypeVar( 'ST', bound = PackableSample )
503
+ RT = TypeVar( 'RT', bound = PackableSample )
504
+
505
+
506
+ class _ShardListStage(wds.utils.PipelineStage):
507
+ """Pipeline stage that yields {url: shard_id} dicts from a DataSource.
417
508
 
418
- # class AnyBatch( BaseModel ):
419
- # """A batch of `AnySample`s"""
420
- # values: list[AnySample]
509
+ This is analogous to SimpleShardList but works with any DataSource.
510
+ Used as the first stage before split_by_worker.
511
+ """
421
512
 
513
+ def __init__(self, source: DataSource):
514
+ self.source = source
422
515
 
423
- ST = TypeVar( 'ST', bound = PackableSample )
424
- # BT = TypeVar( 'BT' )
516
+ def run(self):
517
+ """Yield {url: shard_id} dicts for each shard."""
518
+ for shard_id in self.source.list_shards():
519
+ yield {"url": shard_id}
425
520
 
426
- RT = TypeVar( 'RT', bound = PackableSample )
427
521
 
428
- # TODO For python 3.13
429
- # BT = TypeVar( 'BT', default = None )
430
- # IT = TypeVar( 'IT', default = Any )
522
+ class _StreamOpenerStage(wds.utils.PipelineStage):
523
+ """Pipeline stage that opens streams from a DataSource.
524
+
525
+ Takes {url: shard_id} dicts and adds a stream using source.open_shard().
526
+ This replaces WebDataset's url_opener stage.
527
+ """
528
+
529
+ def __init__(self, source: DataSource):
530
+ self.source = source
531
+
532
+ def run(self, src):
533
+ """Open streams for each shard dict."""
534
+ for sample in src:
535
+ shard_id = sample["url"]
536
+ stream = self.source.open_shard(shard_id)
537
+ sample["stream"] = stream
538
+ yield sample
539
+
431
540
 
432
541
  class Dataset( Generic[ST] ):
433
542
  """A typed dataset built on WebDataset with lens transformations.
@@ -442,26 +551,31 @@ class Dataset( Generic[ST] ):
442
551
  - Type transformations via the lens system (``as_type()``)
443
552
  - Export to parquet format
444
553
 
445
- Type Parameters:
554
+ Parameters:
446
555
  ST: The sample type for this dataset, must derive from ``PackableSample``.
447
556
 
448
557
  Attributes:
449
558
  url: WebDataset brace-notation URL for the tar file(s).
450
559
 
451
560
  Example:
452
- >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
453
- >>> for sample in ds.ordered(batch_size=32):
454
- ... # sample is SampleBatch[MyData] with batch_size samples
455
- ... embeddings = sample.embeddings # shape: (32, ...)
456
- ...
457
- >>> # Transform to a different view
458
- >>> ds_view = ds.as_type(MyDataView)
561
+ ::
562
+
563
+ >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
564
+ >>> for sample in ds.ordered(batch_size=32):
565
+ ... # sample is SampleBatch[MyData] with batch_size samples
566
+ ... embeddings = sample.embeddings # shape: (32, ...)
567
+ ...
568
+ >>> # Transform to a different view
569
+ >>> ds_view = ds.as_type(MyDataView)
570
+
571
+ Note:
572
+ This class uses Python's ``__orig_class__`` mechanism to extract the
573
+ type parameter at runtime. Instances must be created using the
574
+ subscripted syntax ``Dataset[MyType](url)`` rather than calling the
575
+ constructor directly with an unsubscripted class.
459
576
  """
460
-
461
- # sample_class: Type = get_parameters( )
462
- # """The type of each returned sample from this `Dataset`'s iterator"""
463
- # batch_class: Type = get_bound( BT )
464
- # """The type of a batch built from `sample_class`"""
577
+ # Design note: The docstring uses "Parameters:" for type parameters because
578
+ # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
465
579
 
466
580
  @property
467
581
  def sample_type( self ) -> Type:
@@ -469,12 +583,11 @@ class Dataset( Generic[ST] ):
469
583
 
470
584
  Returns:
471
585
  The type parameter ``ST`` used when creating this ``Dataset[ST]``.
472
-
473
- Note:
474
- Extracts the type parameter at runtime using ``__orig_class__``.
475
586
  """
476
- # NOTE: Linting may fail here due to __orig_class__ being a runtime attribute
477
- return typing.get_args( self.__orig_class__ )[0]
587
+ if self._sample_type_cache is None:
588
+ self._sample_type_cache = typing.get_args( self.__orig_class__ )[0]
589
+ assert self._sample_type_cache is not None
590
+ return self._sample_type_cache
478
591
  @property
479
592
  def batch_type( self ) -> Type:
480
593
  """The type of batches produced by this dataset.
@@ -482,28 +595,60 @@ class Dataset( Generic[ST] ):
482
595
  Returns:
483
596
  ``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
484
597
  """
485
- # return self.__orig_class__.__args__[1]
486
598
  return SampleBatch[self.sample_type]
487
599
 
600
+ def __init__( self,
601
+ source: DataSource | str | None = None,
602
+ metadata_url: str | None = None,
603
+ *,
604
+ url: str | None = None,
605
+ ) -> None:
606
+ """Create a dataset from a DataSource or URL.
488
607
 
489
- # _schema_registry_sample: dict[str, Type]
490
- # _schema_registry_batch: dict[str, Type | None]
491
-
492
- #
608
+ Args:
609
+ source: Either a DataSource implementation or a WebDataset-compatible
610
+ URL string. If a string is provided, it's wrapped in URLSource
611
+ for backward compatibility.
493
612
 
494
- def __init__( self, url: str ) -> None:
495
- """Create a dataset from a WebDataset URL.
613
+ Examples:
614
+ - String URL: ``"path/to/file-{000000..000009}.tar"``
615
+ - URLSource: ``URLSource("https://example.com/data.tar")``
616
+ - S3Source: ``S3Source(bucket="my-bucket", keys=["data.tar"])``
496
617
 
497
- Args:
498
- url: WebDataset brace-notation URL pointing to tar files, e.g.,
499
- ``"path/to/file-{000000..000009}.tar"`` for multiple shards or
500
- ``"path/to/file-000000.tar"`` for a single shard.
618
+ metadata_url: Optional URL to msgpack-encoded metadata for this dataset.
619
+ url: Deprecated. Use ``source`` instead. Kept for backward compatibility.
501
620
  """
502
621
  super().__init__()
503
- self.url = url
504
622
 
505
- # Allow addition of automatic transformation of raw underlying data
623
+ # Handle backward compatibility: url= keyword argument
624
+ if source is None and url is not None:
625
+ source = url
626
+ elif source is None:
627
+ raise TypeError("Dataset() missing required argument: 'source' or 'url'")
628
+
629
+ # Normalize source: strings become URLSource for backward compatibility
630
+ if isinstance(source, str):
631
+ self._source: DataSource = URLSource(source)
632
+ self.url = source
633
+ else:
634
+ self._source = source
635
+ # For compatibility, expose URL if source has list_shards
636
+ shards = source.list_shards()
637
+ # Design note: Using first shard as url for legacy compatibility.
638
+ # Full shard list is available via list_shards() method.
639
+ self.url = shards[0] if shards else ""
640
+
641
+ self._metadata: dict[str, Any] | None = None
642
+ self.metadata_url: str | None = metadata_url
643
+ """Optional URL to msgpack-encoded metadata for this dataset."""
644
+
506
645
  self._output_lens: Lens | None = None
646
+ self._sample_type_cache: Type | None = None
647
+
648
+ @property
649
+ def source(self) -> DataSource:
650
+ """The underlying data source for this dataset."""
651
+ return self._source
507
652
 
508
653
  def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
509
654
  """View this dataset through a different sample type using a registered lens.
@@ -521,76 +666,104 @@ class Dataset( Generic[ST] ):
521
666
  ValueError: If no registered lens exists between the current
522
667
  sample type and the target type.
523
668
  """
524
- ret = Dataset[other]( self.url )
669
+ ret = Dataset[other]( self._source )
525
670
  # Get the singleton lens registry
526
671
  lenses = LensNetwork()
527
672
  ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
528
673
  return ret
529
674
 
530
- # @classmethod
531
- # def register( cls, uri: str,
532
- # sample_class: Type,
533
- # batch_class: Optional[Type] = None,
534
- # ):
535
- # """Register an `ekumen` schema to use a particular dataset sample class"""
536
- # cls._schema_registry_sample[uri] = sample_class
537
- # cls._schema_registry_batch[uri] = batch_class
538
-
539
- # @classmethod
540
- # def at( cls, uri: str ) -> 'Dataset':
541
- # """Create a Dataset for the `ekumen` index entry at `uri`"""
542
- # client = eat.Client()
543
- # return cls( )
544
-
545
- # Common functionality
546
-
547
675
  @property
548
- def shard_list( self ) -> list[str]:
549
- """List of individual dataset shards
550
-
676
+ def shards(self) -> Iterator[str]:
677
+ """Lazily iterate over shard identifiers.
678
+
679
+ Yields:
680
+ Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
681
+
682
+ Example:
683
+ ::
684
+
685
+ >>> for shard in ds.shards:
686
+ ... print(f"Processing {shard}")
687
+ """
688
+ return iter(self._source.list_shards())
689
+
690
+ def list_shards(self) -> list[str]:
691
+ """Get list of individual dataset shards.
692
+
551
693
  Returns:
552
694
  A full (non-lazy) list of the individual ``tar`` files within the
553
695
  source WebDataset.
554
696
  """
555
- pipe = wds.pipeline.DataPipeline(
556
- wds.shardlists.SimpleShardList( self.url ),
557
- wds.filters.map( lambda x: x['url'] )
697
+ return self._source.list_shards()
698
+
699
+ # Legacy alias for backwards compatibility
700
+ @property
701
+ def shard_list(self) -> list[str]:
702
+ """List of individual dataset shards (deprecated, use list_shards()).
703
+
704
+ .. deprecated::
705
+ Use :meth:`list_shards` instead.
706
+ """
707
+ import warnings
708
+ warnings.warn(
709
+ "shard_list is deprecated, use list_shards() instead",
710
+ DeprecationWarning,
711
+ stacklevel=2,
558
712
  )
559
- return list( pipe )
713
+ return self.list_shards()
714
+
715
+ @property
716
+ def metadata( self ) -> dict[str, Any] | None:
717
+ """Fetch and cache metadata from metadata_url.
718
+
719
+ Returns:
720
+ Deserialized metadata dictionary, or None if no metadata_url is set.
721
+
722
+ Raises:
723
+ requests.HTTPError: If metadata fetch fails.
724
+ """
725
+ if self.metadata_url is None:
726
+ return None
727
+
728
+ if self._metadata is None:
729
+ with requests.get( self.metadata_url, stream = True ) as response:
730
+ response.raise_for_status()
731
+ self._metadata = msgpack.unpackb( response.content, raw = False )
732
+
733
+ # Use our cached values
734
+ return self._metadata
560
735
 
561
736
  def ordered( self,
562
- batch_size: int | None = 1,
737
+ batch_size: int | None = None,
563
738
  ) -> Iterable[ST]:
564
739
  """Iterate over the dataset in order
565
-
740
+
566
741
  Args:
567
742
  batch_size (:obj:`int`, optional): The size of iterated batches.
568
- Default: 1. If ``None``, iterates over one sample at a time
569
- with no batch dimension.
570
-
743
+ Default: None (unbatched). If ``None``, iterates over one
744
+ sample at a time with no batch dimension.
745
+
571
746
  Returns:
572
747
  :obj:`webdataset.DataPipeline` A data pipeline that iterates over
573
748
  the dataset in its original sample order
574
-
575
- """
576
749
 
750
+ """
577
751
  if batch_size is None:
578
- # TODO Duplication here
579
752
  return wds.pipeline.DataPipeline(
580
- wds.shardlists.SimpleShardList( self.url ),
753
+ _ShardListStage(self._source),
581
754
  wds.shardlists.split_by_worker,
582
- #
583
- wds.tariterators.tarfile_to_samples(),
584
- # wds.map( self.preprocess ),
755
+ _StreamOpenerStage(self._source),
756
+ wds.tariterators.tar_file_expander,
757
+ wds.tariterators.group_by_keys,
585
758
  wds.filters.map( self.wrap ),
586
759
  )
587
760
 
588
761
  return wds.pipeline.DataPipeline(
589
- wds.shardlists.SimpleShardList( self.url ),
762
+ _ShardListStage(self._source),
590
763
  wds.shardlists.split_by_worker,
591
- #
592
- wds.tariterators.tarfile_to_samples(),
593
- # wds.map( self.preprocess ),
764
+ _StreamOpenerStage(self._source),
765
+ wds.tariterators.tar_file_expander,
766
+ wds.tariterators.group_by_keys,
594
767
  wds.filters.batched( batch_size ),
595
768
  wds.filters.map( self.wrap_batch ),
596
769
  )
@@ -598,7 +771,7 @@ class Dataset( Generic[ST] ):
598
771
  def shuffled( self,
599
772
  buffer_shards: int = 100,
600
773
  buffer_samples: int = 10_000,
601
- batch_size: int | None = 1,
774
+ batch_size: int | None = None,
602
775
  ) -> Iterable[ST]:
603
776
  """Iterate over the dataset in random order.
604
777
 
@@ -609,8 +782,9 @@ class Dataset( Generic[ST] ):
609
782
  buffer_samples: Number of samples to buffer for shuffling within
610
783
  shards. Larger values increase randomness but use more memory.
611
784
  Default: 10,000.
612
- batch_size: The size of iterated batches. Default: 1. If ``None``,
613
- iterates over one sample at a time with no batch dimension.
785
+ batch_size: The size of iterated batches. Default: None (unbatched).
786
+ If ``None``, iterates over one sample at a time with no batch
787
+ dimension.
614
788
 
615
789
  Returns:
616
790
  A WebDataset data pipeline that iterates over the dataset in
@@ -618,44 +792,74 @@ class Dataset( Generic[ST] ):
618
792
  ``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
619
793
  samples.
620
794
  """
621
-
622
795
  if batch_size is None:
623
- # TODO Duplication here
624
796
  return wds.pipeline.DataPipeline(
625
- wds.shardlists.SimpleShardList( self.url ),
797
+ _ShardListStage(self._source),
626
798
  wds.filters.shuffle( buffer_shards ),
627
799
  wds.shardlists.split_by_worker,
628
- #
629
- wds.tariterators.tarfile_to_samples(),
630
- # wds.shuffle( buffer_samples ),
631
- # wds.map( self.preprocess ),
800
+ _StreamOpenerStage(self._source),
801
+ wds.tariterators.tar_file_expander,
802
+ wds.tariterators.group_by_keys,
632
803
  wds.filters.shuffle( buffer_samples ),
633
804
  wds.filters.map( self.wrap ),
634
805
  )
635
806
 
636
807
  return wds.pipeline.DataPipeline(
637
- wds.shardlists.SimpleShardList( self.url ),
808
+ _ShardListStage(self._source),
638
809
  wds.filters.shuffle( buffer_shards ),
639
810
  wds.shardlists.split_by_worker,
640
- #
641
- wds.tariterators.tarfile_to_samples(),
642
- # wds.shuffle( buffer_samples ),
643
- # wds.map( self.preprocess ),
811
+ _StreamOpenerStage(self._source),
812
+ wds.tariterators.tar_file_expander,
813
+ wds.tariterators.group_by_keys,
644
814
  wds.filters.shuffle( buffer_samples ),
645
815
  wds.filters.batched( batch_size ),
646
816
  wds.filters.map( self.wrap_batch ),
647
817
  )
648
818
 
649
- # TODO Rewrite to eliminate `pandas` dependency directly calling
650
- # `fastparquet`
819
+ # Design note: Uses pandas for parquet export. Could be replaced with
820
+ # direct fastparquet calls to reduce dependencies if needed.
651
821
  def to_parquet( self, path: Pathlike,
652
822
  sample_map: Optional[SampleExportMap] = None,
653
823
  maxcount: Optional[int] = None,
654
824
  **kwargs,
655
825
  ):
656
- """Save dataset contents to a `parquet` file at `path`
826
+ """Export dataset contents to parquet format.
827
+
828
+ Converts all samples to a pandas DataFrame and saves to parquet file(s).
829
+ Useful for interoperability with data analysis tools.
657
830
 
658
- `kwargs` sent to `pandas.to_parquet`
831
+ Args:
832
+ path: Output path for the parquet file. If ``maxcount`` is specified,
833
+ files are named ``{stem}-{segment:06d}.parquet``.
834
+ sample_map: Optional function to convert samples to dictionaries.
835
+ Defaults to ``dataclasses.asdict``.
836
+ maxcount: If specified, split output into multiple files with at most
837
+ this many samples each. Recommended for large datasets.
838
+ **kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
839
+ Common options include ``compression``, ``index``, ``engine``.
840
+
841
+ Warning:
842
+ **Memory Usage**: When ``maxcount=None`` (default), this method loads
843
+ the **entire dataset into memory** as a pandas DataFrame before writing.
844
+ For large datasets, this can cause memory exhaustion.
845
+
846
+ For datasets larger than available RAM, always specify ``maxcount``::
847
+
848
+ # Safe for large datasets - processes in chunks
849
+ ds.to_parquet("output.parquet", maxcount=10000)
850
+
851
+ This creates multiple parquet files: ``output-000000.parquet``,
852
+ ``output-000001.parquet``, etc.
853
+
854
+ Example:
855
+ ::
856
+
857
+ >>> ds = Dataset[MySample]("data.tar")
858
+ >>> # Small dataset - load all at once
859
+ >>> ds.to_parquet("output.parquet")
860
+ >>>
861
+ >>> # Large dataset - process in chunks
862
+ >>> ds.to_parquet("output.parquet", maxcount=50000)
659
863
  """
660
864
  ##
661
865
 
@@ -683,11 +887,11 @@ class Dataset( Generic[ST] ):
683
887
 
684
888
  cur_segment = 0
685
889
  cur_buffer = []
686
- path_template = (path.parent / f'{path.stem}-%06d.{path.suffix}').as_posix()
890
+ path_template = (path.parent / f'{path.stem}-{{:06d}}{path.suffix}').as_posix()
687
891
 
688
892
  for x in self.ordered( batch_size = None ):
689
893
  cur_buffer.append( sample_map( x ) )
690
-
894
+
691
895
  if len( cur_buffer ) >= maxcount:
692
896
  # Write current segment
693
897
  cur_path = path_template.format( cur_segment )
@@ -703,24 +907,7 @@ class Dataset( Generic[ST] ):
703
907
  df = pd.DataFrame( cur_buffer )
704
908
  df.to_parquet( cur_path, **kwargs )
705
909
 
706
-
707
- # Implemented by specific subclasses
708
-
709
- # @property
710
- # @abstractmethod
711
- # def url( self ) -> str:
712
- # """str: Brace-notation URL of the underlying full WebDataset"""
713
- # pass
714
-
715
- # @classmethod
716
- # # TODO replace Any with IT
717
- # def preprocess( cls, sample: WDSRawSample ) -> Any:
718
- # """Pre-built preprocessor for a raw `sample` from the given dataset"""
719
- # return sample
720
-
721
- # @classmethod
722
- # TODO replace Any with IT
723
- def wrap( self, sample: MsgpackRawSample ) -> ST:
910
+ def wrap( self, sample: WDSRawSample ) -> ST:
724
911
  """Wrap a raw msgpack sample into the appropriate dataset-specific type.
725
912
 
726
913
  Args:
@@ -731,27 +918,16 @@ class Dataset( Generic[ST] ):
731
918
  A deserialized sample of type ``ST``, optionally transformed through
732
919
  a lens if ``as_type()`` was called.
733
920
  """
734
- assert 'msgpack' in sample
735
- assert type( sample['msgpack'] ) == bytes
736
-
921
+ if 'msgpack' not in sample:
922
+ raise ValueError(f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}")
923
+ if not isinstance(sample['msgpack'], bytes):
924
+ raise ValueError(f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}")
925
+
737
926
  if self._output_lens is None:
738
927
  return self.sample_type.from_bytes( sample['msgpack'] )
739
928
 
740
929
  source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
741
930
  return self._output_lens( source_sample )
742
-
743
- # try:
744
- # assert type( sample ) == dict
745
- # return cls.sample_class( **{
746
- # k: v
747
- # for k, v in sample.items() if k != '__key__'
748
- # } )
749
-
750
- # except Exception as e:
751
- # # Sample constructor failed -- revert to default
752
- # return AnySample(
753
- # value = sample,
754
- # )
755
931
 
756
932
  def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
757
933
  """Wrap a batch of raw msgpack samples into a typed SampleBatch.
@@ -769,7 +945,8 @@ class Dataset( Generic[ST] ):
769
945
  aggregates them into a batch.
770
946
  """
771
947
 
772
- assert 'msgpack' in batch
948
+ if 'msgpack' not in batch:
949
+ raise ValueError(f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}")
773
950
 
774
951
  if self._output_lens is None:
775
952
  batch_unpacked = [ self.sample_type.from_bytes( bs )
@@ -782,58 +959,44 @@ class Dataset( Generic[ST] ):
782
959
  for s in batch_source ]
783
960
  return SampleBatch[self.sample_type]( batch_view )
784
961
 
785
- # # @classmethod
786
- # def wrap_batch( self, batch: WDSRawBatch ) -> BT:
787
- # """Wrap a `batch` of samples into the appropriate dataset-specific type
788
-
789
- # This default implementation simply creates a list one sample at a time
790
- # """
791
- # assert cls.batch_class is not None, 'No batch class specified'
792
- # return cls.batch_class( **batch )
793
-
794
962
 
795
- ##
796
- # Shortcut decorators
797
-
798
- # def packable( cls ):
799
- # """TODO"""
800
-
801
- # def decorator( cls ):
802
- # # Create a new class dynamically
803
- # # The new class inherits from the new_parent_class first, then the original cls
804
- # new_bases = (PackableSample,) + cls.__bases__
805
- # new_cls = type(cls.__name__, new_bases, dict(cls.__dict__))
963
+ _T = TypeVar('_T')
806
964
 
807
- # # Optionally, update __module__ and __qualname__ for better introspection
808
- # new_cls.__module__ = cls.__module__
809
- # new_cls.__qualname__ = cls.__qualname__
810
965
 
811
- # return new_cls
812
- # return decorator
813
-
814
- def packable( cls ):
966
+ @dataclass_transform()
967
+ def packable( cls: type[_T] ) -> type[_T]:
815
968
  """Decorator to convert a regular class into a ``PackableSample``.
816
969
 
817
970
  This decorator transforms a class into a dataclass that inherits from
818
971
  ``PackableSample``, enabling automatic msgpack serialization/deserialization
819
972
  with special handling for NDArray fields.
820
973
 
974
+ The resulting class satisfies the ``Packable`` protocol, making it compatible
975
+ with all atdata APIs that accept packable types (e.g., ``publish_schema``,
976
+ lens transformations, etc.).
977
+
821
978
  Args:
822
979
  cls: The class to convert. Should have type annotations for its fields.
823
980
 
824
981
  Returns:
825
982
  A new dataclass that inherits from ``PackableSample`` with the same
826
- name and annotations as the original class.
827
-
828
- Example:
829
- >>> @packable
830
- ... class MyData:
831
- ... name: str
832
- ... values: NDArray
833
- ...
834
- >>> sample = MyData(name="test", values=np.array([1, 2, 3]))
835
- >>> bytes_data = sample.packed
836
- >>> restored = MyData.from_bytes(bytes_data)
983
+ name and annotations as the original class. The class satisfies the
984
+ ``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
985
+
986
+ Examples:
987
+ This is a test of the functionality::
988
+
989
+ @packable
990
+ class MyData:
991
+ name: str
992
+ values: NDArray
993
+
994
+ sample = MyData(name="test", values=np.array([1, 2, 3]))
995
+ bytes_data = sample.packed
996
+ restored = MyData.from_bytes(bytes_data)
997
+
998
+ # Works with Packable-typed APIs
999
+ index.publish_schema(MyData, version="1.0.0") # Type-safe
837
1000
  """
838
1001
 
839
1002
  ##
@@ -850,9 +1013,32 @@ def packable( cls ):
850
1013
  def __post_init__( self ):
851
1014
  return PackableSample.__post_init__( self )
852
1015
 
853
- # TODO This doesn't properly carry over the original
1016
+ # Restore original class identity for better repr/debugging
854
1017
  as_packable.__name__ = class_name
1018
+ as_packable.__qualname__ = class_name
1019
+ as_packable.__module__ = cls.__module__
855
1020
  as_packable.__annotations__ = class_annotations
1021
+ if cls.__doc__:
1022
+ as_packable.__doc__ = cls.__doc__
1023
+
1024
+ # Fix qualnames of dataclass-generated methods so they don't show
1025
+ # 'packable.<locals>.as_packable' in help() and IDE hints
1026
+ old_qualname_prefix = 'packable.<locals>.as_packable'
1027
+ for attr_name in ('__init__', '__repr__', '__eq__', '__post_init__'):
1028
+ attr = getattr(as_packable, attr_name, None)
1029
+ if attr is not None and hasattr(attr, '__qualname__'):
1030
+ if attr.__qualname__.startswith(old_qualname_prefix):
1031
+ attr.__qualname__ = attr.__qualname__.replace(
1032
+ old_qualname_prefix, class_name, 1
1033
+ )
1034
+
1035
+ # Auto-register lens from DictSample to this type
1036
+ # This enables ds.as_type(MyType) when ds is Dataset[DictSample]
1037
+ def _dict_to_typed(ds: DictSample) -> as_packable:
1038
+ return as_packable.from_data(ds._data)
1039
+
1040
+ _dict_lens = Lens(_dict_to_typed)
1041
+ LensNetwork().register(_dict_lens)
856
1042
 
857
1043
  ##
858
1044