atdata 0.1.3b3__py3-none-any.whl → 0.1.3b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/__init__.py CHANGED
@@ -1,4 +1,39 @@
1
- """A loose federation of distributed, typed datasets"""
1
+ """A loose federation of distributed, typed datasets.
2
+
3
+ ``atdata`` provides a typed dataset abstraction built on WebDataset, with support
4
+ for:
5
+
6
+ - **Typed samples** with automatic msgpack serialization
7
+ - **NDArray handling** with transparent bytes conversion
8
+ - **Lens transformations** for viewing datasets through different type schemas
9
+ - **Batch aggregation** with automatic numpy array stacking
10
+ - **WebDataset integration** for efficient large-scale dataset storage
11
+
12
+ Quick Start:
13
+ >>> import atdata
14
+ >>> import numpy as np
15
+ >>>
16
+ >>> @atdata.packable
17
+ ... class MyData:
18
+ ... features: np.ndarray
19
+ ... label: str
20
+ >>>
21
+ >>> # Create dataset from WebDataset tar files
22
+ >>> ds = atdata.Dataset[MyData]("path/to/data-{000000..000009}.tar")
23
+ >>>
24
+ >>> # Iterate with automatic batching
25
+ >>> for batch in ds.shuffled(batch_size=32):
26
+ ... features = batch.features # numpy array (32, ...)
27
+ ... labels = batch.label # list of 32 strings
28
+
29
+ Main Components:
30
+ - ``PackableSample``: Base class for msgpack-serializable samples
31
+ - ``Dataset``: Typed dataset wrapper for WebDataset
32
+ - ``SampleBatch``: Automatic batch aggregation
33
+ - ``Lens``: Bidirectional type transformations
34
+ - ``@packable``: Decorator for creating PackableSample classes
35
+ - ``@lens``: Decorator for creating lens transformations
36
+ """
2
37
 
3
38
  ##
4
39
  # Expose components
atdata/_helpers.py CHANGED
@@ -1,4 +1,16 @@
1
- """Assorted helper methods for `atdata`"""
1
+ """Helper utilities for numpy array serialization.
2
+
3
+ This module provides utility functions for converting numpy arrays to and from
4
+ bytes for msgpack serialization. The functions use numpy's native save/load
5
+ format to preserve array dtype and shape information.
6
+
7
+ Functions:
8
+ - ``array_to_bytes()``: Serialize numpy array to bytes
9
+ - ``bytes_to_array()``: Deserialize bytes to numpy array
10
+
11
+ These helpers are used internally by ``PackableSample`` to enable transparent
12
+ handling of NDArray fields during msgpack packing/unpacking.
13
+ """
2
14
 
3
15
  ##
4
16
  # Imports
@@ -11,12 +23,36 @@ import numpy as np
11
23
  ##
12
24
 
13
25
  def array_to_bytes( x: np.ndarray ) -> bytes:
14
- """Convert `numpy` array to a format suitable for packing"""
26
+ """Convert a numpy array to bytes for msgpack serialization.
27
+
28
+ Uses numpy's native ``save()`` format to preserve array dtype and shape.
29
+
30
+ Args:
31
+ x: A numpy array to serialize.
32
+
33
+ Returns:
34
+ Raw bytes representing the serialized array.
35
+
36
+ Note:
37
+ Uses ``allow_pickle=True`` to support object dtypes.
38
+ """
15
39
  np_bytes = BytesIO()
16
40
  np.save( np_bytes, x, allow_pickle = True )
17
41
  return np_bytes.getvalue()
18
42
 
19
43
  def bytes_to_array( b: bytes ) -> np.ndarray:
20
- """Convert packed bytes back to a `numpy` array"""
44
+ """Convert serialized bytes back to a numpy array.
45
+
46
+ Reverses the serialization performed by ``array_to_bytes()``.
47
+
48
+ Args:
49
+ b: Raw bytes from a serialized numpy array.
50
+
51
+ Returns:
52
+ The deserialized numpy array with original dtype and shape.
53
+
54
+ Note:
55
+ Uses ``allow_pickle=True`` to support object dtypes.
56
+ """
21
57
  np_bytes = BytesIO( b )
22
58
  return np.load( np_bytes, allow_pickle = True )
atdata/dataset.py CHANGED
@@ -1,4 +1,29 @@
1
- """Schematized WebDatasets"""
1
+ """Core dataset and sample infrastructure for typed WebDatasets.
2
+
3
+ This module provides the core components for working with typed, msgpack-serialized
4
+ samples in WebDataset format:
5
+
6
+ - ``PackableSample``: Base class for msgpack-serializable samples with automatic
7
+ NDArray handling
8
+ - ``SampleBatch``: Automatic batching with attribute aggregation
9
+ - ``Dataset``: Generic typed dataset wrapper for WebDataset tar files
10
+ - ``@packable``: Decorator to convert regular classes into PackableSample subclasses
11
+
12
+ The implementation handles automatic conversion between numpy arrays and bytes
13
+ during serialization, enabling efficient storage of numerical data in WebDataset
14
+ archives.
15
+
16
+ Example:
17
+ >>> @packable
18
+ ... class ImageSample:
19
+ ... image: NDArray
20
+ ... label: str
21
+ ...
22
+ >>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
23
+ >>> for batch in ds.shuffled(batch_size=32):
24
+ ... images = batch.image # Stacked numpy array (32, H, W, C)
25
+ ... labels = batch.label # List of 32 strings
26
+ """
2
27
 
