atdata 0.1.2a3__py3-none-any.whl → 0.1.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +5 -0
- atdata/dataset.py +102 -22
- atdata/lens.py +122 -0
- {atdata-0.1.2a3.dist-info → atdata-0.1.2b1.dist-info}/METADATA +4 -1
- atdata-0.1.2b1.dist-info/RECORD +9 -0
- atdata-0.1.2a3.dist-info/RECORD +0 -8
- {atdata-0.1.2a3.dist-info → atdata-0.1.2b1.dist-info}/WHEEL +0 -0
- {atdata-0.1.2a3.dist-info → atdata-0.1.2b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.1.2a3.dist-info → atdata-0.1.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/__init__.py
CHANGED
atdata/dataset.py
CHANGED
|
@@ -5,21 +5,29 @@
|
|
|
5
5
|
|
|
6
6
|
import webdataset as wds
|
|
7
7
|
|
|
8
|
-
import
|
|
9
|
-
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
10
9
|
import uuid
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
10
|
+
import functools
|
|
11
|
+
from dataclasses import (
|
|
12
|
+
dataclass,
|
|
13
|
+
asdict,
|
|
14
|
+
)
|
|
14
15
|
from abc import (
|
|
15
16
|
ABC,
|
|
16
17
|
abstractmethod,
|
|
17
18
|
)
|
|
19
|
+
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
import numpy as np
|
|
22
|
+
import pandas as pd
|
|
23
|
+
|
|
18
24
|
from typing import (
|
|
19
25
|
Any,
|
|
20
26
|
Optional,
|
|
21
27
|
Dict,
|
|
22
28
|
Sequence,
|
|
29
|
+
Iterable,
|
|
30
|
+
Callable,
|
|
23
31
|
#
|
|
24
32
|
Self,
|
|
25
33
|
Generic,
|
|
@@ -45,9 +53,14 @@ from . import _helpers as eh
|
|
|
45
53
|
##
|
|
46
54
|
# Typing help
|
|
47
55
|
|
|
56
|
+
Pathlike = str | Path
|
|
57
|
+
|
|
48
58
|
WDSRawSample: TypeAlias = Dict[str, Any]
|
|
49
59
|
WDSRawBatch: TypeAlias = Dict[str, Any]
|
|
50
60
|
|
|
61
|
+
SampleExportRow: TypeAlias = Dict[str, Any]
|
|
62
|
+
SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
|
|
63
|
+
|
|
51
64
|
|
|
52
65
|
##
|
|
53
66
|
# Main base classes
|
|
@@ -94,6 +107,7 @@ def _make_packable( x ):
|
|
|
94
107
|
return eh.array_to_bytes( x )
|
|
95
108
|
return x
|
|
96
109
|
|
|
110
|
+
@dataclass
|
|
97
111
|
class PackableSample( ABC ):
|
|
98
112
|
"""A sample that can be packed and unpacked with msgpack"""
|
|
99
113
|
|
|
@@ -235,6 +249,7 @@ class Dataset( Generic[ST] ):
|
|
|
235
249
|
@property
|
|
236
250
|
def sample_type( self ) -> Type:
|
|
237
251
|
"""The type of each returned sample from this `Dataset`'s iterator"""
|
|
252
|
+
# TODO Figure out why linting fails here
|
|
238
253
|
return self.__orig_class__.__args__[0]
|
|
239
254
|
@property
|
|
240
255
|
def batch_type( self ) -> Type:
|
|
@@ -286,7 +301,7 @@ class Dataset( Generic[ST] ):
|
|
|
286
301
|
|
|
287
302
|
def ordered( self,
|
|
288
303
|
batch_size: int | None = 1,
|
|
289
|
-
) ->
|
|
304
|
+
) -> Iterable[ST]:
|
|
290
305
|
"""Iterate over the dataset in order
|
|
291
306
|
|
|
292
307
|
Args:
|
|
@@ -325,7 +340,7 @@ class Dataset( Generic[ST] ):
|
|
|
325
340
|
buffer_shards: int = 100,
|
|
326
341
|
buffer_samples: int = 10_000,
|
|
327
342
|
batch_size: int | None = 1,
|
|
328
|
-
) ->
|
|
343
|
+
) -> Iterable[ST]:
|
|
329
344
|
"""Iterate over the dataset in random order
|
|
330
345
|
|
|
331
346
|
Args:
|
|
@@ -366,6 +381,64 @@ class Dataset( Generic[ST] ):
|
|
|
366
381
|
wds.batched( batch_size ),
|
|
367
382
|
wds.map( self.wrap_batch ),
|
|
368
383
|
)
|
|
384
|
+
|
|
385
|
+
# TODO Rewrite to eliminate `pandas` dependency directly calling
|
|
386
|
+
# `fastparquet`
|
|
387
|
+
def to_parquet( self, path: Pathlike,
|
|
388
|
+
sample_map: Optional[SampleExportMap] = None,
|
|
389
|
+
maxcount: Optional[int] = None,
|
|
390
|
+
**kwargs,
|
|
391
|
+
):
|
|
392
|
+
"""Save dataset contents to a `parquet` file at `path`
|
|
393
|
+
|
|
394
|
+
`kwargs` sent to `pandas.to_parquet`
|
|
395
|
+
"""
|
|
396
|
+
##
|
|
397
|
+
|
|
398
|
+
# Normalize args
|
|
399
|
+
path = Path( path )
|
|
400
|
+
if sample_map is None:
|
|
401
|
+
sample_map = asdict
|
|
402
|
+
|
|
403
|
+
verbose = kwargs.get( 'verbose', False )
|
|
404
|
+
|
|
405
|
+
it = self.ordered( batch_size = None )
|
|
406
|
+
if verbose:
|
|
407
|
+
it = tqdm( it )
|
|
408
|
+
|
|
409
|
+
#
|
|
410
|
+
|
|
411
|
+
if maxcount is None:
|
|
412
|
+
# Load and save full dataset
|
|
413
|
+
df = pd.DataFrame( [ sample_map( x )
|
|
414
|
+
for x in self.ordered( batch_size = None ) ] )
|
|
415
|
+
df.to_parquet( path, **kwargs )
|
|
416
|
+
|
|
417
|
+
else:
|
|
418
|
+
# Load and save dataset in segments of size `maxcount`
|
|
419
|
+
|
|
420
|
+
cur_segment = 0
|
|
421
|
+
cur_buffer = []
|
|
422
|
+
path_template = (path.parent / f'{path.stem}-%06d.{path.suffix}').as_posix()
|
|
423
|
+
|
|
424
|
+
for x in self.ordered( batch_size = None ):
|
|
425
|
+
cur_buffer.append( sample_map( x ) )
|
|
426
|
+
|
|
427
|
+
if len( cur_buffer ) >= maxcount:
|
|
428
|
+
# Write current segment
|
|
429
|
+
cur_path = path_template.format( cur_segment )
|
|
430
|
+
df = pd.DataFrame( cur_buffer )
|
|
431
|
+
df.to_parquet( cur_path, **kwargs )
|
|
432
|
+
|
|
433
|
+
cur_segment += 1
|
|
434
|
+
cur_buffer = []
|
|
435
|
+
|
|
436
|
+
if len( cur_buffer ) > 0:
|
|
437
|
+
# Write one last segment with remainder
|
|
438
|
+
cur_path = path_template.format( cur_segment )
|
|
439
|
+
df = pd.DataFrame( cur_buffer )
|
|
440
|
+
df.to_parquet( cur_path, **kwargs )
|
|
441
|
+
|
|
369
442
|
|
|
370
443
|
# Implemented by specific subclasses
|
|
371
444
|
|
|
@@ -390,18 +463,18 @@ class Dataset( Generic[ST] ):
|
|
|
390
463
|
|
|
391
464
|
return self.sample_type.from_bytes( sample['msgpack'] )
|
|
392
465
|
|
|
393
|
-
try:
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
466
|
+
# try:
|
|
467
|
+
# assert type( sample ) == dict
|
|
468
|
+
# return cls.sample_class( **{
|
|
469
|
+
# k: v
|
|
470
|
+
# for k, v in sample.items() if k != '__key__'
|
|
471
|
+
# } )
|
|
399
472
|
|
|
400
|
-
except Exception as e:
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
473
|
+
# except Exception as e:
|
|
474
|
+
# # Sample constructor failed -- revert to default
|
|
475
|
+
# return AnySample(
|
|
476
|
+
# value = sample,
|
|
477
|
+
# )
|
|
405
478
|
|
|
406
479
|
def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
|
|
407
480
|
"""Wrap a `batch` of samples into the appropriate dataset-specific type
|
|
@@ -449,15 +522,22 @@ def packable( cls ):
|
|
|
449
522
|
|
|
450
523
|
##
|
|
451
524
|
|
|
525
|
+
class_name = cls.__name__
|
|
526
|
+
class_annotations = cls.__annotations__
|
|
527
|
+
|
|
528
|
+
# Add in dataclass niceness to original class
|
|
452
529
|
as_dataclass = dataclass( cls )
|
|
453
530
|
|
|
454
|
-
|
|
531
|
+
# This triggers a bunch of behind-the-scenes stuff for the newly annotated class
|
|
532
|
+
@dataclass
|
|
533
|
+
class as_packable( as_dataclass, PackableSample ):
|
|
455
534
|
def __post_init__( self ):
|
|
456
535
|
return PackableSample.__post_init__( self )
|
|
457
536
|
|
|
458
|
-
|
|
459
|
-
as_packable.
|
|
537
|
+
# TODO This doesn't properly carry over the original
|
|
538
|
+
as_packable.__name__ = class_name
|
|
539
|
+
as_packable.__annotations__ = class_annotations
|
|
460
540
|
|
|
461
541
|
##
|
|
462
|
-
|
|
542
|
+
|
|
463
543
|
return as_packable
|
atdata/lens.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Lenses between typed datasets"""
|
|
2
|
+
|
|
3
|
+
##
|
|
4
|
+
# Imports
|
|
5
|
+
|
|
6
|
+
from .dataset import PackableSample
|
|
7
|
+
|
|
8
|
+
import functools
|
|
9
|
+
import inspect
|
|
10
|
+
|
|
11
|
+
from typing import (
|
|
12
|
+
TypeAlias,
|
|
13
|
+
Type,
|
|
14
|
+
TypeVar,
|
|
15
|
+
Tuple,
|
|
16
|
+
Dict,
|
|
17
|
+
Callable,
|
|
18
|
+
Optional,
|
|
19
|
+
Generic,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
##
|
|
24
|
+
# Typing helpers
|
|
25
|
+
|
|
26
|
+
DatasetType: TypeAlias = Type[PackableSample]
|
|
27
|
+
LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
|
|
28
|
+
|
|
29
|
+
S = TypeVar( 'S', bound = PackableSample )
|
|
30
|
+
V = TypeVar( 'V', bound = PackableSample )
|
|
31
|
+
type LensGetter[S, V] = Callable[[S], V]
|
|
32
|
+
type LensPutter[S, V] = Callable[[V, S], S]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
##
|
|
36
|
+
# Shortcut decorators
|
|
37
|
+
|
|
38
|
+
class Lens( Generic[S, V] ):
|
|
39
|
+
"""TODO"""
|
|
40
|
+
|
|
41
|
+
def __init__( self, get: LensGetter[S, V],
|
|
42
|
+
put: Optional[LensPutter[S, V]] = None
|
|
43
|
+
) -> None:
|
|
44
|
+
"""TODO"""
|
|
45
|
+
##
|
|
46
|
+
|
|
47
|
+
# Update
|
|
48
|
+
functools.update_wrapper( self, get )
|
|
49
|
+
|
|
50
|
+
# Store the getter
|
|
51
|
+
self._getter = get
|
|
52
|
+
|
|
53
|
+
# Determine and store the putter
|
|
54
|
+
if put is None:
|
|
55
|
+
# Trivial putter does not update the source
|
|
56
|
+
def _trivial_put( v: V, s: S ) -> S:
|
|
57
|
+
return s
|
|
58
|
+
put = _trivial_put
|
|
59
|
+
|
|
60
|
+
self._putter = put
|
|
61
|
+
|
|
62
|
+
# Register this lens for this type signature
|
|
63
|
+
|
|
64
|
+
sig = inspect.signature( get )
|
|
65
|
+
input_types = list( sig.parameters.values() )
|
|
66
|
+
assert len( input_types ) == 1, \
|
|
67
|
+
'Wrong number of input args for lens: should only have one'
|
|
68
|
+
|
|
69
|
+
input_type = input_types[0].annotation
|
|
70
|
+
output_type = sig.return_annotation
|
|
71
|
+
|
|
72
|
+
_registered_lenses[(input_type, output_type)] = self
|
|
73
|
+
print( _registered_lenses )
|
|
74
|
+
|
|
75
|
+
#
|
|
76
|
+
|
|
77
|
+
def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
|
|
78
|
+
"""TODO"""
|
|
79
|
+
##
|
|
80
|
+
self._putter = put
|
|
81
|
+
return put
|
|
82
|
+
|
|
83
|
+
def put( self, v: V, s: S ) -> S:
|
|
84
|
+
"""TODO"""
|
|
85
|
+
return self._putter( v, s )
|
|
86
|
+
|
|
87
|
+
def get( self, s: S ) -> V:
|
|
88
|
+
"""TODO"""
|
|
89
|
+
return self( s )
|
|
90
|
+
|
|
91
|
+
#
|
|
92
|
+
|
|
93
|
+
def __call__( self, s: S ) -> V:
|
|
94
|
+
return self._getter( s )
|
|
95
|
+
|
|
96
|
+
def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
97
|
+
"""Register the annotated function `f` as the getter of a sample lens"""
|
|
98
|
+
return Lens[S, V]( f )
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
##
|
|
102
|
+
# Global registration of used lenses
|
|
103
|
+
|
|
104
|
+
_registered_lenses: Dict[LensSignature, Lens] = dict()
|
|
105
|
+
"""TODO"""
|
|
106
|
+
|
|
107
|
+
# def lens( f: LensPutter ) -> Lens:
|
|
108
|
+
# """Register the annotated function `f` as a sample lens"""
|
|
109
|
+
# ##
|
|
110
|
+
|
|
111
|
+
# sig = inspect.signature( f )
|
|
112
|
+
|
|
113
|
+
# input_types = list( sig.parameters.values() )
|
|
114
|
+
# output_type = sig.return_annotation
|
|
115
|
+
|
|
116
|
+
# _registered_lenses[]
|
|
117
|
+
|
|
118
|
+
# f.lens = Lens(
|
|
119
|
+
|
|
120
|
+
# )
|
|
121
|
+
|
|
122
|
+
# return f
|
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: atdata
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2b1
|
|
4
4
|
Summary: A loose federation of distributed, typed datasets
|
|
5
5
|
Author-email: Maxine Levesque <hello@maxine.science>
|
|
6
6
|
License-File: LICENSE
|
|
7
7
|
Requires-Python: >=3.12
|
|
8
|
+
Requires-Dist: fastparquet>=2024.11.0
|
|
8
9
|
Requires-Dist: msgpack>=1.1.2
|
|
9
10
|
Requires-Dist: numpy>=2.3.4
|
|
10
11
|
Requires-Dist: ormsgpack>=1.11.0
|
|
12
|
+
Requires-Dist: pandas>=2.3.3
|
|
13
|
+
Requires-Dist: tqdm>=4.67.1
|
|
11
14
|
Requires-Dist: webdataset>=1.0.2
|
|
12
15
|
Description-Content-Type: text/markdown
|
|
13
16
|
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
atdata/__init__.py,sha256=YnlohxQwTUK6V84XHm2gdeCQH5sIrTHVLSApB-nt_z8,216
|
|
2
|
+
atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
|
|
3
|
+
atdata/dataset.py,sha256=pBaND2D33JiJoiL9CtCTBa3octtifa21P19K076sW3Q,15905
|
|
4
|
+
atdata/lens.py,sha256=ikExMWdGP3QH-bEuUDNAYO_ZjeaKJTfL9lpaN9CrRB4,2624
|
|
5
|
+
atdata-0.1.2b1.dist-info/METADATA,sha256=WtHM3N0kMxJKwPXl5cluRNnvk49WU_e72lVhmcInraY,529
|
|
6
|
+
atdata-0.1.2b1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
atdata-0.1.2b1.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
|
|
8
|
+
atdata-0.1.2b1.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
9
|
+
atdata-0.1.2b1.dist-info/RECORD,,
|
atdata-0.1.2a3.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
atdata/__init__.py,sha256=jPZVd_6UIo0DSbCnXAnYZ2eMwHYzOk--5vtEDTZvwqw,173
|
|
2
|
-
atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
|
|
3
|
-
atdata/dataset.py,sha256=HXctGwIbU5kr2pqiQCYDyGP1mkph1gIt-x1_PRtWyew,13372
|
|
4
|
-
atdata-0.1.2a3.dist-info/METADATA,sha256=Jj5vP4NW-HtckIsPRzzpXVQXgcQ8HaFSGehdAu4Vfbo,434
|
|
5
|
-
atdata-0.1.2a3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
atdata-0.1.2a3.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
|
|
7
|
-
atdata-0.1.2a3.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
8
|
-
atdata-0.1.2a3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|