atdata 0.1.2a4__py3-none-any.whl → 0.1.3a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atdata/__init__.py CHANGED
@@ -10,5 +10,11 @@ from .dataset import (
10
10
  packable,
11
11
  )
12
12
 
13
+ from .lens import (
14
+ Lens,
15
+ LensNetwork,
16
+ lens,
17
+ )
18
+
13
19
 
14
20
  #
atdata/dataset.py CHANGED
@@ -5,21 +5,29 @@
5
5
 
6
6
  import webdataset as wds
7
7
 
8
- import functools
9
- from dataclasses import dataclass
8
+ from pathlib import Path
10
9
  import uuid
11
-
12
- import numpy as np
13
-
10
+ import functools
11
+ from dataclasses import (
12
+ dataclass,
13
+ asdict,
14
+ )
14
15
  from abc import (
15
16
  ABC,
16
17
  abstractmethod,
17
18
  )
19
+
20
+ from tqdm import tqdm
21
+ import numpy as np
22
+ import pandas as pd
23
+
18
24
  from typing import (
19
25
  Any,
20
26
  Optional,
21
27
  Dict,
22
28
  Sequence,
29
+ Iterable,
30
+ Callable,
23
31
  #
24
32
  Self,
25
33
  Generic,
@@ -40,14 +48,20 @@ from numpy.typing import (
40
48
  import msgpack
41
49
  import ormsgpack
42
50
  from . import _helpers as eh
51
+ from .lens import Lens, LensNetwork
43
52
 
44
53
 
45
54
  ##
46
55
  # Typing help
47
56
 
57
+ Pathlike = str | Path
58
+
48
59
  WDSRawSample: TypeAlias = Dict[str, Any]
49
60
  WDSRawBatch: TypeAlias = Dict[str, Any]
50
61
 
62
+ SampleExportRow: TypeAlias = Dict[str, Any]
63
+ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
64
+
51
65
 
52
66
  ##
53
67
  # Main base classes
@@ -94,6 +108,7 @@ def _make_packable( x ):
94
108
  return eh.array_to_bytes( x )
95
109
  return x
96
110
 
111
+ @dataclass
97
112
  class PackableSample( ABC ):
98
113
  """A sample that can be packed and unpacked with msgpack"""
99
114
 
@@ -217,6 +232,8 @@ class SampleBatch( Generic[DT] ):
217
232
  ST = TypeVar( 'ST', bound = PackableSample )
218
233
  # BT = TypeVar( 'BT' )
219
234
 
235
+ RT = TypeVar( 'RT', bound = PackableSample )
236
+
220
237
  # TODO For python 3.13
221
238
  # BT = TypeVar( 'BT', default = None )
222
239
  # IT = TypeVar( 'IT', default = Any )
@@ -235,6 +252,7 @@ class Dataset( Generic[ST] ):
235
252
  @property
236
253
  def sample_type( self ) -> Type:
237
254
  """The type of each returned sample from this `Dataset`'s iterator"""
255
+ # TODO Figure out why linting fails here
238
256
  return self.__orig_class__.__args__[0]
239
257
  @property
240
258
  def batch_type( self ) -> Type:
@@ -253,6 +271,17 @@ class Dataset( Generic[ST] ):
253
271
  super().__init__()
254
272
  self.url = url
255
273
 
274
+ # Allow addition of automatic transformation of raw underlying data
275
+ self._output_lens: Lens | None = None
276
+
277
+ def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
278
+ """TODO"""
279
+ ret = Dataset[other]( self.url )
280
+ # Get the singleton lens registry
281
+ lenses = LensNetwork()
282
+ ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
283
+ return ret
284
+
256
285
  # @classmethod
257
286
  # def register( cls, uri: str,
258
287
  # sample_class: Type,
@@ -278,15 +307,15 @@ class Dataset( Generic[ST] ):
278
307
  A full (non-lazy) list of the individual ``tar`` files within the
279
308
  source WebDataset.
280
309
  """
281
- pipe = wds.DataPipeline(
282
- wds.SimpleShardList( self.url ),
283
- wds.map( lambda x: x['url'] )
310
+ pipe = wds.pipeline.DataPipeline(
311
+ wds.shardlists.SimpleShardList( self.url ),
312
+ wds.filters.map( lambda x: x['url'] )
284
313
  )
285
314
  return list( pipe )
286
315
 
287
316
  def ordered( self,
288
317
  batch_size: int | None = 1,
289
- ) -> wds.DataPipeline:
318
+ ) -> Iterable[ST]:
290
319
  """Iterate over the dataset in order
291
320
 
292
321
  Args:
@@ -302,30 +331,30 @@ class Dataset( Generic[ST] ):
302
331
 
303
332
  if batch_size is None:
304
333
  # TODO Duplication here
305
- return wds.DataPipeline(
306
- wds.SimpleShardList( self.url ),
307
- wds.split_by_worker,
334
+ return wds.pipeline.DataPipeline(
335
+ wds.shardlists.SimpleShardList( self.url ),
336
+ wds.shardlists.split_by_worker,
308
337
  #
309
- wds.tarfile_to_samples(),
338
+ wds.tariterators.tarfile_to_samples(),
310
339
  # wds.map( self.preprocess ),
311
- wds.map( self.wrap ),
340
+ wds.filters.map( self.wrap ),
312
341
  )
313
342
 
314
- return wds.DataPipeline(
315
- wds.SimpleShardList( self.url ),
316
- wds.split_by_worker,
343
+ return wds.pipeline.DataPipeline(
344
+ wds.shardlists.SimpleShardList( self.url ),
345
+ wds.shardlists.split_by_worker,
317
346
  #
318
- wds.tarfile_to_samples(),
347
+ wds.tariterators.tarfile_to_samples(),
319
348
  # wds.map( self.preprocess ),
320
- wds.batched( batch_size ),
321
- wds.map( self.wrap_batch ),
349
+ wds.filters.batched( batch_size ),
350
+ wds.filters.map( self.wrap_batch ),
322
351
  )
323
352
 
324
353
  def shuffled( self,
325
354
  buffer_shards: int = 100,
326
355
  buffer_samples: int = 10_000,
327
356
  batch_size: int | None = 1,
328
- ) -> wds.DataPipeline:
357
+ ) -> Iterable[ST]:
329
358
  """Iterate over the dataset in random order
330
359
 
331
360
  Args:
@@ -366,6 +395,64 @@ class Dataset( Generic[ST] ):
366
395
  wds.batched( batch_size ),
367
396
  wds.map( self.wrap_batch ),
368
397
  )
398
+
399
+ # TODO Rewrite to eliminate `pandas` dependency directly calling
400
+ # `fastparquet`
401
+ def to_parquet( self, path: Pathlike,
402
+ sample_map: Optional[SampleExportMap] = None,
403
+ maxcount: Optional[int] = None,
404
+ **kwargs,
405
+ ):
406
+ """Save dataset contents to a `parquet` file at `path`
407
+
408
+ `kwargs` sent to `pandas.to_parquet`
409
+ """
410
+ ##
411
+
412
+ # Normalize args
413
+ path = Path( path )
414
+ if sample_map is None:
415
+ sample_map = asdict
416
+
417
+ verbose = kwargs.get( 'verbose', False )
418
+
419
+ it = self.ordered( batch_size = None )
420
+ if verbose:
421
+ it = tqdm( it )
422
+
423
+ #
424
+
425
+ if maxcount is None:
426
+ # Load and save full dataset
427
+ df = pd.DataFrame( [ sample_map( x )
428
+ for x in self.ordered( batch_size = None ) ] )
429
+ df.to_parquet( path, **kwargs )
430
+
431
+ else:
432
+ # Load and save dataset in segments of size `maxcount`
433
+
434
+ cur_segment = 0
435
+ cur_buffer = []
436
+ path_template = (path.parent / f'{path.stem}-%06d.{path.suffix}').as_posix()
437
+
438
+ for x in self.ordered( batch_size = None ):
439
+ cur_buffer.append( sample_map( x ) )
440
+
441
+ if len( cur_buffer ) >= maxcount:
442
+ # Write current segment
443
+ cur_path = path_template.format( cur_segment )
444
+ df = pd.DataFrame( cur_buffer )
445
+ df.to_parquet( cur_path, **kwargs )
446
+
447
+ cur_segment += 1
448
+ cur_buffer = []
449
+
450
+ if len( cur_buffer ) > 0:
451
+ # Write one last segment with remainder
452
+ cur_path = path_template.format( cur_segment )
453
+ df = pd.DataFrame( cur_buffer )
454
+ df.to_parquet( cur_path, **kwargs )
455
+
369
456
 
370
457
  # Implemented by specific subclasses
371
458
 
@@ -388,20 +475,24 @@ class Dataset( Generic[ST] ):
388
475
  assert 'msgpack' in sample
389
476
  assert type( sample['msgpack'] ) == bytes
390
477
 
391
- return self.sample_type.from_bytes( sample['msgpack'] )
478
+ if self._output_lens is None:
479
+ return self.sample_type.from_bytes( sample['msgpack'] )
480
+
481
+ source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
482
+ return self._output_lens( source_sample )
392
483
 
393
- try:
394
- assert type( sample ) == dict
395
- return cls.sample_class( **{
396
- k: v
397
- for k, v in sample.items() if k != '__key__'
398
- } )
484
+ # try:
485
+ # assert type( sample ) == dict
486
+ # return cls.sample_class( **{
487
+ # k: v
488
+ # for k, v in sample.items() if k != '__key__'
489
+ # } )
399
490
 
400
- except Exception as e:
401
- # Sample constructor failed -- revert to default
402
- return AnySample(
403
- value = sample,
404
- )
491
+ # except Exception as e:
492
+ # # Sample constructor failed -- revert to default
493
+ # return AnySample(
494
+ # value = sample,
495
+ # )
405
496
 
406
497
  def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
407
498
  """Wrap a `batch` of samples into the appropriate dataset-specific type
@@ -449,6 +540,9 @@ def packable( cls ):
449
540
 
450
541
  ##
451
542
 
543
+ class_name = cls.__name__
544
+ class_annotations = cls.__annotations__
545
+
452
546
  # Add in dataclass niceness to original class
453
547
  as_dataclass = dataclass( cls )
454
548
 
@@ -458,8 +552,9 @@ def packable( cls ):
458
552
  def __post_init__( self ):
459
553
  return PackableSample.__post_init__( self )
460
554
 
461
- as_packable.__name__ = cls.__name__
462
- as_packable.__annotations__ = cls.__annotations__
555
+ # TODO This doesn't properly carry over the original
556
+ as_packable.__name__ = class_name
557
+ as_packable.__annotations__ = class_annotations
463
558
 
464
559
  ##
465
560
 
atdata/lens.py ADDED
@@ -0,0 +1,200 @@
1
+ """Lenses between typed datasets"""
2
+
3
+ ##
4
+ # Imports
5
+
6
+ import functools
7
+ import inspect
8
+
9
+ from typing import (
10
+ TypeAlias,
11
+ Type,
12
+ TypeVar,
13
+ Tuple,
14
+ Dict,
15
+ Callable,
16
+ Optional,
17
+ Generic,
18
+ #
19
+ TYPE_CHECKING
20
+ )
21
+
22
+ if TYPE_CHECKING:
23
+ from .dataset import PackableSample
24
+
25
+
26
+ ##
27
+ # Typing helpers
28
+
29
+ DatasetType: TypeAlias = Type['PackableSample']
30
+ LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
31
+
32
+ S = TypeVar( 'S', bound = 'PackableSample' )
33
+ V = TypeVar( 'V', bound = 'PackableSample' )
34
+ type LensGetter[S, V] = Callable[[S], V]
35
+ type LensPutter[S, V] = Callable[[V, S], S]
36
+
37
+
38
+ ##
39
+ # Shortcut decorators
40
+
41
+ class Lens( Generic[S, V] ):
42
+ """TODO"""
43
+
44
+ # @property
45
+ # def source_type( self ) -> Type[S]:
46
+ # """The source type (S) for the lens; what is put to"""
47
+ # # TODO Figure out why linting fails here
48
+ # return self.__orig_class__.__args__[0]
49
+
50
+ # @property
51
+ # def view_type( self ) -> Type[V]:
52
+ # """The view type (V) for the lens; what is get'd from"""
53
+ # # TODO FIgure out why linting fails here
54
+ # return self.__orig_class__.__args__[1]
55
+
56
+ def __init__( self, get: LensGetter[S, V],
57
+ put: Optional[LensPutter[S, V]] = None
58
+ ) -> None:
59
+ """TODO"""
60
+ ##
61
+
62
+ # Check argument validity
63
+
64
+ sig = inspect.signature( get )
65
+ input_types = list( sig.parameters.values() )
66
+ assert len( input_types ) == 1, \
67
+ 'Wrong number of input args for lens: should only have one'
68
+
69
+ # Update function details for this object as returned by annotation
70
+ functools.update_wrapper( self, get )
71
+
72
+ self.source_type: Type[PackableSample] = input_types[0].annotation
73
+ self.view_type = sig.return_annotation
74
+
75
+ # Store the getter
76
+ self._getter = get
77
+
78
+ # Determine and store the putter
79
+ if put is None:
80
+ # Trivial putter does not update the source
81
+ def _trivial_put( v: V, s: S ) -> S:
82
+ return s
83
+ put = _trivial_put
84
+ self._putter = put
85
+
86
+ #
87
+
88
+ def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
89
+ """TODO"""
90
+ ##
91
+ self._putter = put
92
+ return put
93
+
94
+ # Methods to actually execute transformations
95
+
96
+ def put( self, v: V, s: S ) -> S:
97
+ """TODO"""
98
+ return self._putter( v, s )
99
+
100
+ def get( self, s: S ) -> V:
101
+ """TODO"""
102
+ return self( s )
103
+
104
+ # Convenience to enable calling the lens as its getter
105
+
106
+ def __call__( self, s: S ) -> V:
107
+ return self._getter( s )
108
+
109
+ # TODO Figure out how to properly parameterize this
110
+ # def _lens_factory[S, V]( register: bool = True ):
111
+ # """Register the annotated function `f` as the getter of a sample lens"""
112
+
113
+ # # The actual lens decorator taking a lens getter function to a lens object
114
+ # def _decorator( f: LensGetter[S, V] ) -> Lens[S, V]:
115
+ # ret = Lens[S, V]( f )
116
+ # if register:
117
+ # _network.register( ret )
118
+ # return ret
119
+
120
+ # # Return the lens decorator
121
+ # return _decorator
122
+
123
+ # # For convenience
124
+ # lens = _lens_factory
125
+
126
+ def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
127
+ ret = Lens[S, V]( f )
128
+ _network.register( ret )
129
+ return ret
130
+
131
+
132
+ ##
133
+ # Global registry of used lenses
134
+
135
+ # _registered_lenses: Dict[LensSignature, Lens] = dict()
136
+ # """TODO"""
137
+
138
+ class LensNetwork:
139
+ """TODO"""
140
+
141
+ _instance = None
142
+ """The singleton instance"""
143
+
144
+ def __new__(cls, *args, **kwargs):
145
+ if cls._instance is None:
146
+ # If no instance exists, create a new one
147
+ cls._instance = super().__new__(cls)
148
+ return cls._instance # Return the existing (or newly created) instance
149
+
150
+ def __init__(self):
151
+ if not hasattr(self, '_initialized'): # Check if already initialized
152
+ self._registry: Dict[LensSignature, Lens] = dict()
153
+ self._initialized = True
154
+
155
+ def register( self, _lens: Lens ):
156
+ """Set `lens` as the canonical view between its source and view types"""
157
+
158
+ # sig = inspect.signature( _lens.get )
159
+ # input_types = list( sig.parameters.values() )
160
+ # assert len( input_types ) == 1, \
161
+ # 'Wrong number of input args for lens: should only have one'
162
+
163
+ # input_type = input_types[0].annotation
164
+ # print( input_type )
165
+ # output_type = sig.return_annotation
166
+
167
+ # self._registry[input_type, output_type] = _lens
168
+ print( _lens.source_type )
169
+ self._registry[_lens.source_type, _lens.view_type] = _lens
170
+
171
+ def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
172
+ """TODO"""
173
+
174
+ # TODO Handle compositional closure
175
+ ret = self._registry.get( (source, view), None )
176
+ if ret is None:
177
+ raise ValueError( f'No registered lens from source {source} to view {view}' )
178
+
179
+ return ret
180
+
181
+
182
+ # Create global singleton registry instance
183
+ _network = LensNetwork()
184
+
185
+ # def lens( f: LensPutter ) -> Lens:
186
+ # """Register the annotated function `f` as a sample lens"""
187
+ # ##
188
+
189
+ # sig = inspect.signature( f )
190
+
191
+ # input_types = list( sig.parameters.values() )
192
+ # output_type = sig.return_annotation
193
+
194
+ # _registered_lenses[]
195
+
196
+ # f.lens = Lens(
197
+
198
+ # )
199
+
200
+ # return f
@@ -1,13 +1,16 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atdata
3
- Version: 0.1.2a4
3
+ Version: 0.1.3a2
4
4
  Summary: A loose federation of distributed, typed datasets
5
5
  Author-email: Maxine Levesque <hello@maxine.science>
6
6
  License-File: LICENSE
7
7
  Requires-Python: >=3.12
8
+ Requires-Dist: fastparquet>=2024.11.0
8
9
  Requires-Dist: msgpack>=1.1.2
9
10
  Requires-Dist: numpy>=2.3.4
10
11
  Requires-Dist: ormsgpack>=1.11.0
12
+ Requires-Dist: pandas>=2.3.3
13
+ Requires-Dist: tqdm>=4.67.1
11
14
  Requires-Dist: webdataset>=1.0.2
12
15
  Description-Content-Type: text/markdown
13
16
 
@@ -0,0 +1,9 @@
1
+ atdata/__init__.py,sha256=V2qBg7i2mfCNG9nww6Gi_fDp7iwolDMrNzhmNO6VA7M,233
2
+ atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
3
+ atdata/dataset.py,sha256=brNKGMkA_au2nLF5oUmjwub1E08DVwBKl9PnzPV6rPM,16722
4
+ atdata/lens.py,sha256=bGlxQ6PEnLj5poQ41DHj1LfpsmI5fELnjnUf4qXOsCo,5304
5
+ atdata-0.1.3a2.dist-info/METADATA,sha256=pYLvssfCzaiqer4yZQM3vh7nRlP8ZcEutXA8po7GZC0,529
6
+ atdata-0.1.3a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ atdata-0.1.3a2.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
8
+ atdata-0.1.3a2.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
9
+ atdata-0.1.3a2.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- atdata/__init__.py,sha256=jPZVd_6UIo0DSbCnXAnYZ2eMwHYzOk--5vtEDTZvwqw,173
2
- atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
3
- atdata/dataset.py,sha256=4cfxyET8470RGKvHvseH8KZBQvTjevovPr_JGVwj854,13518
4
- atdata-0.1.2a4.dist-info/METADATA,sha256=igT4Js5SEl5IUhWX6AqDYdROQ022LCx_qfT3PMLifjI,434
5
- atdata-0.1.2a4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- atdata-0.1.2a4.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
7
- atdata-0.1.2a4.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
8
- atdata-0.1.2a4.dist-info/RECORD,,