3
28
  ##
4
29
  # Imports
@@ -107,6 +132,15 @@ MsgpackRawSample: TypeAlias = Dict[str, Any]
107
132
  # return eh.bytes_to_array( self.raw_bytes )
108
133
 
109
134
  def _make_packable( x ):
135
+ """Convert a value to a msgpack-compatible format.
136
+
137
+ Args:
138
+ x: A value to convert. If it's a numpy array, converts to bytes.
139
+ Otherwise returns the value unchanged.
140
+
141
+ Returns:
142
+ The value in a format suitable for msgpack serialization.
143
+ """
110
144
  # if isinstance( x, ArrayBytes ):
111
145
  # return x.raw_bytes
112
146
  if isinstance( x, np.ndarray ):
@@ -114,7 +148,15 @@ def _make_packable( x ):
114
148
  return x
115
149
 
116
150
  def _is_possibly_ndarray_type( t ):
117
- """Checks if a type annotation is possibly an NDArray."""
151
+ """Check if a type annotation is or contains NDArray.
152
+
153
+ Args:
154
+ t: A type annotation to check.
155
+
156
+ Returns:
157
+ ``True`` if the type is ``NDArray`` or a union containing ``NDArray``
158
+ (e.g., ``NDArray | None``), ``False`` otherwise.
159
+ """
118
160
 
119
161
  # Directly an NDArray
120
162
  if t == NDArray:
@@ -133,10 +175,40 @@ def _is_possibly_ndarray_type( t ):
133
175
 
134
176
  @dataclass
135
177
  class PackableSample( ABC ):
136
- """A sample that can be packed and unpacked with msgpack"""
178
+ """Base class for samples that can be serialized with msgpack.
179
+
180
+ This abstract base class provides automatic serialization/deserialization
181
+ for dataclass-based samples. Fields annotated as ``NDArray`` or
182
+ ``NDArray | None`` are automatically converted between numpy arrays and
183
+ bytes during packing/unpacking.
184
+
185
+ Subclasses should be defined either by:
186
+ 1. Direct inheritance with the ``@dataclass`` decorator
187
+ 2. Using the ``@packable`` decorator (recommended)
188
+
189
+ Example:
190
+ >>> @packable
191
+ ... class MyData:
192
+ ... name: str
193
+ ... embeddings: NDArray
194
+ ...
195
+ >>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
196
+ >>> packed = sample.packed # Serialize to bytes
197
+ >>> restored = MyData.from_bytes(packed) # Deserialize
198
+ """
137
199
 
138
200
  def _ensure_good( self ):
139
- """TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
201
+ """Auto-convert annotated NDArray fields from bytes to numpy arrays.
202
+
203
+ This method scans all dataclass fields and for any field annotated as
204
+ ``NDArray`` or ``NDArray | None``, automatically converts bytes values
205
+ to numpy arrays using the helper deserialization function. This enables
206
+ transparent handling of array serialization in msgpack data.
207
+
208
+ Note:
209
+ This is called during ``__post_init__`` to ensure proper type
210
+ conversion after deserialization.
211
+ """
140
212
 
141
213
  # Auto-convert known types when annotated
142
214
  # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
@@ -173,19 +245,45 @@ class PackableSample( ABC ):
173
245
 
174
246
  @classmethod
175
247
  def from_data( cls, data: MsgpackRawSample ) -> Self:
176
- """Create a sample instance from unpacked msgpack data"""
248
+ """Create a sample instance from unpacked msgpack data.
249
+
250
+ Args:
251
+ data: A dictionary of unpacked msgpack data with keys matching
252
+ the sample's field names.
253
+
254
+ Returns:
255
+ A new instance of this sample class with fields populated from
256
+ the data dictionary and NDArray fields auto-converted from bytes.
257
+ """
177
258
  ret = cls( **data )
178
259
  ret._ensure_good()
179
260
  return ret
180
261
 
181
262
  @classmethod
182
263
  def from_bytes( cls, bs: bytes ) -> Self:
183
- """Create a sample instance from raw msgpack bytes"""
264
+ """Create a sample instance from raw msgpack bytes.
265
+
266
+ Args:
267
+ bs: Raw bytes from a msgpack-serialized sample.
268
+
269
+ Returns:
270
+ A new instance of this sample class deserialized from the bytes.
271
+ """
184
272
  return cls.from_data( ormsgpack.unpackb( bs ) )
185
273
 
186
274
  @property
187
275
  def packed( self ) -> bytes:
188
- """Pack this sample's data into msgpack bytes"""
276
+ """Pack this sample's data into msgpack bytes.
277
+
278
+ NDArray fields are automatically converted to bytes before packing.
279
+ All other fields are packed as-is if they're msgpack-compatible.
280
+
281
+ Returns:
282
+ Raw msgpack bytes representing this sample's data.
283
+
284
+ Raises:
285
+ RuntimeError: If msgpack serialization fails.
286
+ """
189
287
 
190
288
  # Make sure that all of our (possibly unpackable) data is in a packable
191
289
  # format
@@ -204,7 +302,15 @@ class PackableSample( ABC ):
204
302
  # TODO Expand to allow for specifying explicit __key__
205
303
  @property
206
304
  def as_wds( self ) -> WDSRawSample:
207
- """Pack this sample's data for writing to webdataset"""
305
+ """Pack this sample's data for writing to WebDataset.
306
+
307
+ Returns:
308
+ A dictionary with ``__key__`` (UUID v1 for sortable keys) and
309
+ ``msgpack`` (packed sample data) fields suitable for WebDataset.
310
+
311
+ Note:
312
+ TODO: Expand to allow specifying explicit ``__key__`` values.
313
+ """
208
314
  return {
209
315
  # Generates a UUID that is timelike-sortable
210
316
  '__key__': str( uuid.uuid1( 0, 0 ) ),
@@ -212,30 +318,86 @@ class PackableSample( ABC ):
212
318
  }
213
319
 
214
320
  def _batch_aggregate( xs: Sequence ):
321
+ """Aggregate a sequence of values into a batch-appropriate format.
322
+
323
+ Args:
324
+ xs: A sequence of values to aggregate. If the first element is a numpy
325
+ array, all elements are stacked into a single array. Otherwise,
326
+ returns a list.
327
+
328
+ Returns:
329
+ A numpy array (if elements are arrays) or a list (otherwise).
330
+ """
215
331
 
216
332
  if not xs:
217
333
  # Empty sequence
218
334
  return []
219
335
 
220
- # Aggregate
336
+ # Aggregate
221
337
  if isinstance( xs[0], np.ndarray ):
222
338
  return np.array( list( xs ) )
223
339
 
224
340
  return list( xs )
225
341
 
226
342
  class SampleBatch( Generic[DT] ):
343
+ """A batch of samples with automatic attribute aggregation.
344
+
345
+ This class wraps a sequence of samples and provides magic ``__getattr__``
346
+ access to aggregate sample attributes. When you access an attribute that
347
+ exists on the sample type, it automatically aggregates values across all
348
+ samples in the batch.
349
+
350
+ NDArray fields are stacked into a numpy array with a batch dimension.
351
+ Other fields are aggregated into a list.
352
+
353
+ Type Parameters:
354
+ DT: The sample type, must derive from ``PackableSample``.
355
+
356
+ Attributes:
357
+ samples: The list of sample instances in this batch.
358
+
359
+ Example:
360
+ >>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
361
+ >>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
362
+ >>> batch.names # Returns list of names
363
+ """
227
364
 
228
365
  def __init__( self, samples: Sequence[DT] ):
229
- """TODO"""
366
+ """Create a batch from a sequence of samples.
367
+
368
+ Args:
369
+ samples: A sequence of sample instances to aggregate into a batch.
370
+ Each sample must be an instance of a type derived from
371
+ ``PackableSample``.
372
+ """
230
373
  self.samples = list( samples )
231
374
  self._aggregate_cache = dict()
232
375
 
233
376
  @property
234
377
  def sample_type( self ) -> Type:
235
- """The type of each sample in this batch"""
378
+ """The type of each sample in this batch.
379
+
380
+ Returns:
381
+ The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
382
+ """
236
383
  return typing.get_args( self.__orig_class__)[0]
237
384
 
238
385
  def __getattr__( self, name ):
386
+ """Aggregate an attribute across all samples in the batch.
387
+
388
+ This magic method enables attribute-style access to aggregated sample
389
+ fields. Results are cached for efficiency.
390
+
391
+ Args:
392
+ name: The attribute name to aggregate across samples.
393
+
394
+ Returns:
395
+ For NDArray fields: a stacked numpy array with batch dimension.
396
+ For other fields: a list of values from each sample.
397
+
398
+ Raises:
399
+ AttributeError: If the attribute doesn't exist on the sample type.
400
+ """
239
401
  # Aggregate named params of sample type
240
402
  if name in vars( self.sample_type )['__annotations__']:
241
403
  if name not in self._aggregate_cache:
@@ -243,9 +405,9 @@ class SampleBatch( Generic[DT] ):
243
405
  [ getattr( x, name )
244
406
  for x in self.samples ]
245
407
  )
246
-
408
+
247
409
  return self._aggregate_cache[name]
248
-
410
+
249
411
  raise AttributeError( f'No sample attribute named {name}' )
250
412
 
251
413
 
@@ -268,9 +430,32 @@ RT = TypeVar( 'RT', bound = PackableSample )
268
430
  # IT = TypeVar( 'IT', default = Any )
269
431
 
270
432
  class Dataset( Generic[ST] ):
271
- """A dataset that ingests and formats raw samples from a WebDataset
272
-
273
- (Abstract base for subclassing)
433
+ """A typed dataset built on WebDataset with lens transformations.
434
+
435
+ This class wraps WebDataset tar archives and provides type-safe iteration
436
+ over samples of a specific ``PackableSample`` type. Samples are stored as
437
+ msgpack-serialized data within WebDataset shards.
438
+
439
+ The dataset supports:
440
+ - Ordered and shuffled iteration
441
+ - Automatic batching with ``SampleBatch``
442
+ - Type transformations via the lens system (``as_type()``)
443
+ - Export to parquet format
444
+
445
+ Type Parameters:
446
+ ST: The sample type for this dataset, must derive from ``PackableSample``.
447
+
448
+ Attributes:
449
+ url: WebDataset brace-notation URL for the tar file(s).
450
+
451
+ Example:
452
+ >>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
453
+ >>> for sample in ds.ordered(batch_size=32):
454
+ ... # sample is SampleBatch[MyData] with batch_size samples
455
+ ... embeddings = sample.embeddings # shape: (32, ...)
456
+ ...
457
+ >>> # Transform to a different view
458
+ >>> ds_view = ds.as_type(MyDataView)
274
459
  """
275
460
 
276
461
  # sample_class: Type = get_parameters( )
@@ -280,12 +465,23 @@ class Dataset( Generic[ST] ):
280
465
 
281
466
  @property
282
467
  def sample_type( self ) -> Type:
283
- """The type of each returned sample from this `Dataset`'s iterator"""
284
- # TODO Figure out why linting fails here
468
+ """The type of each returned sample from this dataset's iterator.
469
+
470
+ Returns:
471
+ The type parameter ``ST`` used when creating this ``Dataset[ST]``.
472
+
473
+ Note:
474
+ Extracts the type parameter at runtime using ``__orig_class__``.
475
+ """
476
+ # NOTE: Linting may fail here due to __orig_class__ being a runtime attribute
285
477
  return typing.get_args( self.__orig_class__ )[0]
