atdata 0.2.0a1__py3-none-any.whl → 0.2.3b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/dataset.py CHANGED
@@ -13,7 +13,7 @@ The implementation handles automatic conversion between numpy arrays and bytes
13
13
  during serialization, enabling efficient storage of numerical data in WebDataset
14
14
  archives.
15
15
 
16
- Example:
16
+ Examples:
17
17
  >>> @packable
18
18
  ... class ImageSample:
19
19
  ... image: NDArray
@@ -41,6 +41,9 @@ from dataclasses import (
41
41
  )
42
42
  from abc import ABC
43
43
 
44
+ from ._sources import URLSource
45
+ from ._protocols import DataSource
46
+
44
47
  from tqdm import tqdm
45
48
  import numpy as np
46
49
  import pandas as pd
@@ -51,16 +54,17 @@ from typing import (
51
54
  Any,
52
55
  Optional,
53
56
  Dict,
57
+ Iterator,
54
58
  Sequence,
55
59
  Iterable,
56
60
  Callable,
57
- Union,
58
- #
59
61
  Self,
60
62
  Generic,
61
63
  Type,
62
64
  TypeVar,
63
65
  TypeAlias,
66
+ dataclass_transform,
67
+ overload,
64
68
  )
65
69
  from numpy.typing import NDArray
66
70
 
@@ -75,63 +79,203 @@ from .lens import Lens, LensNetwork
75
79
 
76
80
  Pathlike = str | Path
77
81
 
82
+ # WebDataset sample/batch dictionaries (contain __key__, msgpack, etc.)
78
83
  WDSRawSample: TypeAlias = Dict[str, Any]
79
84
  WDSRawBatch: TypeAlias = Dict[str, Any]
80
85
 
81
86
  SampleExportRow: TypeAlias = Dict[str, Any]
82
- SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
87
+ SampleExportMap: TypeAlias = Callable[["PackableSample"], SampleExportRow]
83
88
 
84
89
 
85
90
  ##
86
91
  # Main base classes
87
92
 
88
- DT = TypeVar( 'DT' )
93
+ DT = TypeVar("DT")
89
94
 
90
- MsgpackRawSample: TypeAlias = Dict[str, Any]
91
95
 
96
+ def _make_packable(x):
97
+ """Convert numpy arrays to bytes; pass through other values unchanged."""
98
+ if isinstance(x, np.ndarray):
99
+ return eh.array_to_bytes(x)
100
+ return x
92
101
 
93
- def _make_packable( x ):
94
- """Convert a value to a msgpack-compatible format.
95
102
 
96
- Args:
97
- x: A value to convert. If it's a numpy array, converts to bytes.
98
- Otherwise returns the value unchanged.
103
+ def _is_possibly_ndarray_type(t):
104
+ """Return True if type annotation is NDArray or Optional[NDArray]."""
105
+ if t == NDArray:
106
+ return True
107
+ if isinstance(t, types.UnionType):
108
+ return any(x == NDArray for x in t.__args__)
109
+ return False
99
110
 
100
- Returns:
101
- The value in a format suitable for msgpack serialization.
102
- """
103
- if isinstance( x, np.ndarray ):
104
- return eh.array_to_bytes( x )
105
- return x
106
111
 
107
- def _is_possibly_ndarray_type( t ):
108
- """Check if a type annotation is or contains NDArray.
112
+ class DictSample:
113
+ """Dynamic sample type providing dict-like access to raw msgpack data.
109
114
 
110
- Args:
111
- t: A type annotation to check.
115
+ This class is the default sample type for datasets when no explicit type is
116
+ specified. It stores the raw unpacked msgpack data and provides both
117
+ attribute-style (``sample.field``) and dict-style (``sample["field"]``)
118
+ access to fields.
112
119
 
