atdata 0.1.3b3__py3-none-any.whl → 0.1.3b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +36 -1
- atdata/_helpers.py +39 -3
- atdata/dataset.py +299 -37
- atdata/lens.py +179 -26
- atdata-0.1.3b4.dist-info/METADATA +172 -0
- atdata-0.1.3b4.dist-info/RECORD +9 -0
- atdata-0.1.3b3.dist-info/METADATA +0 -18
- atdata-0.1.3b3.dist-info/RECORD +0 -9
- {atdata-0.1.3b3.dist-info → atdata-0.1.3b4.dist-info}/WHEEL +0 -0
- {atdata-0.1.3b3.dist-info → atdata-0.1.3b4.dist-info}/entry_points.txt +0 -0
- {atdata-0.1.3b3.dist-info → atdata-0.1.3b4.dist-info}/licenses/LICENSE +0 -0
atdata/__init__.py
CHANGED
|
@@ -1,4 +1,39 @@
|
|
|
1
|
-
"""A loose federation of distributed, typed datasets
|
|
1
|
+
"""A loose federation of distributed, typed datasets.
|
|
2
|
+
|
|
3
|
+
``atdata`` provides a typed dataset abstraction built on WebDataset, with support
|
|
4
|
+
for:
|
|
5
|
+
|
|
6
|
+
- **Typed samples** with automatic msgpack serialization
|
|
7
|
+
- **NDArray handling** with transparent bytes conversion
|
|
8
|
+
- **Lens transformations** for viewing datasets through different type schemas
|
|
9
|
+
- **Batch aggregation** with automatic numpy array stacking
|
|
10
|
+
- **WebDataset integration** for efficient large-scale dataset storage
|
|
11
|
+
|
|
12
|
+
Quick Start:
|
|
13
|
+
>>> import atdata
|
|
14
|
+
>>> import numpy as np
|
|
15
|
+
>>>
|
|
16
|
+
>>> @atdata.packable
|
|
17
|
+
... class MyData:
|
|
18
|
+
... features: np.ndarray
|
|
19
|
+
... label: str
|
|
20
|
+
>>>
|
|
21
|
+
>>> # Create dataset from WebDataset tar files
|
|
22
|
+
>>> ds = atdata.Dataset[MyData]("path/to/data-{000000..000009}.tar")
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Iterate with automatic batching
|
|
25
|
+
>>> for batch in ds.shuffled(batch_size=32):
|
|
26
|
+
... features = batch.features # numpy array (32, ...)
|
|
27
|
+
... labels = batch.label # list of 32 strings
|
|
28
|
+
|
|
29
|
+
Main Components:
|
|
30
|
+
- ``PackableSample``: Base class for msgpack-serializable samples
|
|
31
|
+
- ``Dataset``: Typed dataset wrapper for WebDataset
|
|
32
|
+
- ``SampleBatch``: Automatic batch aggregation
|
|
33
|
+
- ``Lens``: Bidirectional type transformations
|
|
34
|
+
- ``@packable``: Decorator for creating PackableSample classes
|
|
35
|
+
- ``@lens``: Decorator for creating lens transformations
|
|
36
|
+
"""
|
|
2
37
|
|
|
3
38
|
##
|
|
4
39
|
# Expose components
|
atdata/_helpers.py
CHANGED
|
@@ -1,4 +1,16 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Helper utilities for numpy array serialization.
|
|
2
|
+
|
|
3
|
+
This module provides utility functions for converting numpy arrays to and from
|
|
4
|
+
bytes for msgpack serialization. The functions use numpy's native save/load
|
|
5
|
+
format to preserve array dtype and shape information.
|
|
6
|
+
|
|
7
|
+
Functions:
|
|
8
|
+
- ``array_to_bytes()``: Serialize numpy array to bytes
|
|
9
|
+
- ``bytes_to_array()``: Deserialize bytes to numpy array
|
|
10
|
+
|
|
11
|
+
These helpers are used internally by ``PackableSample`` to enable transparent
|
|
12
|
+
handling of NDArray fields during msgpack packing/unpacking.
|
|
13
|
+
"""
|
|
2
14
|
|
|
3
15
|
##
|
|
4
16
|
# Imports
|
|
@@ -11,12 +23,36 @@ import numpy as np
|
|
|
11
23
|
##
|
|
12
24
|
|
|
13
25
|
def array_to_bytes( x: np.ndarray ) -> bytes:
|
|
14
|
-
"""Convert
|
|
26
|
+
"""Convert a numpy array to bytes for msgpack serialization.
|
|
27
|
+
|
|
28
|
+
Uses numpy's native ``save()`` format to preserve array dtype and shape.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
x: A numpy array to serialize.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Raw bytes representing the serialized array.
|
|
35
|
+
|
|
36
|
+
Note:
|
|
37
|
+
Uses ``allow_pickle=True`` to support object dtypes.
|
|
38
|
+
"""
|
|
15
39
|
np_bytes = BytesIO()
|
|
16
40
|
np.save( np_bytes, x, allow_pickle = True )
|
|
17
41
|
return np_bytes.getvalue()
|
|
18
42
|
|
|
19
43
|
def bytes_to_array( b: bytes ) -> np.ndarray:
|
|
20
|
-
"""Convert
|
|
44
|
+
"""Convert serialized bytes back to a numpy array.
|
|
45
|
+
|
|
46
|
+
Reverses the serialization performed by ``array_to_bytes()``.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
b: Raw bytes from a serialized numpy array.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
The deserialized numpy array with original dtype and shape.
|
|
53
|
+
|
|
54
|
+
Note:
|
|
55
|
+
Uses ``allow_pickle=True`` to support object dtypes.
|
|
56
|
+
"""
|
|
21
57
|
np_bytes = BytesIO( b )
|
|
22
58
|
return np.load( np_bytes, allow_pickle = True )
|
atdata/dataset.py
CHANGED
|
@@ -1,4 +1,29 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Core dataset and sample infrastructure for typed WebDatasets.
|
|
2
|
+
|
|
3
|
+
This module provides the core components for working with typed, msgpack-serialized
|
|
4
|
+
samples in WebDataset format:
|
|
5
|
+
|
|
6
|
+
- ``PackableSample``: Base class for msgpack-serializable samples with automatic
|
|
7
|
+
NDArray handling
|
|
8
|
+
- ``SampleBatch``: Automatic batching with attribute aggregation
|
|
9
|
+
- ``Dataset``: Generic typed dataset wrapper for WebDataset tar files
|
|
10
|
+
- ``@packable``: Decorator to convert regular classes into PackableSample subclasses
|
|
11
|
+
|
|
12
|
+
The implementation handles automatic conversion between numpy arrays and bytes
|
|
13
|
+
during serialization, enabling efficient storage of numerical data in WebDataset
|
|
14
|
+
archives.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
>>> @packable
|
|
18
|
+
... class ImageSample:
|
|
19
|
+
... image: NDArray
|
|
20
|
+
... label: str
|
|
21
|
+
...
|
|
22
|
+
>>> ds = Dataset[ImageSample]("data-{000000..000009}.tar")
|
|
23
|
+
>>> for batch in ds.shuffled(batch_size=32):
|
|
24
|
+
... images = batch.image # Stacked numpy array (32, H, W, C)
|
|
25
|
+
... labels = batch.label # List of 32 strings
|
|
26
|
+
"""
|
|
2
27
|
|
|
3
28
|
##
|
|
4
29
|
# Imports
|
|
@@ -107,6 +132,15 @@ MsgpackRawSample: TypeAlias = Dict[str, Any]
|
|
|
107
132
|
# return eh.bytes_to_array( self.raw_bytes )
|
|
108
133
|
|
|
109
134
|
def _make_packable( x ):
|
|
135
|
+
"""Convert a value to a msgpack-compatible format.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
x: A value to convert. If it's a numpy array, converts to bytes.
|
|
139
|
+
Otherwise returns the value unchanged.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
The value in a format suitable for msgpack serialization.
|
|
143
|
+
"""
|
|
110
144
|
# if isinstance( x, ArrayBytes ):
|
|
111
145
|
# return x.raw_bytes
|
|
112
146
|
if isinstance( x, np.ndarray ):
|
|
@@ -114,7 +148,15 @@ def _make_packable( x ):
|
|
|
114
148
|
return x
|
|
115
149
|
|
|
116
150
|
def _is_possibly_ndarray_type( t ):
|
|
117
|
-
"""
|
|
151
|
+
"""Check if a type annotation is or contains NDArray.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
t: A type annotation to check.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
``True`` if the type is ``NDArray`` or a union containing ``NDArray``
|
|
158
|
+
(e.g., ``NDArray | None``), ``False`` otherwise.
|
|
159
|
+
"""
|
|
118
160
|
|
|
119
161
|
# Directly an NDArray
|
|
120
162
|
if t == NDArray:
|
|
@@ -133,10 +175,40 @@ def _is_possibly_ndarray_type( t ):
|
|
|
133
175
|
|
|
134
176
|
@dataclass
|
|
135
177
|
class PackableSample( ABC ):
|
|
136
|
-
"""
|
|
178
|
+
"""Base class for samples that can be serialized with msgpack.
|
|
179
|
+
|
|
180
|
+
This abstract base class provides automatic serialization/deserialization
|
|
181
|
+
for dataclass-based samples. Fields annotated as ``NDArray`` or
|
|
182
|
+
``NDArray | None`` are automatically converted between numpy arrays and
|
|
183
|
+
bytes during packing/unpacking.
|
|
184
|
+
|
|
185
|
+
Subclasses should be defined either by:
|
|
186
|
+
1. Direct inheritance with the ``@dataclass`` decorator
|
|
187
|
+
2. Using the ``@packable`` decorator (recommended)
|
|
188
|
+
|
|
189
|
+
Example:
|
|
190
|
+
>>> @packable
|
|
191
|
+
... class MyData:
|
|
192
|
+
... name: str
|
|
193
|
+
... embeddings: NDArray
|
|
194
|
+
...
|
|
195
|
+
>>> sample = MyData(name="test", embeddings=np.array([1.0, 2.0]))
|
|
196
|
+
>>> packed = sample.packed # Serialize to bytes
|
|
197
|
+
>>> restored = MyData.from_bytes(packed) # Deserialize
|
|
198
|
+
"""
|
|
137
199
|
|
|
138
200
|
def _ensure_good( self ):
|
|
139
|
-
"""
|
|
201
|
+
"""Auto-convert annotated NDArray fields from bytes to numpy arrays.
|
|
202
|
+
|
|
203
|
+
This method scans all dataclass fields and for any field annotated as
|
|
204
|
+
``NDArray`` or ``NDArray | None``, automatically converts bytes values
|
|
205
|
+
to numpy arrays using the helper deserialization function. This enables
|
|
206
|
+
transparent handling of array serialization in msgpack data.
|
|
207
|
+
|
|
208
|
+
Note:
|
|
209
|
+
This is called during ``__post_init__`` to ensure proper type
|
|
210
|
+
conversion after deserialization.
|
|
211
|
+
"""
|
|
140
212
|
|
|
141
213
|
# Auto-convert known types when annotated
|
|
142
214
|
# for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
@@ -173,19 +245,45 @@ class PackableSample( ABC ):
|
|
|
173
245
|
|
|
174
246
|
@classmethod
|
|
175
247
|
def from_data( cls, data: MsgpackRawSample ) -> Self:
|
|
176
|
-
"""Create a sample instance from unpacked msgpack data
|
|
248
|
+
"""Create a sample instance from unpacked msgpack data.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
data: A dictionary of unpacked msgpack data with keys matching
|
|
252
|
+
the sample's field names.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
A new instance of this sample class with fields populated from
|
|
256
|
+
the data dictionary and NDArray fields auto-converted from bytes.
|
|
257
|
+
"""
|
|
177
258
|
ret = cls( **data )
|
|
178
259
|
ret._ensure_good()
|
|
179
260
|
return ret
|
|
180
261
|
|
|
181
262
|
@classmethod
|
|
182
263
|
def from_bytes( cls, bs: bytes ) -> Self:
|
|
183
|
-
"""Create a sample instance from raw msgpack bytes
|
|
264
|
+
"""Create a sample instance from raw msgpack bytes.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
bs: Raw bytes from a msgpack-serialized sample.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
A new instance of this sample class deserialized from the bytes.
|
|
271
|
+
"""
|
|
184
272
|
return cls.from_data( ormsgpack.unpackb( bs ) )
|
|
185
273
|
|
|
186
274
|
@property
|
|
187
275
|
def packed( self ) -> bytes:
|
|
188
|
-
"""Pack this sample's data into msgpack bytes
|
|
276
|
+
"""Pack this sample's data into msgpack bytes.
|
|
277
|
+
|
|
278
|
+
NDArray fields are automatically converted to bytes before packing.
|
|
279
|
+
All other fields are packed as-is if they're msgpack-compatible.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Raw msgpack bytes representing this sample's data.
|
|
283
|
+
|
|
284
|
+
Raises:
|
|
285
|
+
RuntimeError: If msgpack serialization fails.
|
|
286
|
+
"""
|
|
189
287
|
|
|
190
288
|
# Make sure that all of our (possibly unpackable) data is in a packable
|
|
191
289
|
# format
|
|
@@ -204,7 +302,15 @@ class PackableSample( ABC ):
|
|
|
204
302
|
# TODO Expand to allow for specifying explicit __key__
|
|
205
303
|
@property
|
|
206
304
|
def as_wds( self ) -> WDSRawSample:
|
|
207
|
-
"""Pack this sample's data for writing to
|
|
305
|
+
"""Pack this sample's data for writing to WebDataset.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
A dictionary with ``__key__`` (UUID v1 for sortable keys) and
|
|
309
|
+
``msgpack`` (packed sample data) fields suitable for WebDataset.
|
|
310
|
+
|
|
311
|
+
Note:
|
|
312
|
+
TODO: Expand to allow specifying explicit ``__key__`` values.
|
|
313
|
+
"""
|
|
208
314
|
return {
|
|
209
315
|
# Generates a UUID that is timelike-sortable
|
|
210
316
|
'__key__': str( uuid.uuid1( 0, 0 ) ),
|
|
@@ -212,30 +318,86 @@ class PackableSample( ABC ):
|
|
|
212
318
|
}
|
|
213
319
|
|
|
214
320
|
def _batch_aggregate( xs: Sequence ):
|
|
321
|
+
"""Aggregate a sequence of values into a batch-appropriate format.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
xs: A sequence of values to aggregate. If the first element is a numpy
|
|
325
|
+
array, all elements are stacked into a single array. Otherwise,
|
|
326
|
+
returns a list.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
A numpy array (if elements are arrays) or a list (otherwise).
|
|
330
|
+
"""
|
|
215
331
|
|
|
216
332
|
if not xs:
|
|
217
333
|
# Empty sequence
|
|
218
334
|
return []
|
|
219
335
|
|
|
220
|
-
# Aggregate
|
|
336
|
+
# Aggregate
|
|
221
337
|
if isinstance( xs[0], np.ndarray ):
|
|
222
338
|
return np.array( list( xs ) )
|
|
223
339
|
|
|
224
340
|
return list( xs )
|
|
225
341
|
|
|
226
342
|
class SampleBatch( Generic[DT] ):
|
|
343
|
+
"""A batch of samples with automatic attribute aggregation.
|
|
344
|
+
|
|
345
|
+
This class wraps a sequence of samples and provides magic ``__getattr__``
|
|
346
|
+
access to aggregate sample attributes. When you access an attribute that
|
|
347
|
+
exists on the sample type, it automatically aggregates values across all
|
|
348
|
+
samples in the batch.
|
|
349
|
+
|
|
350
|
+
NDArray fields are stacked into a numpy array with a batch dimension.
|
|
351
|
+
Other fields are aggregated into a list.
|
|
352
|
+
|
|
353
|
+
Type Parameters:
|
|
354
|
+
DT: The sample type, must derive from ``PackableSample``.
|
|
355
|
+
|
|
356
|
+
Attributes:
|
|
357
|
+
samples: The list of sample instances in this batch.
|
|
358
|
+
|
|
359
|
+
Example:
|
|
360
|
+
>>> batch = SampleBatch[MyData]([sample1, sample2, sample3])
|
|
361
|
+
>>> batch.embeddings # Returns stacked numpy array of shape (3, ...)
|
|
362
|
+
>>> batch.names # Returns list of names
|
|
363
|
+
"""
|
|
227
364
|
|
|
228
365
|
def __init__( self, samples: Sequence[DT] ):
|
|
229
|
-
"""
|
|
366
|
+
"""Create a batch from a sequence of samples.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
samples: A sequence of sample instances to aggregate into a batch.
|
|
370
|
+
Each sample must be an instance of a type derived from
|
|
371
|
+
``PackableSample``.
|
|
372
|
+
"""
|
|
230
373
|
self.samples = list( samples )
|
|
231
374
|
self._aggregate_cache = dict()
|
|
232
375
|
|
|
233
376
|
@property
|
|
234
377
|
def sample_type( self ) -> Type:
|
|
235
|
-
"""The type of each sample in this batch
|
|
378
|
+
"""The type of each sample in this batch.
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
The type parameter ``DT`` used when creating this ``SampleBatch[DT]``.
|
|
382
|
+
"""
|
|
236
383
|
return typing.get_args( self.__orig_class__)[0]
|
|
237
384
|
|
|
238
385
|
def __getattr__( self, name ):
|
|
386
|
+
"""Aggregate an attribute across all samples in the batch.
|
|
387
|
+
|
|
388
|
+
This magic method enables attribute-style access to aggregated sample
|
|
389
|
+
fields. Results are cached for efficiency.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
name: The attribute name to aggregate across samples.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
For NDArray fields: a stacked numpy array with batch dimension.
|
|
396
|
+
For other fields: a list of values from each sample.
|
|
397
|
+
|
|
398
|
+
Raises:
|
|
399
|
+
AttributeError: If the attribute doesn't exist on the sample type.
|
|
400
|
+
"""
|
|
239
401
|
# Aggregate named params of sample type
|
|
240
402
|
if name in vars( self.sample_type )['__annotations__']:
|
|
241
403
|
if name not in self._aggregate_cache:
|
|
@@ -243,9 +405,9 @@ class SampleBatch( Generic[DT] ):
|
|
|
243
405
|
[ getattr( x, name )
|
|
244
406
|
for x in self.samples ]
|
|
245
407
|
)
|
|
246
|
-
|
|
408
|
+
|
|
247
409
|
return self._aggregate_cache[name]
|
|
248
|
-
|
|
410
|
+
|
|
249
411
|
raise AttributeError( f'No sample attribute named {name}' )
|
|
250
412
|
|
|
251
413
|
|
|
@@ -268,9 +430,32 @@ RT = TypeVar( 'RT', bound = PackableSample )
|
|
|
268
430
|
# IT = TypeVar( 'IT', default = Any )
|
|
269
431
|
|
|
270
432
|
class Dataset( Generic[ST] ):
|
|
271
|
-
"""A dataset
|
|
272
|
-
|
|
273
|
-
|
|
433
|
+
"""A typed dataset built on WebDataset with lens transformations.
|
|
434
|
+
|
|
435
|
+
This class wraps WebDataset tar archives and provides type-safe iteration
|
|
436
|
+
over samples of a specific ``PackableSample`` type. Samples are stored as
|
|
437
|
+
msgpack-serialized data within WebDataset shards.
|
|
438
|
+
|
|
439
|
+
The dataset supports:
|
|
440
|
+
- Ordered and shuffled iteration
|
|
441
|
+
- Automatic batching with ``SampleBatch``
|
|
442
|
+
- Type transformations via the lens system (``as_type()``)
|
|
443
|
+
- Export to parquet format
|
|
444
|
+
|
|
445
|
+
Type Parameters:
|
|
446
|
+
ST: The sample type for this dataset, must derive from ``PackableSample``.
|
|
447
|
+
|
|
448
|
+
Attributes:
|
|
449
|
+
url: WebDataset brace-notation URL for the tar file(s).
|
|
450
|
+
|
|
451
|
+
Example:
|
|
452
|
+
>>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
|
|
453
|
+
>>> for sample in ds.ordered(batch_size=32):
|
|
454
|
+
... # sample is SampleBatch[MyData] with batch_size samples
|
|
455
|
+
... embeddings = sample.embeddings # shape: (32, ...)
|
|
456
|
+
...
|
|
457
|
+
>>> # Transform to a different view
|
|
458
|
+
>>> ds_view = ds.as_type(MyDataView)
|
|
274
459
|
"""
|
|
275
460
|
|
|
276
461
|
# sample_class: Type = get_parameters( )
|
|
@@ -280,12 +465,23 @@ class Dataset( Generic[ST] ):
|
|
|
280
465
|
|
|
281
466
|
@property
|
|
282
467
|
def sample_type( self ) -> Type:
|
|
283
|
-
"""The type of each returned sample from this
|
|
284
|
-
|
|
468
|
+
"""The type of each returned sample from this dataset's iterator.
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
The type parameter ``ST`` used when creating this ``Dataset[ST]``.
|
|
472
|
+
|
|
473
|
+
Note:
|
|
474
|
+
Extracts the type parameter at runtime using ``__orig_class__``.
|
|
475
|
+
"""
|
|
476
|
+
# NOTE: Linting may fail here due to __orig_class__ being a runtime attribute
|
|
285
477
|
return typing.get_args( self.__orig_class__ )[0]
|
|
286
478
|
@property
|
|
287
479
|
def batch_type( self ) -> Type:
|
|
288
|
-
"""The type of
|
|
480
|
+
"""The type of batches produced by this dataset.
|
|
481
|
+
|
|
482
|
+
Returns:
|
|
483
|
+
``SampleBatch[ST]`` where ``ST`` is this dataset's sample type.
|
|
484
|
+
"""
|
|
289
485
|
# return self.__orig_class__.__args__[1]
|
|
290
486
|
return SampleBatch[self.sample_type]
|
|
291
487
|
|
|
@@ -296,7 +492,13 @@ class Dataset( Generic[ST] ):
|
|
|
296
492
|
#
|
|
297
493
|
|
|
298
494
|
def __init__( self, url: str ) -> None:
|
|
299
|
-
"""
|
|
495
|
+
"""Create a dataset from a WebDataset URL.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
url: WebDataset brace-notation URL pointing to tar files, e.g.,
|
|
499
|
+
``"path/to/file-{000000..000009}.tar"`` for multiple shards or
|
|
500
|
+
``"path/to/file-000000.tar"`` for a single shard.
|
|
501
|
+
"""
|
|
300
502
|
super().__init__()
|
|
301
503
|
self.url = url
|
|
302
504
|
|
|
@@ -304,7 +506,21 @@ class Dataset( Generic[ST] ):
|
|
|
304
506
|
self._output_lens: Lens | None = None
|
|
305
507
|
|
|
306
508
|
def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
|
|
307
|
-
"""
|
|
509
|
+
"""View this dataset through a different sample type using a registered lens.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
other: The target sample type to transform into. Must be a type
|
|
513
|
+
derived from ``PackableSample``.
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
A new ``Dataset`` instance that yields samples of type ``other``
|
|
517
|
+
by applying the appropriate lens transformation from the global
|
|
518
|
+
``LensNetwork`` registry.
|
|
519
|
+
|
|
520
|
+
Raises:
|
|
521
|
+
ValueError: If no registered lens exists between the current
|
|
522
|
+
sample type and the target type.
|
|
523
|
+
"""
|
|
308
524
|
ret = Dataset[other]( self.url )
|
|
309
525
|
# Get the singleton lens registry
|
|
310
526
|
lenses = LensNetwork()
|
|
@@ -384,18 +600,23 @@ class Dataset( Generic[ST] ):
|
|
|
384
600
|
buffer_samples: int = 10_000,
|
|
385
601
|
batch_size: int | None = 1,
|
|
386
602
|
) -> Iterable[ST]:
|
|
387
|
-
"""Iterate over the dataset in random order
|
|
388
|
-
|
|
603
|
+
"""Iterate over the dataset in random order.
|
|
604
|
+
|
|
389
605
|
Args:
|
|
390
|
-
buffer_shards
|
|
391
|
-
|
|
392
|
-
Default:
|
|
393
|
-
|
|
394
|
-
|
|
606
|
+
buffer_shards: Number of shards to buffer for shuffling at the
|
|
607
|
+
shard level. Larger values increase randomness but use more
|
|
608
|
+
memory. Default: 100.
|
|
609
|
+
buffer_samples: Number of samples to buffer for shuffling within
|
|
610
|
+
shards. Larger values increase randomness but use more memory.
|
|
611
|
+
Default: 10,000.
|
|
612
|
+
batch_size: The size of iterated batches. Default: 1. If ``None``,
|
|
613
|
+
iterates over one sample at a time with no batch dimension.
|
|
614
|
+
|
|
395
615
|
Returns:
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
616
|
+
A WebDataset data pipeline that iterates over the dataset in
|
|
617
|
+
randomized order. If ``batch_size`` is not ``None``, yields
|
|
618
|
+
``SampleBatch[ST]`` instances; otherwise yields individual ``ST``
|
|
619
|
+
samples.
|
|
399
620
|
"""
|
|
400
621
|
|
|
401
622
|
if batch_size is None:
|
|
@@ -500,7 +721,16 @@ class Dataset( Generic[ST] ):
|
|
|
500
721
|
# @classmethod
|
|
501
722
|
# TODO replace Any with IT
|
|
502
723
|
def wrap( self, sample: MsgpackRawSample ) -> ST:
|
|
503
|
-
"""Wrap a
|
|
724
|
+
"""Wrap a raw msgpack sample into the appropriate dataset-specific type.
|
|
725
|
+
|
|
726
|
+
Args:
|
|
727
|
+
sample: A dictionary containing at minimum a ``'msgpack'`` key with
|
|
728
|
+
serialized sample bytes.
|
|
729
|
+
|
|
730
|
+
Returns:
|
|
731
|
+
A deserialized sample of type ``ST``, optionally transformed through
|
|
732
|
+
a lens if ``as_type()`` was called.
|
|
733
|
+
"""
|
|
504
734
|
assert 'msgpack' in sample
|
|
505
735
|
assert type( sample['msgpack'] ) == bytes
|
|
506
736
|
|
|
@@ -524,9 +754,19 @@ class Dataset( Generic[ST] ):
|
|
|
524
754
|
# )
|
|
525
755
|
|
|
526
756
|
def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
|
|
527
|
-
"""Wrap a
|
|
528
|
-
|
|
529
|
-
|
|
757
|
+
"""Wrap a batch of raw msgpack samples into a typed SampleBatch.
|
|
758
|
+
|
|
759
|
+
Args:
|
|
760
|
+
batch: A dictionary containing a ``'msgpack'`` key with a list of
|
|
761
|
+
serialized sample bytes.
|
|
762
|
+
|
|
763
|
+
Returns:
|
|
764
|
+
A ``SampleBatch[ST]`` containing deserialized samples, optionally
|
|
765
|
+
transformed through a lens if ``as_type()`` was called.
|
|
766
|
+
|
|
767
|
+
Note:
|
|
768
|
+
This implementation deserializes samples one at a time, then
|
|
769
|
+
aggregates them into a batch.
|
|
530
770
|
"""
|
|
531
771
|
|
|
532
772
|
assert 'msgpack' in batch
|
|
@@ -572,8 +812,30 @@ class Dataset( Generic[ST] ):
|
|
|
572
812
|
# return decorator
|
|
573
813
|
|
|
574
814
|
def packable( cls ):
|
|
575
|
-
"""
|
|
576
|
-
|
|
815
|
+
"""Decorator to convert a regular class into a ``PackableSample``.
|
|
816
|
+
|
|
817
|
+
This decorator transforms a class into a dataclass that inherits from
|
|
818
|
+
``PackableSample``, enabling automatic msgpack serialization/deserialization
|
|
819
|
+
with special handling for NDArray fields.
|
|
820
|
+
|
|
821
|
+
Args:
|
|
822
|
+
cls: The class to convert. Should have type annotations for its fields.
|
|
823
|
+
|
|
824
|
+
Returns:
|
|
825
|
+
A new dataclass that inherits from ``PackableSample`` with the same
|
|
826
|
+
name and annotations as the original class.
|
|
827
|
+
|
|
828
|
+
Example:
|
|
829
|
+
>>> @packable
|
|
830
|
+
... class MyData:
|
|
831
|
+
... name: str
|
|
832
|
+
... values: NDArray
|
|
833
|
+
...
|
|
834
|
+
>>> sample = MyData(name="test", values=np.array([1, 2, 3]))
|
|
835
|
+
>>> bytes_data = sample.packed
|
|
836
|
+
>>> restored = MyData.from_bytes(bytes_data)
|
|
837
|
+
"""
|
|
838
|
+
|
|
577
839
|
##
|
|
578
840
|
|
|
579
841
|
class_name = cls.__name__
|
atdata/lens.py
CHANGED
|
@@ -1,4 +1,42 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Lens-based type transformations for datasets.
|
|
2
|
+
|
|
3
|
+
This module implements a lens system for bidirectional transformations between
|
|
4
|
+
different sample types. Lenses enable viewing a dataset through different type
|
|
5
|
+
schemas without duplicating the underlying data.
|
|
6
|
+
|
|
7
|
+
Key components:
|
|
8
|
+
|
|
9
|
+
- ``Lens``: Bidirectional transformation with getter (S -> V) and optional
|
|
10
|
+
putter (V, S -> S)
|
|
11
|
+
- ``LensNetwork``: Global singleton registry for lens transformations
|
|
12
|
+
- ``@lens``: Decorator to create and register lens transformations
|
|
13
|
+
|
|
14
|
+
Lenses support the functional programming concept of composable, well-behaved
|
|
15
|
+
transformations that satisfy lens laws (GetPut and PutGet).
|
|
16
|
+
|
|
17
|
+
Example:
|
|
18
|
+
>>> @packable
|
|
19
|
+
... class FullData:
|
|
20
|
+
... name: str
|
|
21
|
+
... age: int
|
|
22
|
+
... embedding: NDArray
|
|
23
|
+
...
|
|
24
|
+
>>> @packable
|
|
25
|
+
... class NameOnly:
|
|
26
|
+
... name: str
|
|
27
|
+
...
|
|
28
|
+
>>> @lens
|
|
29
|
+
... def name_view(full: FullData) -> NameOnly:
|
|
30
|
+
... return NameOnly(name=full.name)
|
|
31
|
+
...
|
|
32
|
+
>>> @name_view.putter
|
|
33
|
+
... def name_view_put(view: NameOnly, source: FullData) -> FullData:
|
|
34
|
+
... return FullData(name=view.name, age=source.age,
|
|
35
|
+
... embedding=source.embedding)
|
|
36
|
+
...
|
|
37
|
+
>>> ds = Dataset[FullData]("data.tar")
|
|
38
|
+
>>> ds_names = ds.as_type(NameOnly) # Uses registered lens
|
|
39
|
+
"""
|
|
2
40
|
|
|
3
41
|
##
|
|
4
42
|
# Imports
|
|
@@ -39,24 +77,45 @@ type LensPutter[S, V] = Callable[[V, S], S]
|
|
|
39
77
|
# Shortcut decorators
|
|
40
78
|
|
|
41
79
|
class Lens( Generic[S, V] ):
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
80
|
+
"""A bidirectional transformation between two sample types.
|
|
81
|
+
|
|
82
|
+
A lens provides a way to view and update data of type ``S`` (source) as if
|
|
83
|
+
it were type ``V`` (view). It consists of a getter that transforms ``S -> V``
|
|
84
|
+
and an optional putter that transforms ``(V, S) -> S``, enabling updates to
|
|
85
|
+
the view to be reflected back in the source.
|
|
86
|
+
|
|
87
|
+
Type Parameters:
|
|
88
|
+
S: The source type, must derive from ``PackableSample``.
|
|
89
|
+
V: The view type, must derive from ``PackableSample``.
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
>>> @lens
|
|
93
|
+
... def name_lens(full: FullData) -> NameOnly:
|
|
94
|
+
... return NameOnly(name=full.name)
|
|
95
|
+
...
|
|
96
|
+
>>> @name_lens.putter
|
|
97
|
+
... def name_lens_put(view: NameOnly, source: FullData) -> FullData:
|
|
98
|
+
... return FullData(name=view.name, age=source.age)
|
|
99
|
+
"""
|
|
55
100
|
|
|
56
101
|
def __init__( self, get: LensGetter[S, V],
|
|
57
102
|
put: Optional[LensPutter[S, V]] = None
|
|
58
103
|
) -> None:
|
|
59
|
-
"""
|
|
104
|
+
"""Initialize a lens with a getter and optional putter function.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
get: A function that transforms from source type ``S`` to view type
|
|
108
|
+
``V``. Must accept exactly one parameter annotated with the
|
|
109
|
+
source type.
|
|
110
|
+
put: An optional function that updates the source based on a modified
|
|
111
|
+
view. Takes a view of type ``V`` and original source of type ``S``,
|
|
112
|
+
and returns an updated source of type ``S``. If not provided, a
|
|
113
|
+
trivial putter is used that ignores updates to the view.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
AssertionError: If the getter function doesn't have exactly one
|
|
117
|
+
parameter.
|
|
118
|
+
"""
|
|
60
119
|
##
|
|
61
120
|
|
|
62
121
|
# Check argument validity
|
|
@@ -70,11 +129,11 @@ class Lens( Generic[S, V] ):
|
|
|
70
129
|
functools.update_wrapper( self, get )
|
|
71
130
|
|
|
72
131
|
self.source_type: Type[PackableSample] = input_types[0].annotation
|
|
73
|
-
self.view_type = sig.return_annotation
|
|
132
|
+
self.view_type: Type[PackableSample] = sig.return_annotation
|
|
74
133
|
|
|
75
134
|
# Store the getter
|
|
76
135
|
self._getter = get
|
|
77
|
-
|
|
136
|
+
|
|
78
137
|
# Determine and store the putter
|
|
79
138
|
if put is None:
|
|
80
139
|
# Trivial putter does not update the source
|
|
@@ -86,7 +145,20 @@ class Lens( Generic[S, V] ):
|
|
|
86
145
|
#
|
|
87
146
|
|
|
88
147
|
def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
|
|
89
|
-
"""
|
|
148
|
+
"""Decorator to register a putter function for this lens.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
put: A function that takes a view of type ``V`` and source of type
|
|
152
|
+
``S``, and returns an updated source of type ``S``.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
The putter function, allowing this to be used as a decorator.
|
|
156
|
+
|
|
157
|
+
Example:
|
|
158
|
+
>>> @my_lens.putter
|
|
159
|
+
... def my_lens_put(view: ViewType, source: SourceType) -> SourceType:
|
|
160
|
+
... return SourceType(...)
|
|
161
|
+
"""
|
|
90
162
|
##
|
|
91
163
|
self._putter = put
|
|
92
164
|
return put
|
|
@@ -94,16 +166,39 @@ class Lens( Generic[S, V] ):
|
|
|
94
166
|
# Methods to actually execute transformations
|
|
95
167
|
|
|
96
168
|
def put( self, v: V, s: S ) -> S:
|
|
97
|
-
"""
|
|
169
|
+
"""Update the source based on a modified view.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
v: The modified view of type ``V``.
|
|
173
|
+
s: The original source of type ``S``.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
An updated source of type ``S`` that reflects changes from the view.
|
|
177
|
+
"""
|
|
98
178
|
return self._putter( v, s )
|
|
99
179
|
|
|
100
180
|
def get( self, s: S ) -> V:
|
|
101
|
-
"""
|
|
181
|
+
"""Transform the source into the view type.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
s: The source sample of type ``S``.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
A view of the source as type ``V``.
|
|
188
|
+
"""
|
|
102
189
|
return self( s )
|
|
103
190
|
|
|
104
191
|
# Convenience to enable calling the lens as its getter
|
|
105
|
-
|
|
192
|
+
|
|
106
193
|
def __call__( self, s: S ) -> V:
|
|
194
|
+
"""Apply the lens transformation (same as ``get()``).
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
s: The source sample of type ``S``.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
A view of the source as type ``V``.
|
|
201
|
+
"""
|
|
107
202
|
return self._getter( s )
|
|
108
203
|
|
|
109
204
|
# TODO Figure out how to properly parameterize this
|
|
@@ -124,6 +219,28 @@ class Lens( Generic[S, V] ):
|
|
|
124
219
|
# lens = _lens_factory
|
|
125
220
|
|
|
126
221
|
def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
222
|
+
"""Decorator to create and register a lens transformation.
|
|
223
|
+
|
|
224
|
+
This decorator converts a getter function into a ``Lens`` object and
|
|
225
|
+
automatically registers it in the global ``LensNetwork`` registry.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
f: A getter function that transforms from source type ``S`` to view
|
|
229
|
+
type ``V``. Must have exactly one parameter with a type annotation.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
A ``Lens[S, V]`` object that can be called to apply the transformation
|
|
233
|
+
or decorated with ``@lens_name.putter`` to add a putter function.
|
|
234
|
+
|
|
235
|
+
Example:
|
|
236
|
+
>>> @lens
|
|
237
|
+
... def extract_name(full: FullData) -> NameOnly:
|
|
238
|
+
... return NameOnly(name=full.name)
|
|
239
|
+
...
|
|
240
|
+
>>> @extract_name.putter
|
|
241
|
+
... def extract_name_put(view: NameOnly, source: FullData) -> FullData:
|
|
242
|
+
... return FullData(name=view.name, age=source.age)
|
|
243
|
+
"""
|
|
127
244
|
ret = Lens[S, V]( f )
|
|
128
245
|
_network.register( ret )
|
|
129
246
|
return ret
|
|
@@ -136,25 +253,46 @@ def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
|
136
253
|
# """TODO"""
|
|
137
254
|
|
|
138
255
|
class LensNetwork:
|
|
139
|
-
"""
|
|
256
|
+
"""Global registry for lens transformations between sample types.
|
|
257
|
+
|
|
258
|
+
This class implements a singleton pattern to maintain a global registry of
|
|
259
|
+
all lenses decorated with ``@lens``. It enables looking up transformations
|
|
260
|
+
between different ``PackableSample`` types.
|
|
261
|
+
|
|
262
|
+
Attributes:
|
|
263
|
+
_instance: The singleton instance of this class.
|
|
264
|
+
_registry: Dictionary mapping ``(source_type, view_type)`` tuples to
|
|
265
|
+
their corresponding ``Lens`` objects.
|
|
266
|
+
"""
|
|
140
267
|
|
|
141
268
|
_instance = None
|
|
142
269
|
"""The singleton instance"""
|
|
143
270
|
|
|
144
271
|
def __new__(cls, *args, **kwargs):
|
|
272
|
+
"""Ensure only one instance of LensNetwork exists (singleton pattern)."""
|
|
145
273
|
if cls._instance is None:
|
|
146
274
|
# If no instance exists, create a new one
|
|
147
275
|
cls._instance = super().__new__(cls)
|
|
148
276
|
return cls._instance # Return the existing (or newly created) instance
|
|
149
277
|
|
|
150
278
|
def __init__(self):
|
|
279
|
+
"""Initialize the lens registry (only on first instantiation)."""
|
|
151
280
|
if not hasattr(self, '_initialized'): # Check if already initialized
|
|
152
281
|
self._registry: Dict[LensSignature, Lens] = dict()
|
|
153
282
|
self._initialized = True
|
|
154
283
|
|
|
155
284
|
def register( self, _lens: Lens ):
|
|
156
|
-
"""
|
|
157
|
-
|
|
285
|
+
"""Register a lens as the canonical transformation between two types.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
_lens: The lens to register. Will be stored in the registry under
|
|
289
|
+
the key ``(_lens.source_type, _lens.view_type)``.
|
|
290
|
+
|
|
291
|
+
Note:
|
|
292
|
+
If a lens already exists for the same type pair, it will be
|
|
293
|
+
overwritten.
|
|
294
|
+
"""
|
|
295
|
+
|
|
158
296
|
# sig = inspect.signature( _lens.get )
|
|
159
297
|
# input_types = list( sig.parameters.values() )
|
|
160
298
|
# assert len( input_types ) == 1, \
|
|
@@ -169,13 +307,28 @@ class LensNetwork:
|
|
|
169
307
|
self._registry[_lens.source_type, _lens.view_type] = _lens
|
|
170
308
|
|
|
171
309
|
def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
|
|
172
|
-
"""
|
|
310
|
+
"""Look up the lens transformation between two sample types.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
source: The source sample type (must derive from ``PackableSample``).
|
|
314
|
+
view: The target view type (must derive from ``PackableSample``).
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
The registered ``Lens`` that transforms from ``source`` to ``view``.
|
|
318
|
+
|
|
319
|
+
Raises:
|
|
320
|
+
ValueError: If no lens has been registered for the given type pair.
|
|
321
|
+
|
|
322
|
+
Note:
|
|
323
|
+
Currently only supports direct transformations. Compositional
|
|
324
|
+
transformations (chaining multiple lenses) are not yet implemented.
|
|
325
|
+
"""
|
|
173
326
|
|
|
174
327
|
# TODO Handle compositional closure
|
|
175
328
|
ret = self._registry.get( (source, view), None )
|
|
176
329
|
if ret is None:
|
|
177
330
|
raise ValueError( f'No registered lens from source {source} to view {view}' )
|
|
178
|
-
|
|
331
|
+
|
|
179
332
|
return ret
|
|
180
333
|
|
|
181
334
|
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: atdata
|
|
3
|
+
Version: 0.1.3b4
|
|
4
|
+
Summary: A loose federation of distributed, typed datasets
|
|
5
|
+
Author-email: Maxine Levesque <hello@maxine.science>
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Requires-Dist: fastparquet>=2024.11.0
|
|
9
|
+
Requires-Dist: msgpack>=1.1.2
|
|
10
|
+
Requires-Dist: numpy>=2.3.4
|
|
11
|
+
Requires-Dist: ormsgpack>=1.11.0
|
|
12
|
+
Requires-Dist: pandas>=2.3.3
|
|
13
|
+
Requires-Dist: tqdm>=4.67.1
|
|
14
|
+
Requires-Dist: webdataset>=1.0.2
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
|
|
17
|
+
# atdata
|
|
18
|
+
|
|
19
|
+
[](https://codecov.io/gh/foundation-ac/atdata)
|
|
20
|
+
|
|
21
|
+
A loose federation of distributed, typed datasets built on WebDataset.
|
|
22
|
+
|
|
23
|
+
**atdata** provides a type-safe, composable framework for working with large-scale datasets. It combines the efficiency of WebDataset's tar-based storage with Python's type system and functional programming patterns.
|
|
24
|
+
|
|
25
|
+
## Features
|
|
26
|
+
|
|
27
|
+
- **Typed Samples** - Define dataset schemas using Python dataclasses with automatic msgpack serialization
|
|
28
|
+
- **Lens Transformations** - Bidirectional, composable transformations between different dataset views
|
|
29
|
+
- **Automatic Batching** - Smart batch aggregation with numpy array stacking
|
|
30
|
+
- **WebDataset Integration** - Efficient storage and streaming for large-scale datasets
|
|
31
|
+
|
|
32
|
+
## Installation
|
|
33
|
+
|
|
34
|
+
```bash
|
|
35
|
+
pip install atdata
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
Requires Python 3.12 or later.
|
|
39
|
+
|
|
40
|
+
## Quick Start
|
|
41
|
+
|
|
42
|
+
### Defining Sample Types
|
|
43
|
+
|
|
44
|
+
Use the `@packable` decorator to create typed dataset samples:
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
import atdata
|
|
48
|
+
from numpy.typing import NDArray
|
|
49
|
+
|
|
50
|
+
@atdata.packable
|
|
51
|
+
class ImageSample:
|
|
52
|
+
image: NDArray
|
|
53
|
+
label: str
|
|
54
|
+
metadata: dict
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
### Creating Datasets
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
# Create a dataset
|
|
61
|
+
dataset = atdata.Dataset[ImageSample]("path/to/data-{000000..000009}.tar")
|
|
62
|
+
|
|
63
|
+
# Iterate over samples in order
|
|
64
|
+
for sample in dataset.ordered(batch_size=None):
|
|
65
|
+
print(f"Label: {sample.label}, Image shape: {sample.image.shape}")
|
|
66
|
+
|
|
67
|
+
# Iterate with shuffling and batching
|
|
68
|
+
for batch in dataset.shuffled(batch_size=32):
|
|
69
|
+
# batch.image is automatically stacked into shape (32, ...)
|
|
70
|
+
# batch.label is a list of 32 labels
|
|
71
|
+
process_batch(batch.image, batch.label)
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
### Lens Transformations
|
|
75
|
+
|
|
76
|
+
Define reusable transformations between sample types:
|
|
77
|
+
|
|
78
|
+
```python
|
|
79
|
+
@atdata.packable
|
|
80
|
+
class ProcessedSample:
|
|
81
|
+
features: NDArray
|
|
82
|
+
label: str
|
|
83
|
+
|
|
84
|
+
@atdata.lens
|
|
85
|
+
def preprocess(sample: ImageSample) -> ProcessedSample:
|
|
86
|
+
features = extract_features(sample.image)
|
|
87
|
+
return ProcessedSample(features=features, label=sample.label)
|
|
88
|
+
|
|
89
|
+
# Apply lens to view dataset as ProcessedSample
|
|
90
|
+
processed_ds = dataset.as_type(ProcessedSample)
|
|
91
|
+
|
|
92
|
+
for sample in processed_ds.ordered(batch_size=None):
|
|
93
|
+
# sample is now a ProcessedSample
|
|
94
|
+
print(sample.features.shape)
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
## Core Concepts
|
|
98
|
+
|
|
99
|
+
### PackableSample
|
|
100
|
+
|
|
101
|
+
Base class for serializable samples. Fields annotated as `NDArray` are automatically handled:
|
|
102
|
+
|
|
103
|
+
```python
|
|
104
|
+
@atdata.packable
|
|
105
|
+
class MySample:
|
|
106
|
+
array_field: NDArray # Automatically serialized
|
|
107
|
+
optional_array: NDArray | None
|
|
108
|
+
regular_field: str
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
### Lens
|
|
112
|
+
|
|
113
|
+
Bidirectional transformations with getter/putter semantics:
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
@atdata.lens
|
|
117
|
+
def my_lens(source: SourceType) -> ViewType:
|
|
118
|
+
# Transform source -> view
|
|
119
|
+
return ViewType(...)
|
|
120
|
+
|
|
121
|
+
@my_lens.putter
|
|
122
|
+
def my_lens_put(view: ViewType, source: SourceType) -> SourceType:
|
|
123
|
+
# Transform view -> source
|
|
124
|
+
return SourceType(...)
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
### Dataset URLs
|
|
128
|
+
|
|
129
|
+
Uses WebDataset brace expansion for sharded datasets:
|
|
130
|
+
|
|
131
|
+
- Single file: `"data/dataset-000000.tar"`
|
|
132
|
+
- Multiple shards: `"data/dataset-{000000..000099}.tar"`
|
|
133
|
+
- Multiple patterns: `"data/{train,val}/dataset-{000000..000009}.tar"`
|
|
134
|
+
|
|
135
|
+
## Development
|
|
136
|
+
|
|
137
|
+
### Setup
|
|
138
|
+
|
|
139
|
+
```bash
|
|
140
|
+
# Install uv if not already available
|
|
141
|
+
python -m pip install uv
|
|
142
|
+
|
|
143
|
+
# Install dependencies
|
|
144
|
+
uv sync
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
### Testing
|
|
148
|
+
|
|
149
|
+
```bash
|
|
150
|
+
# Run all tests with coverage
|
|
151
|
+
pytest
|
|
152
|
+
|
|
153
|
+
# Run specific test file
|
|
154
|
+
pytest tests/test_dataset.py
|
|
155
|
+
|
|
156
|
+
# Run single test
|
|
157
|
+
pytest tests/test_lens.py::test_lens
|
|
158
|
+
```
|
|
159
|
+
|
|
160
|
+
### Building
|
|
161
|
+
|
|
162
|
+
```bash
|
|
163
|
+
uv build
|
|
164
|
+
```
|
|
165
|
+
|
|
166
|
+
## Contributing
|
|
167
|
+
|
|
168
|
+
Contributions are welcome! This project is in beta, so the API may still evolve.
|
|
169
|
+
|
|
170
|
+
## License
|
|
171
|
+
|
|
172
|
+
This project is licensed under the Mozilla Public License 2.0. See [LICENSE](LICENSE) for details.
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
atdata/__init__.py,sha256=_363ZuJfwbBQTMYsoKOiyoBe4AHr3iplK-EQyrAeTdg,1545
|
|
2
|
+
atdata/_helpers.py,sha256=RvA-Xlj3AvgSWuiPdS8YTBp8AJT-u32BaLpxsu4PIIA,1564
|
|
3
|
+
atdata/dataset.py,sha256=O_7b3ub_M4IMRuhv95oz1PVFdsOhNiyXgtY8NphPdBk,27842
|
|
4
|
+
atdata/lens.py,sha256=ynn1DQkR89eRL6JV9EsawuPY9JTrZ67pAX4cRvZ6UVk,11157
|
|
5
|
+
atdata-0.1.3b4.dist-info/METADATA,sha256=SdZSI_SonE-pt4nhmFh5bz9zKD79wT2CKXKFxrTfvgc,4162
|
|
6
|
+
atdata-0.1.3b4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
atdata-0.1.3b4.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
|
|
8
|
+
atdata-0.1.3b4.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
9
|
+
atdata-0.1.3b4.dist-info/RECORD,,
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: atdata
|
|
3
|
-
Version: 0.1.3b3
|
|
4
|
-
Summary: A loose federation of distributed, typed datasets
|
|
5
|
-
Author-email: Maxine Levesque <hello@maxine.science>
|
|
6
|
-
License-File: LICENSE
|
|
7
|
-
Requires-Python: >=3.12
|
|
8
|
-
Requires-Dist: fastparquet>=2024.11.0
|
|
9
|
-
Requires-Dist: msgpack>=1.1.2
|
|
10
|
-
Requires-Dist: numpy>=2.3.4
|
|
11
|
-
Requires-Dist: ormsgpack>=1.11.0
|
|
12
|
-
Requires-Dist: pandas>=2.3.3
|
|
13
|
-
Requires-Dist: tqdm>=4.67.1
|
|
14
|
-
Requires-Dist: webdataset>=1.0.2
|
|
15
|
-
Description-Content-Type: text/markdown
|
|
16
|
-
|
|
17
|
-
# atdata
|
|
18
|
-
A loose federation of distributed, typed datasets
|
atdata-0.1.3b3.dist-info/RECORD
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
atdata/__init__.py,sha256=V2qBg7i2mfCNG9nww6Gi_fDp7iwolDMrNzhmNO6VA7M,233
|
|
2
|
-
atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
|
|
3
|
-
atdata/dataset.py,sha256=qyAiKSjjYqFVWmaLz5LAIZ3_YVHbm5lg32zmctqjjlE,18085
|
|
4
|
-
atdata/lens.py,sha256=HvXuRqYTeJBpMyIQVdGZXxEvbGKBuFCF8lbiib4TqsA,5306
|
|
5
|
-
atdata-0.1.3b3.dist-info/METADATA,sha256=jrGZ592QbkJdZCq8FLmXOznQ0LkTUyUkqLVIH3ZRj4U,529
|
|
6
|
-
atdata-0.1.3b3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
atdata-0.1.3b3.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
|
|
8
|
-
atdata-0.1.3b3.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
9
|
-
atdata-0.1.3b3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|