286
478
  @property
287
479
  def batch_type( self ) -> Type:
288
- """The type of a batch built from `sample_class`"""
480
+ """The type of batches produced by this dataset.
481
+
482
+ Returns:
483
+ ``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
484
+ """
289
485
  # return self.__orig_class__.__args__[1]
290
486
  return SampleBatch[self.sample_type]
291
487
 
@@ -296,7 +492,13 @@ class Dataset( Generic[ST] ):
296
492
  #
297
493
 
298
494
  def __init__( self, url: str ) -> None:
299
- """TODO"""
495
+ """Create a dataset from a WebDataset URL.
496
+
497
+ Args:
498
+ url: WebDataset brace-notation URL pointing to tar files, e.g.,
499
+ ``"path/to/file-{000000..000009}.tar"`` for multiple shards or
500
+ ``"path/to/file-000000.tar"`` for a single shard.
501
+ """
300
502
  super().__init__()
301
503
  self.url = url
302
504
 
@@ -304,7 +506,21 @@ class Dataset( Generic[ST] ):
304
506
  self._output_lens: Lens | None = None
305
507
 
306
508
  def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
307
- """TODO"""
509
+ """View this dataset through a different sample type using a registered lens.
510
+
511
+ Args:
512
+ other: The target sample type to transform into. Must be a type
513
+ derived from ``PackableSample``.
514
+
515
+ Returns:
516
+ A new ``Dataset`` instance that yields samples of type ``other``
517
+ by applying the appropriate lens transformation from the global
518
+ ``LensNetwork`` registry.
519
+
520
+ Raises:
521
+ ValueError: If no registered lens exists between the current
522
+ sample type and the target type.
523
+ """
308
524
  ret = Dataset[other]( self.url )
309
525
  # Get the singleton lens registry
310
526
  lenses = LensNetwork()
@@ -384,18 +600,23 @@ class Dataset( Generic[ST] ):
384
600
  buffer_samples: int = 10_000,
385
601
  batch_size: int | None = 1,
386
602
  ) -> Iterable[ST]:
387
- """Iterate over the dataset in random order
388
-
603
+ """Iterate over the dataset in random order.
604
+
389
605
  Args:
390
- buffer_shards (int): Asdf
391
- batch_size (:obj:`int`, optional) The size of iterated batches.
392
- Default: 1. If ``None``, iterates over one sample at a time
393
- with no batch dimension.
394
-
606
+ buffer_shards: Number of shards to buffer for shuffling at the
607
+ shard level. Larger values increase randomness but use more
608
+ memory. Default: 100.
609
+ buffer_samples: Number of samples to buffer for shuffling within
610
+ shards. Larger values increase randomness but use more memory.
611
+ Default: 10,000.
612
+ batch_size: The size of iterated batches. Default: 1. If ``None``,
613
+ iterates over one sample at a time with no batch dimension.
614
+
395
615
  Returns:
396
- :obj:`webdataset.DataPipeline` A data pipeline that iterates over
397
- the dataset in its original sample order
398
-
616
+ A WebDataset data pipeline that iterates over the dataset in
617
+ randomized order. If ``batch_size`` is not ``None``, yields
618
+ ``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
619
+ samples.
399
620
  """
400
621
 
401
622
  if batch_size is None:
@@ -500,7 +721,16 @@ class Dataset( Generic[ST] ):
500
721
  # @classmethod
501
722
  # TODO replace Any with IT
502
723
  def wrap( self, sample: MsgpackRawSample ) -> ST:
503
- """Wrap a `sample` into the appropriate dataset-specific type"""
724
+ """Wrap a raw msgpack sample into the appropriate dataset-specific type.
725
+
726
+ Args:
727
+ sample: A dictionary containing at minimum a ``'msgpack'`` key with
728
+ serialized sample bytes.
729
+
730
+ Returns:
731
+ A deserialized sample of type ``ST``, optionally transformed through
732
+ a lens if ``as_type()`` was called.
733
+ """
504
734
  assert 'msgpack' in sample
505
735
  assert type( sample['msgpack'] ) == bytes
506
736
 
@@ -524,9 +754,19 @@ class Dataset( Generic[ST] ):
524
754
  # )
525
755
 
526
756
  def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
527
- """Wrap a `batch` of samples into the appropriate dataset-specific type
528
-
529
- This default implementation simply creates a list one sample at a time
757
+ """Wrap a batch of raw msgpack samples into a typed SampleBatch.
758
+
759
+ Args:
760
+ batch: A dictionary containing a ``'msgpack'`` key with a list of
761
+ serialized sample bytes.
762
+
763
+ Returns:
764
+ A ``SampleBatch[ST]`` containing deserialized samples, optionally
765
+ transformed through a lens if ``as_type()`` was called.
766
+
767
+ Note:
768
+ This implementation deserializes samples one at a time, then
769
+ aggregates them into a batch.
530
770
  """