113
- Returns:
114
- ``True`` if the type is ``NDArray`` or a union containing ``NDArray``
115
- (e.g., ``NDArray | None``), ``False`` otherwise.
120
+ ``DictSample`` is useful for:
121
+ - Exploring datasets without defining a schema first
122
+ - Working with datasets that have variable schemas
123
+ - Prototyping before committing to a typed schema
124
+
125
+ To convert to a typed schema, use ``Dataset.as_type()`` with a
126
+ ``@packable``-decorated class. Every ``@packable`` class automatically
127
+ registers a lens from ``DictSample``, making this conversion seamless.
128
+
129
+ Examples:
130
+ >>> ds = load_dataset("path/to/data.tar") # Returns Dataset[DictSample]
131
+ >>> for sample in ds.ordered():
132
+ ... print(sample.some_field) # Attribute access
133
+ ... print(sample["other_field"]) # Dict access
134
+ ... print(sample.keys()) # Inspect available fields
135
+ ...
136
+ >>> # Convert to typed schema
137
+ >>> typed_ds = ds.as_type(MyTypedSample)
138
+
139
+ Note:
140
+ NDArray fields are stored as raw bytes in DictSample. They are only
141
+ converted to numpy arrays when accessed through a typed sample class.
116
142
  """
117
143
 
118
- # Directly an NDArray
119
- if t == NDArray:
120
- # print( 'is an NDArray' )
121
- return True
122
-
123
- # Check for Optionals (i.e., NDArray | None)
124
- if isinstance( t, types.UnionType ):
125
- t_parts = t.__args__
126
- if any( x == NDArray
127
- for x in t_parts ):
128
- return True
129
-
130
- # Not an NDArray
131
- return False
144
+ __slots__ = ("_data",)
145
+
146
+ def __init__(self, _data: dict[str, Any] | None = None, **kwargs: Any) -> None:
147
+ """Create a DictSample from a dictionary or keyword arguments.
148
+
149
+ Args:
150
+ _data: Raw data dictionary. If provided, kwargs are ignored.
151
+ **kwargs: Field values if _data is not provided.
152
+ """
153
+ if _data is not None:
154
+ object.__setattr__(self, "_data", _data)
155
+ else:
156
+ object.__setattr__(self, "_data", kwargs)
157
+
158
+ @classmethod
159
+ def from_data(cls, data: dict[str, Any]) -> "DictSample":
160
+ """Create a DictSample from unpacked msgpack data.
161
+
162
+ Args:
163
+ data: Dictionary with field names as keys.
164
+
165
+ Returns:
166
+ New DictSample instance wrapping the data.
167
+ """
168
+ return cls(_data=data)
169
+
170
+ @classmethod
171
+ def from_bytes(cls, bs: bytes) -> "DictSample":
172
+ """Create a DictSample from raw msgpack bytes.
173
+
174
+ Args:
175
+ bs: Raw bytes from a msgpack-serialized sample.
176
+
177
+ Returns:
178
+ New DictSample instance with the unpacked data.
179
+ """
180
+ return cls.from_data(ormsgpack.unpackb(bs))
181
+
182
+ def __getattr__(self, name: str) -> Any:
183
+ """Access a field by attribute name.
184
+
185
+ Args:
186
+ name: Field name to access.
187
+
188
+ Returns:
189
+ The field value.
190
+
191
+ Raises:
192
+ AttributeError: If the field doesn't exist.
193
+ """
194
+ # Avoid infinite recursion for _data lookup
195
+ if name == "_data":
196
+ raise AttributeError(name)
197
+ try:
198
+ return self._data[name]
199
+ except KeyError:
200
+ raise AttributeError(
201
+ f"'{type(self).__name__}' has no field '{name}'. "
202
+ f"Available fields: {list(self._data.keys())}"
203
+ ) from None
204
+
205
+ def __getitem__(self, key: str) -> Any:
206
+ """Access a field by dict key.
207
+
208
+ Args:
209
+ key: Field name to access.
210
+
211
+ Returns:
212
+ The field value.
213
+
214
+ Raises:
215
+ KeyError: If the field doesn't exist.
216
+ """
217
+ return self._data[key]
218
+
219
+ def __contains__(self, key: str) -> bool:
220
+ """Check if a field exists."""
221
+ return key in self._data
222
+
223
+ def keys(self) -> list[str]:
224
+ """Return list of field names."""
225
+ return list(self._data.keys())
226
+
227
+ def values(self) -> list[Any]:
228
+ """Return list of field values."""
229
+ return list(self._data.values())
230
+
231
+ def items(self) -> list[tuple[str, Any]]:
232
+ """Return list of (field_name, value) tuples."""
233
+ return list(self._data.items())
234
+
235
+ def get(self, key: str, default: Any = None) -> Any:
236
+ """Get a field value with optional default.
237
+
238
+ Args:
239
+ key: Field name to access.
240
+ default: Value to return if field doesn't exist.
241
+
242
+ Returns:
243
+ The field value or default.
244
+ """
245
+ return self._data.get(key, default)
246
+
247
+ def to_dict(self) -> dict[str, Any]:
248
+ """Return a copy of the underlying data dictionary."""
249
+ return dict(self._data)
250
+
251
+ @property
252
+ def packed(self) -> bytes:
253
+ """Pack this sample's data into msgpack bytes.
254
+
255
+ Returns:
256
+ Raw msgpack bytes representing this sample's data.
257
+ """
258
+ return msgpack.packb(self._data)
259
+
260
+ @property
261
+ def as_wds(self) -> "WDSRawSample":
262
+ """Pack this sample's data for writing to WebDataset.
263
+
264
+ Returns:
265
+ A dictionary with ``__key__`` and ``msgpack`` fields.
266
+ """
267
+ return {
268
+ "__key__": str(uuid.uuid1(0, 0)),
269
+ "msgpack": self.packed,
270
+ }
271
+
272
+ def __repr__(self) -> str:
273
+ fields = ", ".join(f"{k}=..." for k in self._data.keys())
274
+ return f"DictSample({fields})"
275
+
132
276
 
133
277
  @dataclass
134
- class PackableSample( ABC ):
278
+ class PackableSample(ABC):
135
279
  """Base class for samples that can be serialized with msgpack.
136
280
 
137
281
  This abstract base class provides automatic serialization/deserialization
@@ -143,7 +287,7 @@ class PackableSample( ABC ):
143
287
  1. Direct inheritance with the ``@dataclass`` decorator
144
288
  2. Using the ``@packable`` decorator (recommended)
145
289
 
146
- Example:
290
+ Examples:
147
291
  >>> @packable
148
292
  ... class MyData:
149
293
  ... name: str
@@ -154,67 +298,53 @@ class PackableSample( ABC ):
154
298
  >>> restored = MyData.from_bytes(packed) # Deserialize
155
299
  """
156
300
 
157
- def _ensure_good( self ):
158
- """Auto-convert annotated NDArray fields from bytes to numpy arrays.
159
-
160
- This method scans all dataclass fields and for any field annotated as
161
- ``NDArray`` or ``NDArray | None``, automatically converts bytes values
162
- to numpy arrays using the helper deserialization function. This enables
163
- transparent handling of array serialization in msgpack data.
164
-
165
- Note:
166
- This is called during ``__post_init__`` to ensure proper type
167
- conversion after deserialization.
168
- """
301
+ def _ensure_good(self):
302
+ """Convert bytes to NDArray for fields annotated as NDArray or NDArray | None."""
169
303
 
170
304
  # Auto-convert known types when annotated
171
305
  # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
172
- for field in dataclasses.fields( self ):
306
+ for field in dataclasses.fields(self):
173
307
  var_name = field.name
174
308
  var_type = field.type
175
309
 
176
310
  # Annotation for this variable is to be an NDArray
177
- if _is_possibly_ndarray_type( var_type ):
311
+ if _is_possibly_ndarray_type(var_type):
178
312
  # ... so, we'll always auto-convert to numpy
179
313
 
180
- var_cur_value = getattr( self, var_name )
314
+ var_cur_value = getattr(self, var_name)
181
315
 
182
316
  # Execute the appropriate conversion for intermediate data
183
317
  # based on what is provided
184
318
 
185
- if isinstance( var_cur_value, np.ndarray ):
319
+ if isinstance(var_cur_value, np.ndarray):
186
320
  # Already the correct type, no conversion needed
187
321
  continue
188
322
 
189
- elif isinstance( var_cur_value, bytes ):
190
- # TODO This does create a constraint that serialized bytes
191
- # in a field that might be an NDArray are always interpreted
192
- # as being the NDArray interpretation
193
- setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
323
+ elif isinstance(var_cur_value, bytes):
324
+ # Design note: bytes in NDArray-typed fields are always interpreted
325
+ # as serialized arrays. This means raw bytes fields must not be
326
+ # annotated as NDArray.
327
+ setattr(self, var_name, eh.bytes_to_array(var_cur_value))
194
328
 
195
- def __post_init__( self ):
329
+ def __post_init__(self):
196
330
  self._ensure_good()
197
331
 
198
332
  ##
199
333
 
200
334
  @classmethod
201
- def from_data( cls, data: MsgpackRawSample ) -> Self:
335
+ def from_data(cls, data: WDSRawSample) -> Self:
202
336
  """Create a sample instance from unpacked msgpack data.
