atdata 0.1.2a3__py3-none-any.whl → 0.1.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/__init__.py CHANGED
@@ -10,5 +10,10 @@ from .dataset import (
10
10
  packable,
11
11
  )
12
12
 
13
+ from .lens import (
14
+ Lens,
15
+ lens,
16
+ )
17
+
13
18
 
14
19
  #
atdata/dataset.py CHANGED
@@ -5,21 +5,29 @@
5
5
 
6
6
  import webdataset as wds
7
7
 
8
- import functools
9
- from dataclasses import dataclass
8
+ from pathlib import Path
10
9
  import uuid
11
-
12
- import numpy as np
13
-
10
+ import functools
11
+ from dataclasses import (
12
+ dataclass,
13
+ asdict,
14
+ )
14
15
  from abc import (
15
16
  ABC,
16
17
  abstractmethod,
17
18
  )
19
+
20
+ from tqdm import tqdm
21
+ import numpy as np
22
+ import pandas as pd
23
+
18
24
  from typing import (
19
25
  Any,
20
26
  Optional,
21
27
  Dict,
22
28
  Sequence,
29
+ Iterable,
30
+ Callable,
23
31
  #
24
32
  Self,
25
33
  Generic,
@@ -45,9 +53,14 @@ from . import _helpers as eh
45
53
  ##
46
54
  # Typing help
47
55
 
56
+ Pathlike = str | Path
57
+
48
58
  WDSRawSample: TypeAlias = Dict[str, Any]
49
59
  WDSRawBatch: TypeAlias = Dict[str, Any]
50
60
 
61
+ SampleExportRow: TypeAlias = Dict[str, Any]
62
+ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
63
+
51
64
 
52
65
  ##
53
66
  # Main base classes
@@ -94,6 +107,7 @@ def _make_packable( x ):
94
107
  return eh.array_to_bytes( x )
95
108
  return x
96
109
 
110
+ @dataclass
97
111
  class PackableSample( ABC ):
98
112
  """A sample that can be packed and unpacked with msgpack"""
99
113
 
@@ -235,6 +249,7 @@ class Dataset( Generic[ST] ):
235
249
  @property
236
250
  def sample_type( self ) -> Type:
237
251
  """The type of each returned sample from this `Dataset`'s iterator"""
252
+ # TODO Figure out why linting fails here
238
253
  return self.__orig_class__.__args__[0]
239
254
  @property
240
255
  def batch_type( self ) -> Type:
@@ -286,7 +301,7 @@ class Dataset( Generic[ST] ):
286
301
 
287
302
  def ordered( self,
288
303
  batch_size: int | None = 1,
289
- ) -> wds.DataPipeline:
304
+ ) -> Iterable[ST]:
290
305
  """Iterate over the dataset in order
291
306
 
292
307
  Args:
@@ -325,7 +340,7 @@ class Dataset( Generic[ST] ):
325
340
  buffer_shards: int = 100,
326
341
  buffer_samples: int = 10_000,
327
342
  batch_size: int | None = 1,
328
- ) -> wds.DataPipeline:
343
+ ) -> Iterable[ST]:
329
344
  """Iterate over the dataset in random order
330
345
 
331
346
  Args:
@@ -366,6 +381,64 @@ class Dataset( Generic[ST] ):
366
381
  wds.batched( batch_size ),
367
382
  wds.map( self.wrap_batch ),
368
383
  )
384
+
385
+ # TODO Rewrite to eliminate `pandas` dependency directly calling
386
+ # `fastparquet`
387
+ def to_parquet( self, path: Pathlike,
388
+ sample_map: Optional[SampleExportMap] = None,
389
+ maxcount: Optional[int] = None,
390
+ **kwargs,
391
+ ):
392
+ """Save dataset contents to a `parquet` file at `path`
393
+
394
+ `kwargs` sent to `pandas.to_parquet`
395
+ """
396
+ ##
397
+
398
+ # Normalize args
399
+ path = Path( path )
400
+ if sample_map is None:
401
+ sample_map = asdict
402
+
403
+ verbose = kwargs.get( 'verbose', False )
404
+
405
+ it = self.ordered( batch_size = None )
406
+ if verbose:
407
+ it = tqdm( it )
408
+
409
+ #
410
+
411
+ if maxcount is None:
412
+ # Load and save full dataset
413
+ df = pd.DataFrame( [ sample_map( x )
414
+ for x in self.ordered( batch_size = None ) ] )
415
+ df.to_parquet( path, **kwargs )
416
+
417
+ else:
418
+ # Load and save dataset in segments of size `maxcount`
419
+
420
+ cur_segment = 0
421
+ cur_buffer = []
422
+ path_template = (path.parent / f'{path.stem}-%06d.{path.suffix}').as_posix()
423
+
424
+ for x in self.ordered( batch_size = None ):
425
+ cur_buffer.append( sample_map( x ) )
426
+
427
+ if len( cur_buffer ) >= maxcount:
428
+ # Write current segment
429
+ cur_path = path_template.format( cur_segment )
430
+ df = pd.DataFrame( cur_buffer )
431
+ df.to_parquet( cur_path, **kwargs )
432
+
433
+ cur_segment += 1
434
+ cur_buffer = []
435
+
436
+ if len( cur_buffer ) > 0:
437
+ # Write one last segment with remainder
438
+ cur_path = path_template.format( cur_segment )
439
+ df = pd.DataFrame( cur_buffer )
440
+ df.to_parquet( cur_path, **kwargs )
441
+
369
442
 