531
771
 
532
772
  assert 'msgpack' in batch
@@ -572,8 +812,30 @@ class Dataset( Generic[ST] ):
572
812
  # return decorator
573
813
 
574
814
  def packable( cls ):
575
- """TODO"""
576
-
815
+ """Decorator to convert a regular class into a ``PackableSample``.
816
+
817
+ This decorator transforms a class into a dataclass that inherits from
818
+ ``PackableSample``, enabling automatic msgpack serialization/deserialization
819
+ with special handling for NDArray fields.
820
+
821
+ Args:
822
+ cls: The class to convert. Should have type annotations for its fields.
823
+
824
+ Returns:
825
+ A new dataclass that inherits from ``PackableSample`` with the same
826
+ name and annotations as the original class.
827
+
828
+ Example:
829
+ >>> @packable
830
+ ... class MyData:
831
+ ... name: str
832
+ ... values: NDArray
833
+ ...
834
+ >>> sample = MyData(name="test", values=np.array([1, 2, 3]))
835
+ >>> bytes_data = sample.packed
836
+ >>> restored = MyData.from_bytes(bytes_data)
837
+ """
838
+
577
839
  ##
578
840
 
579
841
  class_name = cls.__name__
atdata/lens.py CHANGED
@@ -1,4 +1,42 @@
1
- """Lenses between typed datasets"""
1
+ """Lens-based type transformations for datasets.
2
+
3
+ This module implements a lens system for bidirectional transformations between
4
+ different sample types. Lenses enable viewing a dataset through different type
5
+ schemas without duplicating the underlying data.
6
+
7
+ Key components:
8
+
9
+ - ``Lens``: Bidirectional transformation with getter (S -> V) and optional
10
+ putter (V, S -> S)
11
+ - ``LensNetwork``: Global singleton registry for lens transformations
12
+ - ``@lens``: Decorator to create and register lens transformations
13
+
14
+ Lenses support the functional programming concept of composable, well-behaved
15
+ transformations that satisfy lens laws (GetPut and PutGet).
16
+
17
+ Example:
18
+ >>> @packable
19
+ ... class FullData:
20
+ ... name: str
21
+ ... age: int
22
+ ... embedding: NDArray
23
+ ...
24
+ >>> @packable
25
+ ... class NameOnly:
26
+ ... name: str
27
+ ...
28
+ >>> @lens
29
+ ... def name_view(full: FullData) -> NameOnly:
30
+ ... return NameOnly(name=full.name)
31
+ ...
32
+ >>> @name_view.putter
33
+ ... def name_view_put(view: NameOnly, source: FullData) -> FullData:
34
+ ... return FullData(name=view.name, age=source.age,
35
+ ... embedding=source.embedding)
36
+ ...
37
+ >>> ds = Dataset[FullData]("data.tar")
38
+ >>> ds_names = ds.as_type(NameOnly) # Uses registered lens
39
+ """
2
40
 
3
41
  ##
4
42
  # Imports
@@ -39,24 +77,45 @@ type LensPutter[S, V] = Callable[[V, S], S]
39
77
  # Shortcut decorators
40
78
 
41
79
  class Lens( Generic[S, V] ):
42
- """TODO"""
43
-
44
- # @property
45
- # def source_type( self ) -> Type[S]:
46
- # """The source type (S) for the lens; what is put to"""
47
- # # TODO Figure out why linting fails here
48
- # return self.__orig_class__.__args__[0]
49
-
50
- # @property
51
- # def view_type( self ) -> Type[V]:
52
- # """The view type (V) for the lens; what is get'd from"""
53
- # # TODO FIgure out why linting fails here
54
- # return self.__orig_class__.__args__[1]
80
+ """A bidirectional transformation between two sample types.
81
+
82
+ A lens provides a way to view and update data of type ``S`` (source) as if
83
+ it were type ``V`` (view). It consists of a getter that transforms ``S -> V``
84
+ and an optional putter that transforms ``(V, S) -> S``, enabling updates to
85
+ the view to be reflected back in the source.
86
+
87
+ Type Parameters:
88
+ S: The source type, must derive from ``PackableSample``.
89
+ V: The view type, must derive from ``PackableSample``.
90
+
91
+ Example:
92
+ >>> @lens
93
+ ... def name_lens(full: FullData) -> NameOnly:
94
+ ... return NameOnly(name=full.name)
95
+ ...
96
+ >>> @name_lens.putter
97
+ ... def name_lens_put(view: NameOnly, source: FullData) -> FullData:
98
+ ... return FullData(name=view.name, age=source.age)
99
+ """
55
100
 
56
101
  def __init__( self, get: LensGetter[S, V],
57
102
  put: Optional[LensPutter[S, V]] = None
58
103
  ) -> None:
59
- """TODO"""
104
+ """Initialize a lens with a getter and optional putter function.
105
+
106
+ Args:
107
+ get: A function that transforms from source type ``S`` to view type
108
+ ``V``. Must accept exactly one parameter annotated with the
109
+ source type.
110
+ put: An optional function that updates the source based on a modified
111
+ view. Takes a view of type ``V`` and original source of type ``S``,
112
+ and returns an updated source of type ``S``. If not provided, a
113
+ trivial putter is used that ignores updates to the view.
114
+
115
+ Raises:
116
+ AssertionError: If the getter function doesn't have exactly one
117
+ parameter.
118
+ """
60
119
  ##
61
120
 
62
121
  # Check argument validity
@@ -70,11 +129,11 @@ class Lens( Generic[S, V] ):
70
129
  functools.update_wrapper( self, get )
71
130
 
72
131
  self.source_type: Type[PackableSample] = input_types[0].annotation
73
- self.view_type = sig.return_annotation
132
+ self.view_type: Type[PackableSample] = sig.return_annotation
74
133
 
75
134
  # Store the getter
76
135
  self._getter = get
77
-
136
+
78
137
  # Determine and store the putter
79
138
  if put is None:
80
139
  # Trivial putter does not update the source
@@ -86,7 +145,20 @@ class Lens( Generic[S, V] ):
86
145
  #
87
146
 
88
147
  def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
89
- """TODO"""
148
+ """Decorator to register a putter function for this lens.
149
+
150
+ Args:
151
+ put: A function that takes a view of type ``V`` and source of type
152
+ ``S``, and returns an updated source of type ``S``.
153
+
154
+ Returns:
155
+ The putter function, allowing this to be used as a decorator.
156
+
157
+ Example:
158
+ >>> @my_lens.putter
159
+ ... def my_lens_put(view: ViewType, source: SourceType) -> SourceType:
160
+ ... return SourceType(...)
161
+ """
90
162
  ##
91
163
  self._putter = put
92
164
  return put
@@ -94,16 +166,39 @@ class Lens( Generic[S, V] ):
94
166
  # Methods to actually execute transformations
95
167
 
96
168
  def put( self, v: V, s: S ) -> S:
97
- """TODO"""
169
+ """Update the source based on a modified view.
170
+
171
+ Args:
172
+ v: The modified view of type ``V``.
173
+ s: The original source of type ``S``.
174
+
175
+ Returns:
176
+ An updated source of type ``S`` that reflects changes from the view.
177
+ """
98
178
  return self._putter( v, s )
99
179
 
100
180
  def get( self, s: S ) -> V:
101
- """TODO"""
181
+ """Transform the source into the view type.
182
+
183
+ Args:
184
+ s: The source sample of type ``S``.
185
+
186
+ Returns:
187
+ A view of the source as type ``V``.
188
+ """
102
189
  return self( s )
103
190
 
104
191
  # Convenience to enable calling the lens as its getter
105
-
192
+
106
193
  def __call__( self, s: S ) -> V:
194
+ """Apply the lens transformation (same as ``get()``).
195
+
196
+ Args:
197
+ s: The source sample of type ``S``.
198
+
199
+ Returns:
200
+ A view of the source as type ``V``.
201
+ """
107
202
  return self._getter( s )
108
203
 
109
204
  # TODO Figure out how to properly parameterize this
@@ -124,6 +219,28 @@ class Lens( Generic[S, V] ):
124
219
  # lens = _lens_factory
125
220
 
126
221
  def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
222
+ """Decorator to create and register a lens transformation.
223
+
224
+ This decorator converts a getter function into a ``Lens`` object and
225
+ automatically registers it in the global ``LensNetwork`` registry.
226
+
227
+ Args:
228
+ f: A getter function that transforms from source type ``S`` to view
229
+ type ``V``. Must have exactly one parameter with a type annotation.
230
+
231
+ Returns:
232
+ A ``Lens[S, V]`` object that can be called to apply the transformation
233
+ or decorated with ``@lens_name.putter`` to add a putter function.
234
+
235
+ Example:
236
+ >>> @lens
237
+ ... def extract_name(full: FullData) -> NameOnly:
238
+ ... return NameOnly(name=full.name)
239
+ ...
240
+ >>> @extract_name.putter
241
+ ... def extract_name_put(view: NameOnly, source: FullData) -> FullData:
242
+ ... return FullData(name=view.name, age=source.age)
243
+ """
127
244
  ret = Lens[S, V]( f )
128
245
  _network.register( ret )
129
246
  return ret
@@ -136,25 +253,46 @@ def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
136
253
  # """TODO"""
137
254
 
138
255
  class LensNetwork:
139
- """TODO"""
256
+ """Global registry for lens transformations between sample types.
257
+
258
+ This class implements a singleton pattern to maintain a global registry of
259
+ all lenses decorated with ``@lens``. It enables looking up transformations
260
+ between different ``PackableSample`` types.
261
+
262
+ Attributes:
263
+ _instance: The singleton instance of this class.
264
+ _registry: Dictionary mapping ``(source_type, view_type)`` tuples to
265
+ their corresponding ``Lens`` objects.
266
+ """
140
267
 
141
268
  _instance = None
142
269
  """The singleton instance"""
143
270
 
144
271
  def __new__(cls, *args, **kwargs):
272
+ """Ensure only one instance of LensNetwork exists (singleton pattern)."""
145
273
  if cls._instance is None:
146
274
  # If no instance exists, create a new one
147
275
  cls._instance = super().__new__(cls)
148
276
  return cls._instance # Return the existing (or newly created) instance
149
277
 