203
337
 
204
338
  Args:
205
- data: A dictionary of unpacked msgpack data with keys matching
206
- the sample's field names.
339
+ data: Dictionary with keys matching the sample's field names.
207
340
 
208
341
  Returns:
209
- A new instance of this sample class with fields populated from
210
- the data dictionary and NDArray fields auto-converted from bytes.
342
+ New instance with NDArray fields auto-converted from bytes.
211
343
  """
212
- ret = cls( **data )
213
- ret._ensure_good()
214
- return ret
215
-
344
+ return cls(**data)
345
+
216
346
  @classmethod
217
- def from_bytes( cls, bs: bytes ) -> Self:
347
+ def from_bytes(cls, bs: bytes) -> Self:
218
348
  """Create a sample instance from raw msgpack bytes.
219
349
 
220
350
  Args:
@@ -223,10 +353,10 @@ class PackableSample( ABC ):
223
353
  Returns:
224
354
  A new instance of this sample class deserialized from the bytes.
225
355
  """
226
- return cls.from_data( ormsgpack.unpackb( bs ) )
356
+ return cls.from_data(ormsgpack.unpackb(bs))
227
357
 
228
358
  @property
229
- def packed( self ) -> bytes:
359
+ def packed(self) -> bytes:
230
360
  """Pack this sample's data into msgpack bytes.
231
361
 
232
362
  NDArray fields are automatically converted to bytes before packing.
@@ -241,21 +371,17 @@ class PackableSample( ABC ):
241
371
 
242
372
  # Make sure that all of our (possibly unpackable) data is in a packable
243
373
  # format
244
- o = {
245
- k: _make_packable( v )
246
- for k, v in vars( self ).items()
247
- }
374
+ o = {k: _make_packable(v) for k, v in vars(self).items()}
248
375
 
249
- ret = msgpack.packb( o )
376
+ ret = msgpack.packb(o)
250
377
 
251
378
  if ret is None:
252
- raise RuntimeError( f'Failed to pack sample to bytes: {o}' )
379
+ raise RuntimeError(f"Failed to pack sample to bytes: {o}")
253
380
 
254
381
  return ret
255
-
256
- # TODO Expand to allow for specifying explicit __key__
382
+
257
383
  @property
258
- def as_wds( self ) -> WDSRawSample:
384
+ def as_wds(self) -> WDSRawSample:
259
385
  """Pack this sample's data for writing to WebDataset.
260
386
 
261
387
  Returns:
@@ -263,37 +389,26 @@ class PackableSample( ABC ):
263
389
  ``msgpack`` (packed sample data) fields suitable for WebDataset.
264
390
 
265
391
  Note:
266
- TODO: Expand to allow specifying explicit ``__key__`` values.
392
+ Keys are auto-generated as UUID v1 for time-sortable ordering.
393
+ Custom key specification is not currently supported.
267
394
  """
268
395
  return {
269
396
  # Generates a UUID that is timelike-sortable
270
- '__key__': str( uuid.uuid1( 0, 0 ) ),
271
- 'msgpack': self.packed,
397
+ "__key__": str(uuid.uuid1(0, 0)),
398
+ "msgpack": self.packed,
272
399
  }
273
400
 
274
- def _batch_aggregate( xs: Sequence ):
275
- """Aggregate a sequence of values into a batch-appropriate format.
276
-
277
- Args:
278
- xs: A sequence of values to aggregate. If the first element is a numpy
279
- array, all elements are stacked into a single array. Otherwise,
280
- returns a list.
281
-
282
- Returns:
283
- A numpy array (if elements are arrays) or a list (otherwise).
284
- """
285
401
 
402
+ def _batch_aggregate(xs: Sequence):
403
+ """Stack arrays into numpy array with batch dim; otherwise return list."""
286
404
  if not xs:
287
- # Empty sequence
288
405
  return []
406
+ if isinstance(xs[0], np.ndarray):
407
+ return np.array(list(xs))
408
+ return list(xs)
289
409
 
290
- # Aggregate
291
- if isinstance( xs[0], np.ndarray ):
292
- return np.array( list( xs ) )
293
410
 
294
- return list( xs )
295
-
296
- class SampleBatch( Generic[DT] ):
411
+ class SampleBatch(Generic[DT]):
297
412
  """A batch of samples with automatic attribute aggregation.
298
413
 
299
414
  This class wraps a sequence of samples and provides magic ``__getattr__``
@@ -304,19 +419,28 @@ class SampleBatch( Generic[DT] ):
304
419
  NDArray fields are stacked into a numpy array with a batch dimension.
305
420
  Other fields are aggregated into a list.
306
421
 
307
- Type Parameters:
422
+ Parameters:
308
423
  DT: The sample type, must derive from ``PackableSample``.
309
424
 
310
425
  Attributes:
311
426
  samples: The list of sample instances in this batch.
312
427
 
313
- Example:
428
+ Examples:
314
429
  >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
315
430
  >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
316
431
  >>> batch.names # Returns list of names
432
+
433
+ Note:
434
+ This class uses Python's ``__orig_class__`` mechanism to extract the
435
+ type parameter at runtime. Instances must be created using the
436
+ subscripted syntax ``SampleBatch[MyType](samples)`` rather than
437
+ calling the constructor directly with an unsubscripted class.
317
438
  """
318
439
 
319
- def __init__( self, samples: Sequence[DT] ):
440
+ # Design note: The docstring uses "Parameters:" for type parameters because
441
+ # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
442
+
443
+ def __init__(self, samples: Sequence[DT]):
320
444
  """Create a batch from a sequence of samples.
321
445
 
322
446
  Args:
@@ -324,19 +448,23 @@ class SampleBatch( Generic[DT] ):
324
448
  Each sample must be an instance of a type derived from
325
449
  ``PackableSample``.
326
450
  """
