atdata 0.1.2a4__py3-none-any.whl → 0.1.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +5 -0
- atdata/dataset.py +97 -20
- atdata/lens.py +122 -0
- {atdata-0.1.2a4.dist-info → atdata-0.1.2b1.dist-info}/METADATA +4 -1
- atdata-0.1.2b1.dist-info/RECORD +9 -0
- atdata-0.1.2a4.dist-info/RECORD +0 -8
- {atdata-0.1.2a4.dist-info → atdata-0.1.2b1.dist-info}/WHEEL +0 -0
- {atdata-0.1.2a4.dist-info → atdata-0.1.2b1.dist-info}/entry_points.txt +0 -0
- {atdata-0.1.2a4.dist-info → atdata-0.1.2b1.dist-info}/licenses/LICENSE +0 -0
atdata/__init__.py
CHANGED
atdata/dataset.py
CHANGED
|
@@ -5,21 +5,29 @@
|
|
|
5
5
|
|
|
6
6
|
import webdataset as wds
|
|
7
7
|
|
|
8
|
-
import
|
|
9
|
-
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
10
9
|
import uuid
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
10
|
+
import functools
|
|
11
|
+
from dataclasses import (
|
|
12
|
+
dataclass,
|
|
13
|
+
asdict,
|
|
14
|
+
)
|
|
14
15
|
from abc import (
|
|
15
16
|
ABC,
|
|
16
17
|
abstractmethod,
|
|
17
18
|
)
|
|
19
|
+
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
import numpy as np
|
|
22
|
+
import pandas as pd
|
|
23
|
+
|
|
18
24
|
from typing import (
|
|
19
25
|
Any,
|
|
20
26
|
Optional,
|
|
21
27
|
Dict,
|
|
22
28
|
Sequence,
|
|
29
|
+
Iterable,
|
|
30
|
+
Callable,
|
|
23
31
|
#
|
|
24
32
|
Self,
|
|
25
33
|
Generic,
|
|
@@ -45,9 +53,14 @@ from . import _helpers as eh
|
|
|
45
53
|
##
|
|
46
54
|
# Typing help
|
|
47
55
|
|
|
56
|
+
Pathlike = str | Path
|
|
57
|
+
|
|
48
58
|
WDSRawSample: TypeAlias = Dict[str, Any]
|
|
49
59
|
WDSRawBatch: TypeAlias = Dict[str, Any]
|
|
50
60
|
|
|
61
|
+
SampleExportRow: TypeAlias = Dict[str, Any]
|
|
62
|
+
SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
|
|
63
|
+
|
|
51
64
|
|
|
52
65
|
##
|
|
53
66
|
# Main base classes
|
|
@@ -94,6 +107,7 @@ def _make_packable( x ):
|
|
|
94
107
|
return eh.array_to_bytes( x )
|
|
95
108
|
return x
|
|
96
109
|
|
|
110
|
+
@dataclass
|
|
97
111
|
class PackableSample( ABC ):
|
|
98
112
|
"""A sample that can be packed and unpacked with msgpack"""
|
|
99
113
|
|
|
@@ -235,6 +249,7 @@ class Dataset( Generic[ST] ):
|
|
|
235
249
|
@property
|
|
236
250
|
def sample_type( self ) -> Type:
|
|
237
251
|
"""The type of each returned sample from this `Dataset`'s iterator"""
|
|
252
|
+
# TODO Figure out why linting fails here
|
|
238
253
|
return self.__orig_class__.__args__[0]
|
|
239
254
|
@property
|
|
240
255
|
def batch_type( self ) -> Type:
|
|
@@ -286,7 +301,7 @@ class Dataset( Generic[ST] ):
|
|
|
286
301
|
|
|
287
302
|
def ordered( self,
|
|
288
303
|
batch_size: int | None = 1,
|
|
289
|
-
) ->
|
|
304
|
+
) -> Iterable[ST]:
|
|
290
305
|
"""Iterate over the dataset in order
|
|
291
306
|
|
|
292
307
|
Args:
|
|
@@ -325,7 +340,7 @@ class Dataset( Generic[ST] ):
|
|
|
325
340
|
buffer_shards: int = 100,
|
|
326
341
|
buffer_samples: int = 10_000,
|
|
327
342
|
batch_size: int | None = 1,
|
|
328
|
-
) ->
|
|
343
|
+
) -> Iterable[ST]:
|
|
329
344
|
"""Iterate over the dataset in random order
|
|
330
345
|
|
|
331
346
|
Args:
|
|
@@ -366,6 +381,64 @@ class Dataset( Generic[ST] ):
|
|
|
366
381
|
wds.batched( batch_size ),
|
|
367
382
|
wds.map( self.wrap_batch ),
|
|
368
383
|
)
|
|
384
|
+
|
|
385
|
+
# TODO Rewrite to eliminate `pandas` dependency directly calling
|
|
386
|
+
# `fastparquet`
|
|
387
|
+
def to_parquet( self, path: Pathlike,
|
|
388
|
+
sample_map: Optional[SampleExportMap] = None,
|
|
389
|
+
maxcount: Optional[int] = None,
|
|
390
|
+
**kwargs,
|
|
391
|
+
):
|
|
392
|
+
"""Save dataset contents to a `parquet` file at `path`
|
|
393
|
+
|
|
394
|
+
`kwargs` sent to `pandas.to_parquet`
|
|
395
|
+
"""
|
|
396
|
+
##
|
|
397
|
+
|
|
398
|
+
# Normalize args
|
|
399
|
+
path = Path( path )
|
|
400
|
+
if sample_map is None:
|
|
401
|
+
sample_map = asdict
|
|
402
|
+
|
|
403
|
+
verbose = kwargs.get( 'verbose', False )
|
|
404
|
+
|
|
405
|
+
it = self.ordered( batch_size = None )
|
|
406
|
+
if verbose:
|
|
407
|
+
it = tqdm( it )
|
|
408
|
+
|
|
409
|
+
#
|
|
410
|
+
|
|
411
|
+
if maxcount is None:
|
|
412
|
+
# Load and save full dataset
|
|
413
|
+
df = pd.DataFrame( [ sample_map( x )
|
|
414
|
+
for x in self.ordered( batch_size = None ) ] )
|
|
415
|
+
df.to_parquet( path, **kwargs )
|
|
416
|
+
|
|
417
|
+
else:
|
|
418
|
+
# Load and save dataset in segments of size `maxcount`
|
|
419
|
+
|
|
420
|
+
cur_segment = 0
|
|
421
|
+
cur_buffer = []
|
|
422
|
+
path_template = (path.parent / f'{path.stem}-%06d.{path.suffix}').as_posix()
|
|
423
|
+
|
|
424
|
+
for x in self.ordered( batch_size = None ):
|
|
425
|
+
cur_buffer.append( sample_map( x ) )
|
|
426
|
+
|
|
427
|
+
if len( cur_buffer ) >= maxcount:
|
|
428
|
+
# Write current segment
|
|
429
|
+
cur_path = path_template.format( cur_segment )
|
|
430
|
+
df = pd.DataFrame( cur_buffer )
|
|
431
|
+
df.to_parquet( cur_path, **kwargs )
|
|
432
|
+
|
|
433
|
+
cur_segment += 1
|
|
434
|
+
cur_buffer = []
|
|
435
|
+
|
|
436
|
+
if len( cur_buffer ) > 0:
|
|
437
|
+
# Write one last segment with remainder
|
|
438
|
+
cur_path = path_template.format( cur_segment )
|
|
439
|
+
df = pd.DataFrame( cur_buffer )
|
|
440
|
+
df.to_parquet( cur_path, **kwargs )
|
|
441
|
+
|
|
369
442
|
|
|
370
443
|
# Implemented by specific subclasses
|
|
371
444
|
|
|
@@ -390,18 +463,18 @@ class Dataset( Generic[ST] ):
|
|
|
390
463
|
|
|
391
464
|
return self.sample_type.from_bytes( sample['msgpack'] )
|
|
392
465
|
|
|
393
|
-
try:
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
466
|
+
# try:
|
|
467
|
+
# assert type( sample ) == dict
|
|
468
|
+
# return cls.sample_class( **{
|
|
469
|
+
# k: v
|
|
470
|
+
# for k, v in sample.items() if k != '__key__'
|
|
471
|
+
# } )
|
|
399
472
|
|
|
400
|
-
except Exception as e:
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
473
|
+
# except Exception as e:
|
|
474
|
+
# # Sample constructor failed -- revert to default
|
|
475
|
+
# return AnySample(
|
|
476
|
+
# value = sample,
|
|
477
|
+
# )
|
|
405
478
|
|
|
406
479
|
def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
|
|
407
480
|
"""Wrap a `batch` of samples into the appropriate dataset-specific type
|
|
@@ -449,6 +522,9 @@ def packable( cls ):
|
|
|
449
522
|
|
|
450
523
|
##
|
|
451
524
|
|
|
525
|
+
class_name = cls.__name__
|
|
526
|
+
class_annotations = cls.__annotations__
|
|
527
|
+
|
|
452
528
|
# Add in dataclass niceness to original class
|
|
453
529
|
as_dataclass = dataclass( cls )
|
|
454
530
|
|
|
@@ -458,8 +534,9 @@ def packable( cls ):
|
|
|
458
534
|
def __post_init__( self ):
|
|
459
535
|
return PackableSample.__post_init__( self )
|
|
460
536
|
|
|
461
|
-
|
|
462
|
-
as_packable.
|
|
537
|
+
# TODO This doesn't properly carry over the original
|
|
538
|
+
as_packable.__name__ = class_name
|
|
539
|
+
as_packable.__annotations__ = class_annotations
|
|
463
540
|
|
|
464
541
|
##
|
|
465
542
|
|
atdata/lens.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Lenses between typed datasets"""
|
|
2
|
+
|
|
3
|
+
##
|
|
4
|
+
# Imports
|
|
5
|
+
|
|
6
|
+
from .dataset import PackableSample
|
|
7
|
+
|
|
8
|
+
import functools
|
|
9
|
+
import inspect
|
|
10
|
+
|
|
11
|
+
from typing import (
|
|
12
|
+
TypeAlias,
|
|
13
|
+
Type,
|
|
14
|
+
TypeVar,
|
|
15
|
+
Tuple,
|
|
16
|
+
Dict,
|
|
17
|
+
Callable,
|
|
18
|
+
Optional,
|
|
19
|
+
Generic,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
##
|
|
24
|
+
# Typing helpers
|
|
25
|
+
|
|
26
|
+
DatasetType: TypeAlias = Type[PackableSample]
|
|
27
|
+
LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
|
|
28
|
+
|
|
29
|
+
S = TypeVar( 'S', bound = PackableSample )
|
|
30
|
+
V = TypeVar( 'V', bound = PackableSample )
|
|
31
|
+
type LensGetter[S, V] = Callable[[S], V]
|
|
32
|
+
type LensPutter[S, V] = Callable[[V, S], S]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
##
|
|
36
|
+
# Shortcut decorators
|
|
37
|
+
|
|
38
|
+
class Lens( Generic[S, V] ):
|
|
39
|
+
"""TODO"""
|
|
40
|
+
|
|
41
|
+
def __init__( self, get: LensGetter[S, V],
|
|
42
|
+
put: Optional[LensPutter[S, V]] = None
|
|
43
|
+
) -> None:
|
|
44
|
+
"""TODO"""
|
|
45
|
+
##
|
|
46
|
+
|
|
47
|
+
# Update
|
|
48
|
+
functools.update_wrapper( self, get )
|
|
49
|
+
|
|
50
|
+
# Store the getter
|
|
51
|
+
self._getter = get
|
|
52
|
+
|
|
53
|
+
# Determine and store the putter
|
|
54
|
+
if put is None:
|
|
55
|
+
# Trivial putter does not update the source
|
|
56
|
+
def _trivial_put( v: V, s: S ) -> S:
|
|
57
|
+
return s
|
|
58
|
+
put = _trivial_put
|
|
59
|
+
|
|
60
|
+
self._putter = put
|
|
61
|
+
|
|
62
|
+
# Register this lens for this type signature
|
|
63
|
+
|
|
64
|
+
sig = inspect.signature( get )
|
|
65
|
+
input_types = list( sig.parameters.values() )
|
|
66
|
+
assert len( input_types ) == 1, \
|
|
67
|
+
'Wrong number of input args for lens: should only have one'
|
|
68
|
+
|
|
69
|
+
input_type = input_types[0].annotation
|
|
70
|
+
output_type = sig.return_annotation
|
|
71
|
+
|
|
72
|
+
_registered_lenses[(input_type, output_type)] = self
|
|
73
|
+
print( _registered_lenses )
|
|
74
|
+
|
|
75
|
+
#
|
|
76
|
+
|
|
77
|
+
def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
|
|
78
|
+
"""TODO"""
|
|
79
|
+
##
|
|
80
|
+
self._putter = put
|
|
81
|
+
return put
|
|
82
|
+
|
|
83
|
+
def put( self, v: V, s: S ) -> S:
|
|
84
|
+
"""TODO"""
|
|
85
|
+
return self._putter( v, s )
|
|
86
|
+
|
|
87
|
+
def get( self, s: S ) -> V:
|
|
88
|
+
"""TODO"""
|
|
89
|
+
return self( s )
|
|
90
|
+
|
|
91
|
+
#
|
|
92
|
+
|
|
93
|
+
def __call__( self, s: S ) -> V:
|
|
94
|
+
return self._getter( s )
|
|
95
|
+
|
|
96
|
+
def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
97
|
+
"""Register the annotated function `f` as the getter of a sample lens"""
|
|
98
|
+
return Lens[S, V]( f )
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
##
|
|
102
|
+
# Global registration of used lenses
|
|
103
|
+
|
|
104
|
+
_registered_lenses: Dict[LensSignature, Lens] = dict()
|
|
105
|
+
"""TODO"""
|
|
106
|
+
|
|
107
|
+
# def lens( f: LensPutter ) -> Lens:
|
|
108
|
+
# """Register the annotated function `f` as a sample lens"""
|
|
109
|
+
# ##
|
|
110
|
+
|
|
111
|
+
# sig = inspect.signature( f )
|
|
112
|
+
|
|
113
|
+
# input_types = list( sig.parameters.values() )
|
|
114
|
+
# output_type = sig.return_annotation
|
|
115
|
+
|
|
116
|
+
# _registered_lenses[]
|
|
117
|
+
|
|
118
|
+
# f.lens = Lens(
|
|
119
|
+
|
|
120
|
+
# )
|
|
121
|
+
|
|
122
|
+
# return f
|
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: atdata
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2b1
|
|
4
4
|
Summary: A loose federation of distributed, typed datasets
|
|
5
5
|
Author-email: Maxine Levesque <hello@maxine.science>
|
|
6
6
|
License-File: LICENSE
|
|
7
7
|
Requires-Python: >=3.12
|
|
8
|
+
Requires-Dist: fastparquet>=2024.11.0
|
|
8
9
|
Requires-Dist: msgpack>=1.1.2
|
|
9
10
|
Requires-Dist: numpy>=2.3.4
|
|
10
11
|
Requires-Dist: ormsgpack>=1.11.0
|
|
12
|
+
Requires-Dist: pandas>=2.3.3
|
|
13
|
+
Requires-Dist: tqdm>=4.67.1
|
|
11
14
|
Requires-Dist: webdataset>=1.0.2
|
|
12
15
|
Description-Content-Type: text/markdown
|
|
13
16
|
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
atdata/__init__.py,sha256=YnlohxQwTUK6V84XHm2gdeCQH5sIrTHVLSApB-nt_z8,216
|
|
2
|
+
atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
|
|
3
|
+
atdata/dataset.py,sha256=pBaND2D33JiJoiL9CtCTBa3octtifa21P19K076sW3Q,15905
|
|
4
|
+
atdata/lens.py,sha256=ikExMWdGP3QH-bEuUDNAYO_ZjeaKJTfL9lpaN9CrRB4,2624
|
|
5
|
+
atdata-0.1.2b1.dist-info/METADATA,sha256=WtHM3N0kMxJKwPXl5cluRNnvk49WU_e72lVhmcInraY,529
|
|
6
|
+
atdata-0.1.2b1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
atdata-0.1.2b1.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
|
|
8
|
+
atdata-0.1.2b1.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
9
|
+
atdata-0.1.2b1.dist-info/RECORD,,
|
atdata-0.1.2a4.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
atdata/__init__.py,sha256=jPZVd_6UIo0DSbCnXAnYZ2eMwHYzOk--5vtEDTZvwqw,173
|
|
2
|
-
atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
|
|
3
|
-
atdata/dataset.py,sha256=4cfxyET8470RGKvHvseH8KZBQvTjevovPr_JGVwj854,13518
|
|
4
|
-
atdata-0.1.2a4.dist-info/METADATA,sha256=igT4Js5SEl5IUhWX6AqDYdROQ022LCx_qfT3PMLifjI,434
|
|
5
|
-
atdata-0.1.2a4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
atdata-0.1.2a4.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
|
|
7
|
-
atdata-0.1.2a4.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
8
|
-
atdata-0.1.2a4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|