150
278
  def __init__(self):
279
+ """Initialize the lens registry (only on first instantiation)."""
151
280
  if not hasattr(self, '_initialized'): # Check if already initialized
152
281
  self._registry: Dict[LensSignature, Lens] = dict()
153
282
  self._initialized = True
154
283
 
155
284
  def register( self, _lens: Lens ):
156
- """Set `lens` as the canonical view between its source and view types"""
157
-
285
+ """Register a lens as the canonical transformation between two types.
286
+
287
+ Args:
288
+ _lens: The lens to register. Will be stored in the registry under
289
+ the key ``(_lens.source_type, _lens.view_type)``.
290
+
291
+ Note:
292
+ If a lens already exists for the same type pair, it will be
293
+ overwritten.
294
+ """
295
+
158
296
  # sig = inspect.signature( _lens.get )
159
297
  # input_types = list( sig.parameters.values() )
160
298
  # assert len( input_types ) == 1, \
@@ -169,13 +307,28 @@ class LensNetwork:
169
307
  self._registry[_lens.source_type, _lens.view_type] = _lens
170
308
 
171
309
  def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
172
- """TODO"""
310
+ """Look up the lens transformation between two sample types.
311
+
312
+ Args:
313
+ source: The source sample type (must derive from ``PackableSample``).
314
+ view: The target view type (must derive from ``PackableSample``).
315
+
316
+ Returns:
317
+ The registered ``Lens`` that transforms from ``source`` to ``view``.
318
+
319
+ Raises:
320
+ ValueError: If no lens has been registered for the given type pair.
321
+
322
+ Note:
323
+ Currently only supports direct transformations. Compositional
324
+ transformations (chaining multiple lenses) are not yet implemented.
325
+ """
173
326
 
174
327
  # TODO Handle compositional closure
175
328
  ret = self._registry.get( (source, view), None )
176
329
  if ret is None:
177
330
  raise ValueError( f'No registered lens from source {source} to view {view}' )
178
-
331
+
179
332
  return ret
180
333
 
181
334
 