327
- self.samples = list( samples )
451
+ self.samples = list(samples)
328
452
  self._aggregate_cache = dict()
453
+ self._sample_type_cache: Type | None = None
329
454
 
330
455
  @property
331
- def sample_type( self ) -> Type:
456
+ def sample_type(self) -> Type:
332
457
  """The type of each sample in this batch.
333
458
 
334
459
  Returns:
335
460
  The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
336
461
  """
337
- return typing.get_args( self.__orig_class__)[0]
462
+ if self._sample_type_cache is None:
463
+ self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
464
+ assert self._sample_type_cache is not None
465
+ return self._sample_type_cache
338
466
 
339
- def __getattr__( self, name ):
467
+ def __getattr__(self, name):
340
468
  """Aggregate an attribute across all samples in the batch.
341
469
 
342
470
  This magic method enables attribute-style access to aggregated sample
@@ -353,22 +481,57 @@ class SampleBatch( Generic[DT] ):
353
481
  AttributeError: If the attribute doesn't exist on the sample type.
354
482
  """
355
483
  # Aggregate named params of sample type
356
- if name in vars( self.sample_type )['__annotations__']:
484
+ if name in vars(self.sample_type)["__annotations__"]:
357
485
  if name not in self._aggregate_cache:
358
486
  self._aggregate_cache[name] = _batch_aggregate(
359
- [ getattr( x, name )
360
- for x in self.samples ]
487
+ [getattr(x, name) for x in self.samples]
361
488
  )
362
489
 
363
490
  return self._aggregate_cache[name]
364
491
 
365
- raise AttributeError( f'No sample attribute named {name}' )
492
+ raise AttributeError(f"No sample attribute named {name}")
493
+
494
+
495
+ ST = TypeVar("ST", bound=PackableSample)
496
+ RT = TypeVar("RT", bound=PackableSample)
366
497
 
367
498
 
368
- ST = TypeVar( 'ST', bound = PackableSample )
369
- RT = TypeVar( 'RT', bound = PackableSample )
499
+ class _ShardListStage(wds.utils.PipelineStage):
500
+ """Pipeline stage that yields {url: shard_id} dicts from a DataSource.
370
501
 
371
- class Dataset( Generic[ST] ):
502
+ This is analogous to SimpleShardList but works with any DataSource.
503
+ Used as the first stage before split_by_worker.
504
+ """
505
+
506
+ def __init__(self, source: DataSource):
507
+ self.source = source
508
+
509
+ def run(self):
510
+ """Yield {url: shard_id} dicts for each shard."""
511
+ for shard_id in self.source.list_shards():
512
+ yield {"url": shard_id}
513
+
514
+
515
+ class _StreamOpenerStage(wds.utils.PipelineStage):
516
+ """Pipeline stage that opens streams from a DataSource.
517
+
518
+ Takes {url: shard_id} dicts and adds a stream using source.open_shard().
519
+ This replaces WebDataset's url_opener stage.
520
+ """
521
+
522
+ def __init__(self, source: DataSource):
523
+ self.source = source
524
+
525
+ def run(self, src):
526
+ """Open streams for each shard dict."""
527
+ for sample in src:
528
+ shard_id = sample["url"]
529
+ stream = self.source.open_shard(shard_id)
530
+ sample["stream"] = stream
531
+ yield sample
532
+
533
+
534
+ class Dataset(Generic[ST]):
372
535
  """A typed dataset built on WebDataset with lens transformations.
373
536
 
374
537
  This class wraps WebDataset tar archives and provides type-safe iteration
@@ -381,13 +544,13 @@ class Dataset( Generic[ST] ):
381
544
  - Type transformations via the lens system (``as_type()``)
382
545
  - Export to parquet format
383
546
 
384
- Type Parameters:
547
+ Parameters:
385
548
  ST: The sample type for this dataset, must derive from ``PackableSample``.
386
549
 
387
550
  Attributes:
388
551
  url: WebDataset brace-notation URL for the tar file(s).
389
552
 
390
- Example:
553
+ Examples:
391
554
  >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
392
555
  >>> for sample in ds.ordered(batch_size=32):
393
556
  ... # sample is SampleBatch[MyData] with batch_size samples
@@ -395,23 +558,31 @@ class Dataset( Generic[ST] ):
395
558
  ...
396
559
  >>> # Transform to a different view
397
560
  >>> ds_view = ds.as_type(MyDataView)
398
-
561
+
562
+ Note:
563
+ This class uses Python's ``__orig_class__`` mechanism to extract the
564
+ type parameter at runtime. Instances must be created using the
565
+ subscripted syntax ``Dataset[MyType](url)`` rather than calling the
566
+ constructor directly with an unsubscripted class.
399
567
  """
400
568
 
569
+ # Design note: The docstring uses "Parameters:" for type parameters because
570
+ # quartodoc doesn't yet support "Type Parameters:" sections in generated docs.
571
+
401
572
  @property
402
- def sample_type( self ) -> Type:
573
+ def sample_type(self) -> Type:
403
574
  """The type of each returned sample from this dataset's iterator.
404
575
 
405
576
  Returns:
406
577
  The type parameter ``ST`` used when creating this ``Dataset[ST]``.
407
-
408
- Note:
409
- Extracts the type parameter at runtime using ``__orig_class__``.
410
578
  """
411
- # NOTE: Linting may fail here due to __orig_class__ being a runtime attribute
412
- return typing.get_args( self.__orig_class__ )[0]
579
+ if self._sample_type_cache is None:
580
+ self._sample_type_cache = typing.get_args(self.__orig_class__)[0]
581
+ assert self._sample_type_cache is not None
582
+ return self._sample_type_cache
583
+
413
584
  @property
414
- def batch_type( self ) -> Type:
585
+ def batch_type(self) -> Type:
415
586
  """The type of batches produced by this dataset.
416
587
 
417
588
  Returns:
@@ -419,31 +590,61 @@ class Dataset( Generic[ST] ):
419
590
  """
420
591
  return SampleBatch[self.sample_type]
421
592
 
422
- def __init__( self, url: str,
423
- metadata_url: str | None = None,
424
- ) -> None:
425
- """Create a dataset from a WebDataset URL.
593
+ def __init__(
594
+ self,
595
+ source: DataSource | str | None = None,
596
+ metadata_url: str | None = None,
597
+ *,
598
+ url: str | None = None,
599
+ ) -> None:
600
+ """Create a dataset from a DataSource or URL.
426
601
 