370
443
  # Implemented by specific subclasses
371
444
 
@@ -390,18 +463,18 @@ class Dataset( Generic[ST] ):
390
463
 
391
464
  return self.sample_type.from_bytes( sample['msgpack'] )
392
465
 
393
- try:
394
- assert type( sample ) == dict
395
- return cls.sample_class( **{
396
- k: v
397
- for k, v in sample.items() if k != '__key__'
398
- } )
466
+ # try:
467
+ # assert type( sample ) == dict
468
+ # return cls.sample_class( **{
469
+ # k: v
470
+ # for k, v in sample.items() if k != '__key__'
471
+ # } )
399
472
 
400
- except Exception as e:
401
- # Sample constructor failed -- revert to default
402
- return AnySample(
403
- value = sample,
404
- )
473
+ # except Exception as e:
474
+ # # Sample constructor failed -- revert to default
475
+ # return AnySample(
476
+ # value = sample,
477
+ # )
405
478
 
406
479
  def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
407
480
  """Wrap a `batch` of samples into the appropriate dataset-specific type
@@ -449,15 +522,22 @@ def packable( cls ):
449
522
 
450
523
  ##
451
524
 
525
+ class_name = cls.__name__
526
+ class_annotations = cls.__annotations__
527
+
528
+ # Add in dataclass niceness to original class
452
529
  as_dataclass = dataclass( cls )
453
530
 
454
- class as_packable( PackableSample, as_dataclass ):
531
+ # This triggers a bunch of behind-the-scenes stuff for the newly annotated class
532
+ @dataclass
533
+ class as_packable( as_dataclass, PackableSample ):
455
534
  def __post_init__( self ):
456
535
  return PackableSample.__post_init__( self )
457
536
 
458
- as_packable.__name__ = cls.__name__
459
- as_packable.__annotations__ = cls.__annotations__
537
+ # TODO This doesn't properly carry over the original
538
+ as_packable.__name__ = class_name
539
+ as_packable.__annotations__ = class_annotations
460
540
 
461
541
  ##
462
-
542
+
463
543
  return as_packable
atdata/lens.py ADDED
@@ -0,0 +1,122 @@
1
+ """Lenses between typed datasets"""
2
+
3
+ ##
4
+ # Imports
5
+
6
+ from .dataset import PackableSample
7
+
8
+ import functools
9
+ import inspect
10
+
11
+ from typing import (
12
+ TypeAlias,
13
+ Type,
14
+ TypeVar,
15
+ Tuple,
16
+ Dict,
17
+ Callable,
18
+ Optional,
19
+ Generic,
20
+ )
21
+
22
+
23
+ ##
24
+ # Typing helpers
25
+
26
+ DatasetType: TypeAlias = Type[PackableSample]
27
+ LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
28
+
29
+ S = TypeVar( 'S', bound = PackableSample )
30
+ V = TypeVar( 'V', bound = PackableSample )
31
+ type LensGetter[S, V] = Callable[[S], V]
32
+ type LensPutter[S, V] = Callable[[V, S], S]
33
+
34
+
35
+ ##
36
+ # Shortcut decorators
37
+
38
+ class Lens( Generic[S, V] ):
39
+ """TODO"""
40
+
41
+ def __init__( self, get: LensGetter[S, V],
42
+ put: Optional[LensPutter[S, V]] = None
43
+ ) -> None:
44
+ """TODO"""
45
+ ##
46
+
47
+ # Update
48
+ functools.update_wrapper( self, get )
49
+
50
+ # Store the getter
51
+ self._getter = get
52
+
53
+ # Determine and store the putter
54
+ if put is None:
55
+ # Trivial putter does not update the source
56
+ def _trivial_put( v: V, s: S ) -> S:
57
+ return s
58
+ put = _trivial_put
59
+
60
+ self._putter = put
61
+
62
+ # Register this lens for this type signature
63
+
64
+ sig = inspect.signature( get )
65
+ input_types = list( sig.parameters.values() )
66
+ assert len( input_types ) == 1, \
67
+ 'Wrong number of input args for lens: should only have one'
68
+
69
+ input_type = input_types[0].annotation
70
+ output_type = sig.return_annotation
71
+
72
+ _registered_lenses[(input_type, output_type)] = self
73
+ print( _registered_lenses )
74
+
75
+ #
76
+
77
+ def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
78
+ """TODO"""
79
+ ##
80
+ self._putter = put
81
+ return put
82
+
83
+ def put( self, v: V, s: S ) -> S:
84
+ """TODO"""
85
+ return self._putter( v, s )
86
+
87
+ def get( self, s: S ) -> V:
88
+ """TODO"""
89
+ return self( s )
90
+
91
+ #
92
+
93
+ def __call__( self, s: S ) -> V:
94
+ return self._getter( s )
95
+
96
+ def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
97
+ """Register the annotated function `f` as the getter of a sample lens"""
98
+ return Lens[S, V]( f )
99
+
100
+
101
+ ##
102
+ # Global registration of used lenses
103
+
104
+ _registered_lenses: Dict[LensSignature, Lens] = dict()
105
+ """TODO"""
106
+
107
+ # def lens( f: LensPutter ) -> Lens:
108
+ # """Register the annotated function `f` as a sample lens"""
109
+ # ##
110
+
111
+ # sig = inspect.signature( f )
112
+
113
+ # input_types = list( sig.parameters.values() )
114
+ # output_type = sig.return_annotation
115
+
116
+ # _registered_lenses[]
117
+
118
+ # f.lens = Lens(
119
+
120
+ # )
121
+
122
+ # return f
@@ -1,13 +1,16 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atdata
3
- Version: 0.1.2a3
3
+ Version: 0.1.2b1
4
4
  Summary: A loose federation of distributed, typed datasets
5
5
  Author-email: Maxine Levesque <hello@maxine.science>
6
6
  License-File: LICENSE
7
7
  Requires-Python: >=3.12
8
+ Requires-Dist: fastparquet>=2024.11.0
8
9
  Requires-Dist: msgpack>=1.1.2
9
10
  Requires-Dist: numpy>=2.3.4
10
11
  Requires-Dist: ormsgpack>=1.11.0
12
+ Requires-Dist: pandas>=2.3.3
13
+ Requires-Dist: tqdm>=4.67.1
11
14
  Requires-Dist: webdataset>=1.0.2
12
15
  Description-Content-Type: text/markdown
13
16
 
@@ -0,0 +1,9 @@
1
+ atdata/__init__.py,sha256=YnlohxQwTUK6V84XHm2gdeCQH5sIrTHVLSApB-nt_z8,216
2
+ atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
3
+ atdata/dataset.py,sha256=pBaND2D33JiJoiL9CtCTBa3octtifa21P19K076sW3Q,15905
4
+ atdata/lens.py,sha256=ikExMWdGP3QH-bEuUDNAYO_ZjeaKJTfL9lpaN9CrRB4,2624
5
+ atdata-0.1.2b1.dist-info/METADATA,sha256=WtHM3N0kMxJKwPXl5cluRNnvk49WU_e72lVhmcInraY,529
6
+ atdata-0.1.2b1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ atdata-0.1.2b1.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
8
+ atdata-0.1.2b1.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
9
+ atdata-0.1.2b1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- atdata/__init__.py,sha256=jPZVd_6UIo0DSbCnXAnYZ2eMwHYzOk--5vtEDTZvwqw,173
2
- atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
3
- atdata/dataset.py,sha256=HXctGwIbU5kr2pqiQCYDyGP1mkph1gIt-x1_PRtWyew,13372
4
- atdata-0.1.2a3.dist-info/METADATA,sha256=Jj5vP4NW-HtckIsPRzzpXVQXgcQ8HaFSGehdAu4Vfbo,434
5
- atdata-0.1.2a3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- atdata-0.1.2a3.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
7
- atdata-0.1.2a3.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
8
- atdata-0.1.2a3.dist-info/RECORD,,