atdata 0.1.2b1__tar.gz → 0.1.3a2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atdata
3
- Version: 0.1.2b1
3
+ Version: 0.1.3a2
4
4
  Summary: A loose federation of distributed, typed datasets
5
5
  Author-email: Maxine Levesque <hello@maxine.science>
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "atdata"
3
- version = "0.1.2b1"
3
+ version = "0.1.3a2"
4
4
  description = "A loose federation of distributed, typed datasets"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -12,6 +12,7 @@ from .dataset import (
12
12
 
13
13
  from .lens import (
14
14
  Lens,
15
+ LensNetwork,
15
16
  lens,
16
17
  )
17
18
 
@@ -48,6 +48,7 @@ from numpy.typing import (
48
48
  import msgpack
49
49
  import ormsgpack
50
50
  from . import _helpers as eh
51
+ from .lens import Lens, LensNetwork
51
52
 
52
53
 
53
54
  ##
@@ -231,6 +232,8 @@ class SampleBatch( Generic[DT] ):
231
232
  ST = TypeVar( 'ST', bound = PackableSample )
232
233
  # BT = TypeVar( 'BT' )
233
234
 
235
+ RT = TypeVar( 'RT', bound = PackableSample )
236
+
234
237
  # TODO For python 3.13
235
238
  # BT = TypeVar( 'BT', default = None )
236
239
  # IT = TypeVar( 'IT', default = Any )
@@ -268,6 +271,17 @@ class Dataset( Generic[ST] ):
268
271
  super().__init__()
269
272
  self.url = url
270
273
 
274
+ # Allow addition of automatic transformation of raw underlying data
275
+ self._output_lens: Lens | None = None
276
+
277
+ def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
278
+ """TODO"""
279
+ ret = Dataset[other]( self.url )
280
+ # Get the singleton lens registry
281
+ lenses = LensNetwork()
282
+ ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
283
+ return ret
284
+
271
285
  # @classmethod
272
286
  # def register( cls, uri: str,
273
287
  # sample_class: Type,
@@ -293,9 +307,9 @@ class Dataset( Generic[ST] ):
293
307
  A full (non-lazy) list of the individual ``tar`` files within the
294
308
  source WebDataset.
295
309
  """
296
- pipe = wds.DataPipeline(
297
- wds.SimpleShardList( self.url ),
298
- wds.map( lambda x: x['url'] )
310
+ pipe = wds.pipeline.DataPipeline(
311
+ wds.shardlists.SimpleShardList( self.url ),
312
+ wds.filters.map( lambda x: x['url'] )
299
313
  )
300
314
  return list( pipe )
301
315
 
@@ -317,23 +331,23 @@ class Dataset( Generic[ST] ):
317
331
 
318
332
  if batch_size is None:
319
333
  # TODO Duplication here
320
- return wds.DataPipeline(
321
- wds.SimpleShardList( self.url ),
322
- wds.split_by_worker,
334
+ return wds.pipeline.DataPipeline(
335
+ wds.shardlists.SimpleShardList( self.url ),
336
+ wds.shardlists.split_by_worker,
323
337
  #
324
- wds.tarfile_to_samples(),
338
+ wds.tariterators.tarfile_to_samples(),
325
339
  # wds.map( self.preprocess ),
326
- wds.map( self.wrap ),
340
+ wds.filters.map( self.wrap ),
327
341
  )
328
342
 
329
- return wds.DataPipeline(
330
- wds.SimpleShardList( self.url ),
331
- wds.split_by_worker,
343
+ return wds.pipeline.DataPipeline(
344
+ wds.shardlists.SimpleShardList( self.url ),
345
+ wds.shardlists.split_by_worker,
332
346
  #
333
- wds.tarfile_to_samples(),
347
+ wds.tariterators.tarfile_to_samples(),
334
348
  # wds.map( self.preprocess ),
335
- wds.batched( batch_size ),
336
- wds.map( self.wrap_batch ),
349
+ wds.filters.batched( batch_size ),
350
+ wds.filters.map( self.wrap_batch ),
337
351
  )
338
352
 
339
353
  def shuffled( self,
@@ -461,7 +475,11 @@ class Dataset( Generic[ST] ):
461
475
  assert 'msgpack' in sample
462
476
  assert type( sample['msgpack'] ) == bytes
463
477
 
464
- return self.sample_type.from_bytes( sample['msgpack'] )
478
+ if self._output_lens is None:
479
+ return self.sample_type.from_bytes( sample['msgpack'] )
480
+
481
+ source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
482
+ return self._output_lens( source_sample )
465
483
 
466
484
  # try:
467
485
  # assert type( sample ) == dict
@@ -0,0 +1,200 @@
1
+ """Lenses between typed datasets"""
2
+
3
+ ##
4
+ # Imports
5
+
6
+ import functools
7
+ import inspect
8
+
9
+ from typing import (
10
+ TypeAlias,
11
+ Type,
12
+ TypeVar,
13
+ Tuple,
14
+ Dict,
15
+ Callable,
16
+ Optional,
17
+ Generic,
18
+ #
19
+ TYPE_CHECKING
20
+ )
21
+
22
+ if TYPE_CHECKING:
23
+ from .dataset import PackableSample
24
+
25
+
26
+ ##
27
+ # Typing helpers
28
+
29
+ DatasetType: TypeAlias = Type['PackableSample']
30
+ LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
31
+
32
+ S = TypeVar( 'S', bound = 'PackableSample' )
33
+ V = TypeVar( 'V', bound = 'PackableSample' )
34
+ type LensGetter[S, V] = Callable[[S], V]
35
+ type LensPutter[S, V] = Callable[[V, S], S]
36
+
37
+
38
+ ##
39
+ # Shortcut decorators
40
+
41
+ class Lens( Generic[S, V] ):
42
+ """TODO"""
43
+
44
+ # @property
45
+ # def source_type( self ) -> Type[S]:
46
+ # """The source type (S) for the lens; what is put to"""
47
+ # # TODO Figure out why linting fails here
48
+ # return self.__orig_class__.__args__[0]
49
+
50
+ # @property
51
+ # def view_type( self ) -> Type[V]:
52
+ # """The view type (V) for the lens; what is get'd from"""
53
+ # # TODO FIgure out why linting fails here
54
+ # return self.__orig_class__.__args__[1]
55
+
56
+ def __init__( self, get: LensGetter[S, V],
57
+ put: Optional[LensPutter[S, V]] = None
58
+ ) -> None:
59
+ """TODO"""
60
+ ##
61
+
62
+ # Check argument validity
63
+
64
+ sig = inspect.signature( get )
65
+ input_types = list( sig.parameters.values() )
66
+ assert len( input_types ) == 1, \
67
+ 'Wrong number of input args for lens: should only have one'
68
+
69
+ # Update function details for this object as returned by annotation
70
+ functools.update_wrapper( self, get )
71
+
72
+ self.source_type: Type[PackableSample] = input_types[0].annotation
73
+ self.view_type = sig.return_annotation
74
+
75
+ # Store the getter
76
+ self._getter = get
77
+
78
+ # Determine and store the putter
79
+ if put is None:
80
+ # Trivial putter does not update the source
81
+ def _trivial_put( v: V, s: S ) -> S:
82
+ return s
83
+ put = _trivial_put
84
+ self._putter = put
85
+
86
+ #
87
+
88
+ def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
89
+ """TODO"""
90
+ ##
91
+ self._putter = put
92
+ return put
93
+
94
+ # Methods to actually execute transformations
95
+
96
+ def put( self, v: V, s: S ) -> S:
97
+ """TODO"""
98
+ return self._putter( v, s )
99
+
100
+ def get( self, s: S ) -> V:
101
+ """TODO"""
102
+ return self( s )
103
+
104
+ # Convenience to enable calling the lens as its getter
105
+
106
+ def __call__( self, s: S ) -> V:
107
+ return self._getter( s )
108
+
109
+ # TODO Figure out how to properly parameterize this
110
+ # def _lens_factory[S, V]( register: bool = True ):
111
+ # """Register the annotated function `f` as the getter of a sample lens"""
112
+
113
+ # # The actual lens decorator taking a lens getter function to a lens object
114
+ # def _decorator( f: LensGetter[S, V] ) -> Lens[S, V]:
115
+ # ret = Lens[S, V]( f )
116
+ # if register:
117
+ # _network.register( ret )
118
+ # return ret
119
+
120
+ # # Return the lens decorator
121
+ # return _decorator
122
+
123
+ # # For convenience
124
+ # lens = _lens_factory
125
+
126
+ def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
127
+ ret = Lens[S, V]( f )
128
+ _network.register( ret )
129
+ return ret
130
+
131
+
132
+ ##
133
+ # Global registry of used lenses
134
+
135
+ # _registered_lenses: Dict[LensSignature, Lens] = dict()
136
+ # """TODO"""
137
+
138
+ class LensNetwork:
139
+ """TODO"""
140
+
141
+ _instance = None
142
+ """The singleton instance"""
143
+
144
+ def __new__(cls, *args, **kwargs):
145
+ if cls._instance is None:
146
+ # If no instance exists, create a new one
147
+ cls._instance = super().__new__(cls)
148
+ return cls._instance # Return the existing (or newly created) instance
149
+
150
+ def __init__(self):
151
+ if not hasattr(self, '_initialized'): # Check if already initialized
152
+ self._registry: Dict[LensSignature, Lens] = dict()
153
+ self._initialized = True
154
+
155
+ def register( self, _lens: Lens ):
156
+ """Set `lens` as the canonical view between its source and view types"""
157
+
158
+ # sig = inspect.signature( _lens.get )
159
+ # input_types = list( sig.parameters.values() )
160
+ # assert len( input_types ) == 1, \
161
+ # 'Wrong number of input args for lens: should only have one'
162
+
163
+ # input_type = input_types[0].annotation
164
+ # print( input_type )
165
+ # output_type = sig.return_annotation
166
+
167
+ # self._registry[input_type, output_type] = _lens
168
+ print( _lens.source_type )
169
+ self._registry[_lens.source_type, _lens.view_type] = _lens
170
+
171
+ def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
172
+ """TODO"""
173
+
174
+ # TODO Handle compositional closure
175
+ ret = self._registry.get( (source, view), None )
176
+ if ret is None:
177
+ raise ValueError( f'No registered lens from source {source} to view {view}' )
178
+
179
+ return ret
180
+
181
+
182
+ # Create global singleton registry instance
183
+ _network = LensNetwork()
184
+
185
+ # def lens( f: LensPutter ) -> Lens:
186
+ # """Register the annotated function `f` as a sample lens"""
187
+ # ##
188
+
189
+ # sig = inspect.signature( f )
190
+
191
+ # input_types = list( sig.parameters.values() )
192
+ # output_type = sig.return_annotation
193
+
194
+ # _registered_lenses[]
195
+
196
+ # f.lens = Lens(
197
+
198
+ # )
199
+
200
+ # return f
@@ -179,7 +179,7 @@ def test_wds(
179
179
  ).as_posix()
180
180
  file_wds_pattern = file_pattern.format( shard_id = '%06d' )
181
181
 
182
- with wds.ShardWriter(
182
+ with wds.writer.ShardWriter(
183
183
  pattern = file_wds_pattern,
184
184
  maxcount = shard_maxcount,
185
185
  ) as sink:
@@ -339,7 +339,7 @@ def test_wds(
339
339
  )
340
340
  for case in test_cases ]
341
341
  )
342
- def test_create_sample(
342
+ def test_parquet_export(
343
343
  SampleType: Type[atdata.PackableSample],
344
344
  sample_data: atds.MsgpackRawSample,
345
345
  sample_wds_stem: str,
@@ -360,7 +360,7 @@ def test_create_sample(
360
360
  ## Start out by writing tar dataset
361
361
 
362
362
  wds_filename = (tmp_path / f'{sample_wds_stem}.tar').as_posix()
363
- with wds.TarWriter( wds_filename ) as sink:
363
+ with wds.writer.TarWriter( wds_filename ) as sink:
364
364
  for _ in range( n_copies_dataset ):
365
365
  new_sample = SampleType.from_data( sample_data )
366
366
  sink.write( new_sample.as_wds )
@@ -371,73 +371,12 @@ def test_create_sample(
371
371
  parquet_filename = tmp_path / f'{sample_wds_stem}.parquet'
372
372
  dataset.to_parquet( parquet_filename )
373
373
 
374
+ parquet_filename = tmp_path / f'{sample_wds_stem}-segments.parquet'
375
+ dataset.to_parquet( parquet_filename, maxcount = n_per_file )
376
+
374
377
  ## Double-check our `parquet` export
375
378
 
376
379
  # TODO
377
380
 
378
- def test_lens():
379
- """Test a lens between sample types"""
380
-
381
- # Set up the lens scenario
382
-
383
- @atdata.packable
384
- class Source:
385
- name: str
386
- age: int
387
- height: float
388
-
389
- @atdata.packable
390
- class View:
391
- name: str
392
- height: float
393
-
394
- @atdata.lens
395
- def polite( s: Source ) -> View:
396
- return View(
397
- name = s.name,
398
- height = s.height,
399
- )
400
-
401
- @polite.putter
402
- def polite_update( v: View, s: Source ) -> Source:
403
- return Source(
404
- name = v.name,
405
- height = v.height,
406
- #
407
- age = s.age,
408
- )
409
-
410
- # Test with an example sample
411
-
412
- test_source = Source(
413
- name = 'Hello World',
414
- age = 42,
415
- height = 182.9,
416
- )
417
- correct_view = View(
418
- name = test_source.name,
419
- height = test_source.height,
420
- )
421
-
422
- test_view = polite( test_source )
423
- assert test_view == correct_view, \
424
- f'Incorrect lens behavior: {test_view}, and not {correct_view}'
425
-
426
- # This lens should be well-behaved
427
-
428
- update_view = View(
429
- name = 'Now Taller',
430
- height = 192.9,
431
- )
432
-
433
- x = polite( polite.put( update_view, test_source ) )
434
- assert x == update_view, \
435
- f'Violation of GetPut: {x} =/= {update_view}'
436
-
437
- y = polite.put( polite( test_source ), test_source )
438
- assert y == test_source, \
439
- f'Violation of PutGet: {y} =/= {test_source}'
440
-
441
- # TODO Test PutPut
442
381
 
443
382
  ##
@@ -0,0 +1,166 @@
1
+ """Test lens functionality."""
2
+
3
+ ##
4
+ # Imports
5
+
6
+ import pytest
7
+
8
+ from dataclasses import dataclass
9
+ import webdataset as wds
10
+ import atdata
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+
15
+
16
+ ##
17
+ # Tests
18
+
19
+ def test_lens():
20
+ """Test a lens between sample types"""
21
+
22
+ # Set up the lens scenario
23
+
24
+ @atdata.packable
25
+ class Source:
26
+ name: str
27
+ age: int
28
+ height: float
29
+
30
+ @atdata.packable
31
+ class View:
32
+ name: str
33
+ height: float
34
+
35
+ @atdata.lens
36
+ def polite( s: Source ) -> View:
37
+ return View(
38
+ name = s.name,
39
+ height = s.height,
40
+ )
41
+
42
+ @polite.putter
43
+ def polite_update( v: View, s: Source ) -> Source:
44
+ return Source(
45
+ name = v.name,
46
+ height = v.height,
47
+ #
48
+ age = s.age,
49
+ )
50
+
51
+ # Test with an example sample
52
+
53
+ test_source = Source(
54
+ name = 'Hello World',
55
+ age = 42,
56
+ height = 182.9,
57
+ )
58
+ correct_view = View(
59
+ name = test_source.name,
60
+ height = test_source.height,
61
+ )
62
+
63
+ test_view = polite( test_source )
64
+ assert test_view == correct_view, \
65
+ f'Incorrect lens behavior: {test_view}, and not {correct_view}'
66
+
67
+ # This lens should be well-behaved
68
+
69
+ update_view = View(
70
+ name = 'Now Taller',
71
+ height = 192.9,
72
+ )
73
+
74
+ x = polite( polite.put( update_view, test_source ) )
75
+ assert x == update_view, \
76
+ f'Violation of GetPut: {x} =/= {update_view}'
77
+
78
+ y = polite.put( polite( test_source ), test_source )
79
+ assert y == test_source, \
80
+ f'Violation of PutGet: {y} =/= {test_source}'
81
+
82
+ # TODO Test PutPut
83
+
84
+ def test_conversion( tmp_path ):
85
+ """Test automatic interconversion between sample types"""
86
+
87
+ @dataclass
88
+ class Source( atdata.PackableSample ):
89
+ name: str
90
+ height: float
91
+ favorite_pizza: str
92
+ favorite_image: NDArray
93
+
94
+ @dataclass
95
+ class View( atdata.PackableSample ):
96
+ name: str
97
+ favorite_pizza: str
98
+ favorite_image: NDArray
99
+
100
+ @atdata.lens
101
+ def polite( s: Source ) -> View:
102
+ return View(
103
+ name = s.name,
104
+ favorite_pizza = s.favorite_pizza,
105
+ favorite_image = s.favorite_image,
106
+ )
107
+
108
+ lens_network = atdata.LensNetwork()
109
+ print( lens_network._registry )
110
+
111
+ # Map a test sample through the view
112
+ test_source = Source(
113
+ name = 'Larry',
114
+ height = 42.,
115
+ favorite_pizza = 'pineapple',
116
+ favorite_image = np.random.randn( 224, 224 )
117
+ )
118
+ test_view = polite( test_source )
119
+
120
+ # Create a test dataset
121
+
122
+ k_test = 100
123
+ test_filename = (
124
+ tmp_path
125
+ / 'test-source.tar'
126
+ ).as_posix()
127
+
128
+ with wds.writer.TarWriter( test_filename ) as dest:
129
+ for i in range( k_test ):
130
+ # Create a new copied sample
131
+ cur_sample = Source(
132
+ name = test_source.name,
133
+ height = test_source.height,
134
+ favorite_pizza = test_source.favorite_pizza,
135
+ favorite_image = test_source.favorite_image,
136
+ )
137
+ dest.write( cur_sample.as_wds )
138
+
139
+ # Try reading the test dataset
140
+
141
+ ds = (
142
+ atdata.Dataset[Source]( test_filename )
143
+ .as_type( View )
144
+ )
145
+
146
+ assert ds.sample_type == View, \
147
+ 'Auto-mapped'
148
+
149
+ sample: View | None = None
150
+ for sample in ds.ordered( batch_size = None ):
151
+ # Load only the first sample
152
+ break
153
+
154
+ assert sample is not None, \
155
+ 'Did not load any samples from `Source` dataset'
156
+
157
+ assert sample.name == test_view.name, \
158
+ f'Divergence on auto-mapped dataset: `name` should be {test_view.name}, but is {sample.name}'
159
+ # assert sample.height == test_view.height, \
160
+ # f'Divergence on auto-mapped dataset: `height` should be {test_view.height}, but is {sample.height}'
161
+ assert sample.favorite_pizza == test_view.favorite_pizza, \
162
+ f'Divergence on auto-mapped dataset: `favorite_pizza` should be {test_view.favorite_pizza}, but is {sample.favorite_pizza}'
163
+ assert np.all( sample.favorite_image == test_view.favorite_image ), \
164
+ f'Divergence on auto-mapped dataset: `favorite_image`'
165
+
166
+ ##
@@ -1,122 +0,0 @@
1
- """Lenses between typed datasets"""
2
-
3
- ##
4
- # Imports
5
-
6
- from .dataset import PackableSample
7
-
8
- import functools
9
- import inspect
10
-
11
- from typing import (
12
- TypeAlias,
13
- Type,
14
- TypeVar,
15
- Tuple,
16
- Dict,
17
- Callable,
18
- Optional,
19
- Generic,
20
- )
21
-
22
-
23
- ##
24
- # Typing helpers
25
-
26
- DatasetType: TypeAlias = Type[PackableSample]
27
- LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
28
-
29
- S = TypeVar( 'S', bound = PackableSample )
30
- V = TypeVar( 'V', bound = PackableSample )
31
- type LensGetter[S, V] = Callable[[S], V]
32
- type LensPutter[S, V] = Callable[[V, S], S]
33
-
34
-
35
- ##
36
- # Shortcut decorators
37
-
38
- class Lens( Generic[S, V] ):
39
- """TODO"""
40
-
41
- def __init__( self, get: LensGetter[S, V],
42
- put: Optional[LensPutter[S, V]] = None
43
- ) -> None:
44
- """TODO"""
45
- ##
46
-
47
- # Update
48
- functools.update_wrapper( self, get )
49
-
50
- # Store the getter
51
- self._getter = get
52
-
53
- # Determine and store the putter
54
- if put is None:
55
- # Trivial putter does not update the source
56
- def _trivial_put( v: V, s: S ) -> S:
57
- return s
58
- put = _trivial_put
59
-
60
- self._putter = put
61
-
62
- # Register this lens for this type signature
63
-
64
- sig = inspect.signature( get )
65
- input_types = list( sig.parameters.values() )
66
- assert len( input_types ) == 1, \
67
- 'Wrong number of input args for lens: should only have one'
68
-
69
- input_type = input_types[0].annotation
70
- output_type = sig.return_annotation
71
-
72
- _registered_lenses[(input_type, output_type)] = self
73
- print( _registered_lenses )
74
-
75
- #
76
-
77
- def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
78
- """TODO"""
79
- ##
80
- self._putter = put
81
- return put
82
-
83
- def put( self, v: V, s: S ) -> S:
84
- """TODO"""
85
- return self._putter( v, s )
86
-
87
- def get( self, s: S ) -> V:
88
- """TODO"""
89
- return self( s )
90
-
91
- #
92
-
93
- def __call__( self, s: S ) -> V:
94
- return self._getter( s )
95
-
96
- def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
97
- """Register the annotated function `f` as the getter of a sample lens"""
98
- return Lens[S, V]( f )
99
-
100
-
101
- ##
102
- # Global registration of used lenses
103
-
104
- _registered_lenses: Dict[LensSignature, Lens] = dict()
105
- """TODO"""
106
-
107
- # def lens( f: LensPutter ) -> Lens:
108
- # """Register the annotated function `f` as a sample lens"""
109
- # ##
110
-
111
- # sig = inspect.signature( f )
112
-
113
- # input_types = list( sig.parameters.values() )
114
- # output_type = sig.return_annotation
115
-
116
- # _registered_lenses[]
117
-
118
- # f.lens = Lens(
119
-
120
- # )
121
-
122
- # return f
File without changes
File without changes
File without changes
File without changes