427
602
  Args:
428
- url: WebDataset brace-notation URL pointing to tar files, e.g.,
429
- ``"path/to/file-{000000..000009}.tar"`` for multiple shards or
430
- ``"path/to/file-000000.tar"`` for a single shard.
603
+ source: Either a DataSource implementation or a WebDataset-compatible
604
+ URL string. If a string is provided, it's wrapped in URLSource
605
+ for backward compatibility.
606
+
607
+ Examples:
608
+ - String URL: ``"path/to/file-{000000..000009}.tar"``
609
+ - URLSource: ``URLSource("https://example.com/data.tar")``
610
+ - S3Source: ``S3Source(bucket="my-bucket", keys=["data.tar"])``
611
+
612
+ metadata_url: Optional URL to msgpack-encoded metadata for this dataset.
613
+ url: Deprecated. Use ``source`` instead. Kept for backward compatibility.
431
614
  """
432
615
  super().__init__()
433
- self.url = url
434
- """WebDataset brace-notation URL pointing to tar files, e.g.,
435
- ``"path/to/file-{000000..000009}.tar"`` for multiple shards or
436
- ``"path/to/file-000000.tar"`` for a single shard.
437
- """
616
+
617
+ # Handle backward compatibility: url= keyword argument
618
+ if source is None and url is not None:
619
+ source = url
620
+ elif source is None:
621
+ raise TypeError("Dataset() missing required argument: 'source' or 'url'")
622
+
623
+ # Normalize source: strings become URLSource for backward compatibility
624
+ if isinstance(source, str):
625
+ self._source: DataSource = URLSource(source)
626
+ self.url = source
627
+ else:
628
+ self._source = source
629
+ # For compatibility, expose URL if source has list_shards
630
+ shards = source.list_shards()
631
+ # Design note: Using first shard as url for legacy compatibility.
632
+ # Full shard list is available via list_shards() method.
633
+ self.url = shards[0] if shards else ""
438
634
 
439
635
  self._metadata: dict[str, Any] | None = None
440
636
  self.metadata_url: str | None = metadata_url
441
637
  """Optional URL to msgpack-encoded metadata for this dataset."""
442
638
 
443
- # Allow addition of automatic transformation of raw underlying data
444
639
  self._output_lens: Lens | None = None
640
+ self._sample_type_cache: Type | None = None
445
641
 
446
- def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
642
+ @property
643
+ def source(self) -> DataSource:
644
+ """The underlying data source for this dataset."""
645
+ return self._source
646
+
647
+ def as_type(self, other: Type[RT]) -> "Dataset[RT]":
447
648
  """View this dataset through a different sample type using a registered lens.
448
649
 
449
650
  Args:
@@ -459,28 +660,53 @@ class Dataset( Generic[ST] ):
459
660
  ValueError: If no registered lens exists between the current
460
661
  sample type and the target type.
461
662
  """
462
- ret = Dataset[other]( self.url )
663
+ ret = Dataset[other](self._source)
463
664
  # Get the singleton lens registry
464
665
  lenses = LensNetwork()
465
- ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
666
+ ret._output_lens = lenses.transform(self.sample_type, ret.sample_type)
466
667
  return ret
467
668
 
468
669
  @property
469
- def shard_list( self ) -> list[str]:
470
- """List of individual dataset shards
471
-
670
+ def shards(self) -> Iterator[str]:
671
+ """Lazily iterate over shard identifiers.
672
+
673
+ Yields:
674
+ Shard identifiers (e.g., 'train-000000.tar', 'train-000001.tar').
675
+
676
+ Examples:
677
+ >>> for shard in ds.shards:
678
+ ... print(f"Processing {shard}")
679
+ """
680
+ return iter(self._source.list_shards())
681
+
682
+ def list_shards(self) -> list[str]:
683
+ """Get list of individual dataset shards.
684
+
472
685
  Returns:
473
686
  A full (non-lazy) list of the individual ``tar`` files within the
474
687
  source WebDataset.
475
688
  """