@@ -0,0 +1,172 @@
1
+ Metadata-Version: 2.4
2
+ Name: atdata
3
+ Version: 0.1.3b4
4
+ Summary: A loose federation of distributed, typed datasets
5
+ Author-email: Maxine Levesque <hello@maxine.science>
6
+ License-File: LICENSE
7
+ Requires-Python: >=3.12
8
+ Requires-Dist: fastparquet>=2024.11.0
9
+ Requires-Dist: msgpack>=1.1.2
10
+ Requires-Dist: numpy>=2.3.4
11
+ Requires-Dist: ormsgpack>=1.11.0
12
+ Requires-Dist: pandas>=2.3.3
13
+ Requires-Dist: tqdm>=4.67.1
14
+ Requires-Dist: webdataset>=1.0.2
15
+ Description-Content-Type: text/markdown
16
+
17
+ # atdata
18
+
19
+ [![codecov](https://codecov.io/gh/foundation-ac/atdata/branch/main/graph/badge.svg)](https://codecov.io/gh/foundation-ac/atdata)
20
+
21
+ A loose federation of distributed, typed datasets built on WebDataset.
22
+
23
+ **atdata** provides a type-safe, composable framework for working with large-scale datasets. It combines the efficiency of WebDataset's tar-based storage with Python's type system and functional programming patterns.
24
+
25
+ ## Features
26
+
27
+ - **Typed Samples** - Define dataset schemas using Python dataclasses with automatic msgpack serialization
28
+ - **Lens Transformations** - Bidirectional, composable transformations between different dataset views
29
+ - **Automatic Batching** - Smart batch aggregation with numpy array stacking
30
+ - **WebDataset Integration** - Efficient storage and streaming for large-scale datasets
31
+
32
+ ## Installation
33
+
34
+ ```bash
35
+ pip install atdata
36
+ ```
37
+
38
+ Requires Python 3.12 or later.
39
+
40
+ ## Quick Start
41
+
42
+ ### Defining Sample Types
43
+
44
+ Use the `@packable` decorator to create typed dataset samples:
45
+
46
+ ```python
47
+ import atdata
48
+ from numpy.typing import NDArray
49
+
50
+ @atdata.packable
51
+ class ImageSample:
52
+ image: NDArray
53
+ label: str
54
+ metadata: dict
55
+ ```
56
+
57
+ ### Creating Datasets
58
+
59
+ ```python
60
+ # Create a dataset
61
+ dataset = atdata.Dataset[ImageSample]("path/to/data-{000000..000009}.tar")
62
+
63
+ # Iterate over samples in order
64
+ for sample in dataset.ordered(batch_size=None):
65
+ print(f"Label: {sample.label}, Image shape: {sample.image.shape}")
66
+
67
+ # Iterate with shuffling and batching
68
+ for batch in dataset.shuffled(batch_size=32):
69
+ # batch.image is automatically stacked into shape (32, ...)
70
+ # batch.label is a list of 32 labels
71
+ process_batch(batch.image, batch.label)
72
+ ```
73
+
74
+ ### Lens Transformations
75
+
76
+ Define reusable transformations between sample types:
77
+
78
+ ```python
79
+ @atdata.packable
80
+ class ProcessedSample:
81
+ features: NDArray
82
+ label: str
83
+
84
+ @atdata.lens
85
+ def preprocess(sample: ImageSample) -> ProcessedSample:
86
+ features = extract_features(sample.image)
87
+ return ProcessedSample(features=features, label=sample.label)
88
+
89
+ # Apply lens to view dataset as ProcessedSample
90
+ processed_ds = dataset.as_type(ProcessedSample)
91
+
92
+ for sample in processed_ds.ordered(batch_size=None):
93
+ # sample is now a ProcessedSample
94
+ print(sample.features.shape)
95
+ ```
96
+
97
+ ## Core Concepts
98
+
99
+ ### PackableSample
100
+
101
+ Base class for serializable samples. Fields annotated as `NDArray` are automatically handled:
102
+
103
+ ```python
104
+ @atdata.packable
105
+ class MySample:
106
+ array_field: NDArray # Automatically serialized
107
+ optional_array: NDArray | None
108
+ regular_field: str
109
+ ```
110
+
111
+ ### Lens
112
+
113
+ Bidirectional transformations with getter/putter semantics:
114
+
115
+ ```python
116
+ @atdata.lens
117
+ def my_lens(source: SourceType) -> ViewType:
118
+ # Transform source -> view
119
+ return ViewType(...)
120
+
121
+ @my_lens.putter
122
+ def my_lens_put(view: ViewType, source: SourceType) -> SourceType:
123
+ # Transform view -> source
124
+ return SourceType(...)
125
+ ```
126
+
127
+ ### Dataset URLs
128
+
129
+ Uses WebDataset brace expansion for sharded datasets:
130
+
131
+ - Single file: `"data/dataset-000000.tar"`
132
+ - Multiple shards: `"data/dataset-{000000..000099}.tar"`
133
+ - Multiple patterns: `"data/{train,val}/dataset-{000000..000009}.tar"`
134
+
135
+ ## Development
136
+
137
+ ### Setup
138
+
139
+ ```bash
140
+ # Install uv if not already available
141
+ python -m pip install uv
142
+
143
+ # Install dependencies
144
+ uv sync
145
+ ```
146
+
147
+ ### Testing
148
+
149
+ ```bash
150
+ # Run all tests with coverage
151
+ pytest
152
+
153
+ # Run specific test file
154
+ pytest tests/test_dataset.py
155
+
156
+ # Run single test
157
+ pytest tests/test_lens.py::test_lens
158
+ ```
159
+
160
+ ### Building
161
+
162
+ ```bash
163
+ uv build
164
+ ```
165
+
166
+ ## Contributing
167
+
168
+ Contributions are welcome! This project is in beta, so the API may still evolve.
169
+
170
+ ## License
171
+
172
+ This project is licensed under the Mozilla Public License 2.0. See [LICENSE](LICENSE) for details.
@@ -0,0 +1,9 @@
1
+ atdata/__init__.py,sha256=_363ZuJfwbBQTMYsoKOiyoBe4AHr3iplK-EQyrAeTdg,1545
2
+ atdata/_helpers.py,sha256=RvA-Xlj3AvgSWuiPdS8YTBp8AJT-u32BaLpxsu4PIIA,1564
3
+ atdata/dataset.py,sha256=O_7b3ub_M4IMRuhv95oz1PVFdsOhNiyXgtY8NphPdBk,27842
4
+ atdata/lens.py,sha256=ynn1DQkR89eRL6JV9EsawuPY9JTrZ67pAX4cRvZ6UVk,11157
5
+ atdata-0.1.3b4.dist-info/METADATA,sha256=SdZSI_SonE-pt4nhmFh5bz9zKD79wT2CKXKFxrTfvgc,4162
6
+ atdata-0.1.3b4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ atdata-0.1.3b4.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
8
+ atdata-0.1.3b4.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
9
+ atdata-0.1.3b4.dist-info/RECORD,,
@@ -1,18 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: atdata
3
- Version: 0.1.3b3
4
- Summary: A loose federation of distributed, typed datasets
5
- Author-email: Maxine Levesque <hello@maxine.science>
6
- License-File: LICENSE
7
- Requires-Python: >=3.12
8
- Requires-Dist: fastparquet>=2024.11.0
9
- Requires-Dist: msgpack>=1.1.2
10
- Requires-Dist: numpy>=2.3.4
11
- Requires-Dist: ormsgpack>=1.11.0
12
- Requires-Dist: pandas>=2.3.3
13
- Requires-Dist: tqdm>=4.67.1
14
- Requires-Dist: webdataset>=1.0.2
15
- Description-Content-Type: text/markdown
16
-
17
- # atdata
18
- A loose federation of distributed, typed datasets
@@ -1,9 +0,0 @@
1
- atdata/__init__.py,sha256=V2qBg7i2mfCNG9nww6Gi_fDp7iwolDMrNzhmNO6VA7M,233
2
- atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
3
- atdata/dataset.py,sha256=qyAiKSjjYqFVWmaLz5LAIZ3_YVHbm5lg32zmctqjjlE,18085
4
- atdata/lens.py,sha256=HvXuRqYTeJBpMyIQVdGZXxEvbGKBuFCF8lbiib4TqsA,5306
5
- atdata-0.1.3b3.dist-info/METADATA,sha256=jrGZ592QbkJdZCq8FLmXOznQ0LkTUyUkqLVIH3ZRj4U,529
6
- atdata-0.1.3b3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- atdata-0.1.3b3.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
8
- atdata-0.1.3b3.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
9
- atdata-0.1.3b3.dist-info/RECORD,,