476
- pipe = wds.pipeline.DataPipeline(
477
- wds.shardlists.SimpleShardList( self.url ),
478
- wds.filters.map( lambda x: x['url'] )
689
+ return self._source.list_shards()
690
+
691
+ # Legacy alias for backwards compatibility
692
+ @property
693
+ def shard_list(self) -> list[str]:
694
+ """List of individual dataset shards (deprecated, use list_shards()).
695
+
696
+ .. deprecated::
697
+ Use :meth:`list_shards` instead.
698
+ """
699
+ import warnings
700
+
701
+ warnings.warn(
702
+ "shard_list is deprecated, use list_shards() instead",
703
+ DeprecationWarning,
704
+ stacklevel=2,
479
705
  )
480
- return list( pipe )
706
+ return self.list_shards()
481
707
 
482
708
  @property
483
- def metadata( self ) -> dict[str, Any] | None:
709
+ def metadata(self) -> dict[str, Any] | None:
484
710
  """Fetch and cache metadata from metadata_url.
485
711
 
486
712
  Returns:
@@ -493,50 +719,91 @@ class Dataset( Generic[ST] ):
493
719
  return None
494
720
 
495
721
  if self._metadata is None:
496
- with requests.get( self.metadata_url, stream = True ) as response:
722
+ with requests.get(self.metadata_url, stream=True) as response:
497
723
  response.raise_for_status()
498
- self._metadata = msgpack.unpackb( response.content, raw = False )
499
-
724
+ self._metadata = msgpack.unpackb(response.content, raw=False)
725
+
500
726
  # Use our cached values
501
727
  return self._metadata
502
-
503
- def ordered( self,
504
- batch_size: int | None = 1,
505
- ) -> Iterable[ST]:
506
- """Iterate over the dataset in order
507
-
728
+
729
+ @overload
730
+ def ordered(
731
+ self,
732
+ batch_size: None = None,
733
+ ) -> Iterable[ST]: ...
734
+
735
+ @overload
736
+ def ordered(
737
+ self,
738
+ batch_size: int,
739
+ ) -> Iterable[SampleBatch[ST]]: ...
740
+
741
+ def ordered(
742
+ self,
743
+ batch_size: int | None = None,
744
+ ) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
745
+ """Iterate over the dataset in order.
746
+
508
747
  Args:
509
- batch_size (:obj:`int`, optional): The size of iterated batches.
510
- Default: 1. If ``None``, iterates over one sample at a time
511
- with no batch dimension.
512
-
748
+ batch_size: The size of iterated batches. Default: None (unbatched).
749
+ If ``None``, iterates over one sample at a time with no batch
750
+ dimension.
751
+
513
752
  Returns:
514
- :obj:`webdataset.DataPipeline` A data pipeline that iterates over
515
- the dataset in its original sample order
516
-
753
+ A data pipeline that iterates over the dataset in its original
754
+ sample order. When ``batch_size`` is ``None``, yields individual
755
+ samples of type ``ST``. When ``batch_size`` is an integer, yields
756
+ ``SampleBatch[ST]`` instances containing that many samples.
757
+
758
+ Examples:
759
+ >>> for sample in ds.ordered():
760
+ ... process(sample) # sample is ST
761
+ >>> for batch in ds.ordered(batch_size=32):
762
+ ... process(batch) # batch is SampleBatch[ST]
517
763
  """
518
-
519
764
  if batch_size is None:
520
765
  return wds.pipeline.DataPipeline(
521
- wds.shardlists.SimpleShardList( self.url ),
766
+ _ShardListStage(self._source),
522
767
  wds.shardlists.split_by_worker,
523
- wds.tariterators.tarfile_to_samples(),
524
- wds.filters.map( self.wrap ),
768
+ _StreamOpenerStage(self._source),
769
+ wds.tariterators.tar_file_expander,
770
+ wds.tariterators.group_by_keys,
771
+ wds.filters.map(self.wrap),
525
772
  )
526
773
 
527
774
  return wds.pipeline.DataPipeline(
528
- wds.shardlists.SimpleShardList( self.url ),
775
+ _ShardListStage(self._source),
529
776
  wds.shardlists.split_by_worker,
530
- wds.tariterators.tarfile_to_samples(),
531
- wds.filters.batched( batch_size ),
532
- wds.filters.map( self.wrap_batch ),
777
+ _StreamOpenerStage(self._source),
778
+ wds.tariterators.tar_file_expander,
779
+ wds.tariterators.group_by_keys,
780
+ wds.filters.batched(batch_size),
781
+ wds.filters.map(self.wrap_batch),
533
782
  )
534
783
 
535
- def shuffled( self,
536
- buffer_shards: int = 100,
537
- buffer_samples: int = 10_000,
538
- batch_size: int | None = 1,
539
- ) -> Iterable[ST]:
784
+ @overload
785
+ def shuffled(
786
+ self,
787
+ buffer_shards: int = 100,
788
+ buffer_samples: int = 10_000,
789
+ batch_size: None = None,
790
+ ) -> Iterable[ST]: ...
791
+
792
+ @overload
793
+ def shuffled(
794
+ self,
795
+ buffer_shards: int = 100,
796
+ buffer_samples: int = 10_000,
797
+ *,
798
+ batch_size: int,
799
+ ) -> Iterable[SampleBatch[ST]]: ...
800
+
801
+ def shuffled(
802
+ self,
803
+ buffer_shards: int = 100,
804
+ buffer_samples: int = 10_000,
805
+ batch_size: int | None = None,
806
+ ) -> Iterable[ST] | Iterable[SampleBatch[ST]]:
540
807
  """Iterate over the dataset in random order.
541
808
 
542
809
  Args:
@@ -546,93 +813,139 @@ class Dataset( Generic[ST] ):
546
813
  buffer_samples: Number of samples to buffer for shuffling within
547
814
  shards. Larger values increase randomness but use more memory.
548
815
  Default: 10,000.
549
- batch_size: The size of iterated batches. Default: 1. If ``None``,
550
- iterates over one sample at a time with no batch dimension.
816
+ batch_size: The size of iterated batches. Default: None (unbatched).
817
+ If ``None``, iterates over one sample at a time with no batch
818
+ dimension.
551
819
 
552
820
  Returns:
553
- A WebDataset data pipeline that iterates over the dataset in
554
- randomized order. If ``batch_size`` is not ``None``, yields
555
- ``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
556
- samples.
821
+ A data pipeline that iterates over the dataset in randomized order.
822
+ When ``batch_size`` is ``None``, yields individual samples of type
823
+ ``ST``. When ``batch_size`` is an integer, yields ``SampleBatch[ST]``
824
+ instances containing that many samples.
825
+
826
+ Examples:
827
+ >>> for sample in ds.shuffled():
828
+ ... process(sample) # sample is ST
829
+ >>> for batch in ds.shuffled(batch_size=32):
830
+ ... process(batch) # batch is SampleBatch[ST]
557
831
  """
558
832
  if batch_size is None:
559
833
  return wds.pipeline.DataPipeline(
560
- wds.shardlists.SimpleShardList( self.url ),
561
- wds.filters.shuffle( buffer_shards ),
834
+ _ShardListStage(self._source),
835
+ wds.filters.shuffle(buffer_shards),
562
836
  wds.shardlists.split_by_worker,
563
- wds.tariterators.tarfile_to_samples(),
564
- wds.filters.shuffle( buffer_samples ),
565
- wds.filters.map( self.wrap ),
837
+ _StreamOpenerStage(self._source),
838
+ wds.tariterators.tar_file_expander,
839
+ wds.tariterators.group_by_keys,
840
+ wds.filters.shuffle(buffer_samples),
841
+ wds.filters.map(self.wrap),
566
842
  )
567
843
 
568
844
  return wds.pipeline.DataPipeline(
569
- wds.shardlists.SimpleShardList( self.url ),
570
- wds.filters.shuffle( buffer_shards ),
845
+ _ShardListStage(self._source),
846
+ wds.filters.shuffle(buffer_shards),
571
847
  wds.shardlists.split_by_worker,
572
- wds.tariterators.tarfile_to_samples(),
573
- wds.filters.shuffle( buffer_samples ),
574
- wds.filters.batched( batch_size ),
575
- wds.filters.map( self.wrap_batch ),
848
+ _StreamOpenerStage(self._source),
849
+ wds.tariterators.tar_file_expander,
850
+ wds.tariterators.group_by_keys,
851
+ wds.filters.shuffle(buffer_samples),
852
+ wds.filters.batched(batch_size),
853
+ wds.filters.map(self.wrap_batch),
576
854
  )
577
-
578
- # TODO Rewrite to eliminate `pandas` dependency directly calling
579
- # `fastparquet`
580
- def to_parquet( self, path: Pathlike,
581
- sample_map: Optional[SampleExportMap] = None,
582
- maxcount: Optional[int] = None,
583
- **kwargs,
584
- ):
585
- """Save dataset contents to a `parquet` file at `path`
586
-
587
- `kwargs` sent to `pandas.to_parquet`
855
+
856
+ # Design note: Uses pandas for parquet export. Could be replaced with
857
+ # direct fastparquet calls to reduce dependencies if needed.
858
+ def to_parquet(
859
+ self,
860
+ path: Pathlike,
861
+ sample_map: Optional[SampleExportMap] = None,
862
+ maxcount: Optional[int] = None,
863
+ **kwargs,
864
+ ):
865
+ """Export dataset contents to parquet format.
866
+
867
+ Converts all samples to a pandas DataFrame and saves to parquet file(s).
868
+ Useful for interoperability with data analysis tools.
869
+
870
+ Args:
871
+ path: Output path for the parquet file. If ``maxcount`` is specified,
872
+ files are named ``{stem}-{segment:06d}.parquet``.
873
+ sample_map: Optional function to convert samples to dictionaries.
874
+ Defaults to ``dataclasses.asdict``.
875
+ maxcount: If specified, split output into multiple files with at most
876
+ this many samples each. Recommended for large datasets.
877
+ **kwargs: Additional arguments passed to ``pandas.DataFrame.to_parquet()``.
878
+ Common options include ``compression``, ``index``, ``engine``.
879
+
880
+ Warning:
881
+ **Memory Usage**: When ``maxcount=None`` (default), this method loads
882
+ the **entire dataset into memory** as a pandas DataFrame before writing.
883
+ For large datasets, this can cause memory exhaustion.
884
+
885
+ For datasets larger than available RAM, always specify ``maxcount``::
886
+
887
+ # Safe for large datasets - processes in chunks
888
+ ds.to_parquet("output.parquet", maxcount=10000)
889
+
890
+ This creates multiple parquet files: ``output-000000.parquet``,
891
+ ``output-000001.parquet``, etc.
892
+
893
+ Examples:
894
+ >>> ds = Dataset[MySample]("data.tar")
895
+ >>> # Small dataset - load all at once
896
+ >>> ds.to_parquet("output.parquet")
897
+ >>>
898
+ >>> # Large dataset - process in chunks
899
+ >>> ds.to_parquet("output.parquet", maxcount=50000)
588
900
  """
589
901
  ##
590
902
 
591
903
  # Normalize args
592
- path = Path( path )
904
+ path = Path(path)
593
905
  if sample_map is None:
594
906
  sample_map = asdict
595
-
596
- verbose = kwargs.get( 'verbose', False )
597
907
 
598
- it = self.ordered( batch_size = None )
908
+ verbose = kwargs.get("verbose", False)
909
+
910
+ it = self.ordered(batch_size=None)
599
911
  if verbose:
600
- it = tqdm( it )
912
+ it = tqdm(it)
601
913
 
602
914
  #
603
915
 
604
916
  if maxcount is None:
605
917
  # Load and save full dataset
606
- df = pd.DataFrame( [ sample_map( x )
607
- for x in self.ordered( batch_size = None ) ] )
608
- df.to_parquet( path, **kwargs )
609
-
918
+ df = pd.DataFrame([sample_map(x) for x in self.ordered(batch_size=None)])
919
+ df.to_parquet(path, **kwargs)
920
+
610
921
  else:
611
922
  # Load and save dataset in segments of size `maxcount`
612
923
 
613
924
  cur_segment = 0
614
925
  cur_buffer = []
615
- path_template = (path.parent / f'{path.stem}-{{:06d}}{path.suffix}').as_posix()
926
+ path_template = (
927
+ path.parent / f"{path.stem}-{{:06d}}{path.suffix}"
928
+ ).as_posix()
616
929
 
617
- for x in self.ordered( batch_size = None ):
618
- cur_buffer.append( sample_map( x ) )
930
+ for x in self.ordered(batch_size=None):
931
+ cur_buffer.append(sample_map(x))
619
932
 
620
- if len( cur_buffer ) >= maxcount:
933
+ if len(cur_buffer) >= maxcount:
621
934
  # Write current segment
622
- cur_path = path_template.format( cur_segment )
623
- df = pd.DataFrame( cur_buffer )
624
- df.to_parquet( cur_path, **kwargs )
935
+ cur_path = path_template.format(cur_segment)
936
+ df = pd.DataFrame(cur_buffer)
937
+ df.to_parquet(cur_path, **kwargs)
625
938
 
626
939
  cur_segment += 1
627
940
  cur_buffer = []
628
-
629
- if len( cur_buffer ) > 0:
941
+
942
+ if len(cur_buffer) > 0:
630
943
  # Write one last segment with remainder
631
- cur_path = path_template.format( cur_segment )
632
- df = pd.DataFrame( cur_buffer )
633
- df.to_parquet( cur_path, **kwargs )
944
+ cur_path = path_template.format(cur_segment)
945
+ df = pd.DataFrame(cur_buffer)
946
+ df.to_parquet(cur_path, **kwargs)
634
947
 
635
- def wrap( self, sample: MsgpackRawSample ) -> ST:
948
+ def wrap(self, sample: WDSRawSample) -> ST:
636
949
  """Wrap a raw msgpack sample into the appropriate dataset-specific type.
637
950
 
638
951
  Args:
@@ -643,16 +956,22 @@ class Dataset( Generic[ST] ):
643
956
  A deserialized sample of type ``ST``, optionally transformed through
644
957
  a lens if ``as_type()`` was called.
645
958
  """
646
- assert 'msgpack' in sample
647
- assert type( sample['msgpack'] ) == bytes
648
-
959
+ if "msgpack" not in sample:
960
+ raise ValueError(
961
+ f"Sample missing 'msgpack' key, got keys: {list(sample.keys())}"
962
+ )
963
+ if not isinstance(sample["msgpack"], bytes):
964
+ raise ValueError(
965
+ f"Expected sample['msgpack'] to be bytes, got {type(sample['msgpack']).__name__}"
966
+ )
967
+
649
968
  if self._output_lens is None:
650
- return self.sample_type.from_bytes( sample['msgpack'] )
969
+ return self.sample_type.from_bytes(sample["msgpack"])
651
970
 
652
- source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
653
- return self._output_lens( source_sample )
971
+ source_sample = self._output_lens.source_type.from_bytes(sample["msgpack"])
972
+ return self._output_lens(source_sample)
654
973
 
655
- def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
974
+ def wrap_batch(self, batch: WDSRawBatch) -> SampleBatch[ST]:
656
975
  """Wrap a batch of raw msgpack samples into a typed SampleBatch.
657
976
 
658
977
  Args:
@@ -668,35 +987,48 @@ class Dataset( Generic[ST] ):
668
987
  aggregates them into a batch.
669
988
  """
670
989
 
671
- assert 'msgpack' in batch
990
+ if "msgpack" not in batch:
991
+ raise ValueError(
992
+ f"Batch missing 'msgpack' key, got keys: {list(batch.keys())}"
993
+ )
672
994
 
673
995
  if self._output_lens is None:
674
- batch_unpacked = [ self.sample_type.from_bytes( bs )
675
- for bs in batch['msgpack'] ]
676
- return SampleBatch[self.sample_type]( batch_unpacked )
996
+ batch_unpacked = [
997
+ self.sample_type.from_bytes(bs) for bs in batch["msgpack"]
998
+ ]
999
+ return SampleBatch[self.sample_type](batch_unpacked)
1000
+
1001
+ batch_source = [
1002
+ self._output_lens.source_type.from_bytes(bs) for bs in batch["msgpack"]
1003
+ ]
1004
+ batch_view = [self._output_lens(s) for s in batch_source]
1005
+ return SampleBatch[self.sample_type](batch_view)
677
1006
 
678
- batch_source = [ self._output_lens.source_type.from_bytes( bs )
679
- for bs in batch['msgpack'] ]
680
- batch_view = [ self._output_lens( s )
681
- for s in batch_source ]
682
- return SampleBatch[self.sample_type]( batch_view )
683
1007
 
1008
+ _T = TypeVar("_T")
684
1009
 
685
- def packable( cls ):
1010
+
1011
+ @dataclass_transform()
1012
+ def packable(cls: type[_T]) -> type[_T]:
686
1013
  """Decorator to convert a regular class into a ``PackableSample``.
687
1014
 
688
1015
  This decorator transforms a class into a dataclass that inherits from
689
1016
  ``PackableSample``, enabling automatic msgpack serialization/deserialization
690
1017
  with special handling for NDArray fields.
691
1018
 
1019
+ The resulting class satisfies the ``Packable`` protocol, making it compatible
1020
+ with all atdata APIs that accept packable types (e.g., ``publish_schema``,
1021
+ lens transformations, etc.).
1022
+
692
1023
  Args:
693
1024
  cls: The class to convert. Should have type annotations for its fields.
694
1025
 
695
1026
  Returns:
696
1027
  A new dataclass that inherits from ``PackableSample`` with the same
697
- name and annotations as the original class.
1028
+ name and annotations as the original class. The class satisfies the
1029
+ ``Packable`` protocol and can be used with ``Type[Packable]`` signatures.
698
1030
 
699
- Example:
1031
+ Examples:
700
1032
  >>> @packable
701
1033
  ... class MyData:
702
1034
  ... name: str
@@ -705,6 +1037,9 @@ def packable( cls ):
705
1037
  >>> sample = MyData(name="test", values=np.array([1, 2, 3]))
706
1038
  >>> bytes_data = sample.packed
707
1039
  >>> restored = MyData.from_bytes(bytes_data)
1040
+ >>>
1041
+ >>> # Works with Packable-typed APIs
1042
+ >>> index.publish_schema(MyData, version="1.0.0") # Type-safe
708
1043
  """
709
1044
 
710
1045
  ##
@@ -713,18 +1048,41 @@ def packable( cls ):
713
1048
  class_annotations = cls.__annotations__
714
1049
 
715
1050
  # Add in dataclass niceness to original class
716
- as_dataclass = dataclass( cls )
1051
+ as_dataclass = dataclass(cls)
717
1052
 
718
1053
  # This triggers a bunch of behind-the-scenes stuff for the newly annotated class
719
1054
  @dataclass
720
- class as_packable( as_dataclass, PackableSample ):
721
- def __post_init__( self ):
722
- return PackableSample.__post_init__( self )
723
-
724
- # TODO This doesn't properly carry over the original
1055
+ class as_packable(as_dataclass, PackableSample):
1056
+ def __post_init__(self):
1057
+ return PackableSample.__post_init__(self)
1058
+
1059
+ # Restore original class identity for better repr/debugging
725
1060
  as_packable.__name__ = class_name
1061
+ as_packable.__qualname__ = class_name
1062
+ as_packable.__module__ = cls.__module__
726
1063
  as_packable.__annotations__ = class_annotations
1064
+ if cls.__doc__:
1065
+ as_packable.__doc__ = cls.__doc__
1066
+
1067
+ # Fix qualnames of dataclass-generated methods so they don't show
1068
+ # 'packable.<locals>.as_packable' in help() and IDE hints
1069
+ old_qualname_prefix = "packable.<locals>.as_packable"
1070
+ for attr_name in ("__init__", "__repr__", "__eq__", "__post_init__"):
1071
+ attr = getattr(as_packable, attr_name, None)
1072
+ if attr is not None and hasattr(attr, "__qualname__"):
1073
+ if attr.__qualname__.startswith(old_qualname_prefix):
1074
+ attr.__qualname__ = attr.__qualname__.replace(
1075
+ old_qualname_prefix, class_name, 1
1076
+ )
1077
+
1078
+ # Auto-register lens from DictSample to this type
1079
+ # This enables ds.as_type(MyType) when ds is Dataset[DictSample]
1080
+ def _dict_to_typed(ds: DictSample) -> as_packable:
1081
+ return as_packable.from_data(ds._data)
1082
+
1083
+ _dict_lens = Lens(_dict_to_typed)
1084
+ LensNetwork().register(_dict_lens)
727
1085
 
728
1086
  ##
729
1087
 
730
- return as_packable
1088
+ return as_packable