atdata 0.1.2b1__tar.gz → 0.1.3a2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atdata-0.1.2b1 → atdata-0.1.3a2}/PKG-INFO +1 -1
- {atdata-0.1.2b1 → atdata-0.1.3a2}/pyproject.toml +1 -1
- {atdata-0.1.2b1 → atdata-0.1.3a2}/src/atdata/__init__.py +1 -0
- {atdata-0.1.2b1 → atdata-0.1.3a2}/src/atdata/dataset.py +33 -15
- atdata-0.1.3a2/src/atdata/lens.py +200 -0
- {atdata-0.1.2b1 → atdata-0.1.3a2}/tests/test_dataset.py +6 -67
- atdata-0.1.3a2/tests/test_lens.py +166 -0
- atdata-0.1.2b1/src/atdata/lens.py +0 -122
- {atdata-0.1.2b1 → atdata-0.1.3a2}/.github/workflows/uv-publish-pypi.yml +0 -0
- {atdata-0.1.2b1 → atdata-0.1.3a2}/.github/workflows/uv-test.yml +0 -0
- {atdata-0.1.2b1 → atdata-0.1.3a2}/.gitignore +0 -0
- {atdata-0.1.2b1 → atdata-0.1.3a2}/.python-version +0 -0
- {atdata-0.1.2b1 → atdata-0.1.3a2}/LICENSE +0 -0
- {atdata-0.1.2b1 → atdata-0.1.3a2}/README.md +0 -0
- {atdata-0.1.2b1 → atdata-0.1.3a2}/src/atdata/_helpers.py +0 -0
|
@@ -48,6 +48,7 @@ from numpy.typing import (
|
|
|
48
48
|
import msgpack
|
|
49
49
|
import ormsgpack
|
|
50
50
|
from . import _helpers as eh
|
|
51
|
+
from .lens import Lens, LensNetwork
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
##
|
|
@@ -231,6 +232,8 @@ class SampleBatch( Generic[DT] ):
|
|
|
231
232
|
ST = TypeVar( 'ST', bound = PackableSample )
|
|
232
233
|
# BT = TypeVar( 'BT' )
|
|
233
234
|
|
|
235
|
+
RT = TypeVar( 'RT', bound = PackableSample )
|
|
236
|
+
|
|
234
237
|
# TODO For python 3.13
|
|
235
238
|
# BT = TypeVar( 'BT', default = None )
|
|
236
239
|
# IT = TypeVar( 'IT', default = Any )
|
|
@@ -268,6 +271,17 @@ class Dataset( Generic[ST] ):
|
|
|
268
271
|
super().__init__()
|
|
269
272
|
self.url = url
|
|
270
273
|
|
|
274
|
+
# Allow addition of automatic transformation of raw underlying data
|
|
275
|
+
self._output_lens: Lens | None = None
|
|
276
|
+
|
|
277
|
+
def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
|
|
278
|
+
"""TODO"""
|
|
279
|
+
ret = Dataset[other]( self.url )
|
|
280
|
+
# Get the singleton lens registry
|
|
281
|
+
lenses = LensNetwork()
|
|
282
|
+
ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
|
|
283
|
+
return ret
|
|
284
|
+
|
|
271
285
|
# @classmethod
|
|
272
286
|
# def register( cls, uri: str,
|
|
273
287
|
# sample_class: Type,
|
|
@@ -293,9 +307,9 @@ class Dataset( Generic[ST] ):
|
|
|
293
307
|
A full (non-lazy) list of the individual ``tar`` files within the
|
|
294
308
|
source WebDataset.
|
|
295
309
|
"""
|
|
296
|
-
pipe = wds.DataPipeline(
|
|
297
|
-
wds.SimpleShardList( self.url ),
|
|
298
|
-
wds.map( lambda x: x['url'] )
|
|
310
|
+
pipe = wds.pipeline.DataPipeline(
|
|
311
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
312
|
+
wds.filters.map( lambda x: x['url'] )
|
|
299
313
|
)
|
|
300
314
|
return list( pipe )
|
|
301
315
|
|
|
@@ -317,23 +331,23 @@ class Dataset( Generic[ST] ):
|
|
|
317
331
|
|
|
318
332
|
if batch_size is None:
|
|
319
333
|
# TODO Duplication here
|
|
320
|
-
return wds.DataPipeline(
|
|
321
|
-
wds.SimpleShardList( self.url ),
|
|
322
|
-
wds.split_by_worker,
|
|
334
|
+
return wds.pipeline.DataPipeline(
|
|
335
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
336
|
+
wds.shardlists.split_by_worker,
|
|
323
337
|
#
|
|
324
|
-
wds.tarfile_to_samples(),
|
|
338
|
+
wds.tariterators.tarfile_to_samples(),
|
|
325
339
|
# wds.map( self.preprocess ),
|
|
326
|
-
wds.map( self.wrap ),
|
|
340
|
+
wds.filters.map( self.wrap ),
|
|
327
341
|
)
|
|
328
342
|
|
|
329
|
-
return wds.DataPipeline(
|
|
330
|
-
wds.SimpleShardList( self.url ),
|
|
331
|
-
wds.split_by_worker,
|
|
343
|
+
return wds.pipeline.DataPipeline(
|
|
344
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
345
|
+
wds.shardlists.split_by_worker,
|
|
332
346
|
#
|
|
333
|
-
wds.tarfile_to_samples(),
|
|
347
|
+
wds.tariterators.tarfile_to_samples(),
|
|
334
348
|
# wds.map( self.preprocess ),
|
|
335
|
-
wds.batched( batch_size ),
|
|
336
|
-
wds.map( self.wrap_batch ),
|
|
349
|
+
wds.filters.batched( batch_size ),
|
|
350
|
+
wds.filters.map( self.wrap_batch ),
|
|
337
351
|
)
|
|
338
352
|
|
|
339
353
|
def shuffled( self,
|
|
@@ -461,7 +475,11 @@ class Dataset( Generic[ST] ):
|
|
|
461
475
|
assert 'msgpack' in sample
|
|
462
476
|
assert type( sample['msgpack'] ) == bytes
|
|
463
477
|
|
|
464
|
-
|
|
478
|
+
if self._output_lens is None:
|
|
479
|
+
return self.sample_type.from_bytes( sample['msgpack'] )
|
|
480
|
+
|
|
481
|
+
source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
|
|
482
|
+
return self._output_lens( source_sample )
|
|
465
483
|
|
|
466
484
|
# try:
|
|
467
485
|
# assert type( sample ) == dict
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Lenses between typed datasets"""
|
|
2
|
+
|
|
3
|
+
##
|
|
4
|
+
# Imports
|
|
5
|
+
|
|
6
|
+
import functools
|
|
7
|
+
import inspect
|
|
8
|
+
|
|
9
|
+
from typing import (
|
|
10
|
+
TypeAlias,
|
|
11
|
+
Type,
|
|
12
|
+
TypeVar,
|
|
13
|
+
Tuple,
|
|
14
|
+
Dict,
|
|
15
|
+
Callable,
|
|
16
|
+
Optional,
|
|
17
|
+
Generic,
|
|
18
|
+
#
|
|
19
|
+
TYPE_CHECKING
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from .dataset import PackableSample
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
##
|
|
27
|
+
# Typing helpers
|
|
28
|
+
|
|
29
|
+
DatasetType: TypeAlias = Type['PackableSample']
|
|
30
|
+
LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
|
|
31
|
+
|
|
32
|
+
S = TypeVar( 'S', bound = 'PackableSample' )
|
|
33
|
+
V = TypeVar( 'V', bound = 'PackableSample' )
|
|
34
|
+
type LensGetter[S, V] = Callable[[S], V]
|
|
35
|
+
type LensPutter[S, V] = Callable[[V, S], S]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
##
|
|
39
|
+
# Shortcut decorators
|
|
40
|
+
|
|
41
|
+
class Lens( Generic[S, V] ):
|
|
42
|
+
"""TODO"""
|
|
43
|
+
|
|
44
|
+
# @property
|
|
45
|
+
# def source_type( self ) -> Type[S]:
|
|
46
|
+
# """The source type (S) for the lens; what is put to"""
|
|
47
|
+
# # TODO Figure out why linting fails here
|
|
48
|
+
# return self.__orig_class__.__args__[0]
|
|
49
|
+
|
|
50
|
+
# @property
|
|
51
|
+
# def view_type( self ) -> Type[V]:
|
|
52
|
+
# """The view type (V) for the lens; what is get'd from"""
|
|
53
|
+
# # TODO FIgure out why linting fails here
|
|
54
|
+
# return self.__orig_class__.__args__[1]
|
|
55
|
+
|
|
56
|
+
def __init__( self, get: LensGetter[S, V],
|
|
57
|
+
put: Optional[LensPutter[S, V]] = None
|
|
58
|
+
) -> None:
|
|
59
|
+
"""TODO"""
|
|
60
|
+
##
|
|
61
|
+
|
|
62
|
+
# Check argument validity
|
|
63
|
+
|
|
64
|
+
sig = inspect.signature( get )
|
|
65
|
+
input_types = list( sig.parameters.values() )
|
|
66
|
+
assert len( input_types ) == 1, \
|
|
67
|
+
'Wrong number of input args for lens: should only have one'
|
|
68
|
+
|
|
69
|
+
# Update function details for this object as returned by annotation
|
|
70
|
+
functools.update_wrapper( self, get )
|
|
71
|
+
|
|
72
|
+
self.source_type: Type[PackableSample] = input_types[0].annotation
|
|
73
|
+
self.view_type = sig.return_annotation
|
|
74
|
+
|
|
75
|
+
# Store the getter
|
|
76
|
+
self._getter = get
|
|
77
|
+
|
|
78
|
+
# Determine and store the putter
|
|
79
|
+
if put is None:
|
|
80
|
+
# Trivial putter does not update the source
|
|
81
|
+
def _trivial_put( v: V, s: S ) -> S:
|
|
82
|
+
return s
|
|
83
|
+
put = _trivial_put
|
|
84
|
+
self._putter = put
|
|
85
|
+
|
|
86
|
+
#
|
|
87
|
+
|
|
88
|
+
def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
|
|
89
|
+
"""TODO"""
|
|
90
|
+
##
|
|
91
|
+
self._putter = put
|
|
92
|
+
return put
|
|
93
|
+
|
|
94
|
+
# Methods to actually execute transformations
|
|
95
|
+
|
|
96
|
+
def put( self, v: V, s: S ) -> S:
|
|
97
|
+
"""TODO"""
|
|
98
|
+
return self._putter( v, s )
|
|
99
|
+
|
|
100
|
+
def get( self, s: S ) -> V:
|
|
101
|
+
"""TODO"""
|
|
102
|
+
return self( s )
|
|
103
|
+
|
|
104
|
+
# Convenience to enable calling the lens as its getter
|
|
105
|
+
|
|
106
|
+
def __call__( self, s: S ) -> V:
|
|
107
|
+
return self._getter( s )
|
|
108
|
+
|
|
109
|
+
# TODO Figure out how to properly parameterize this
|
|
110
|
+
# def _lens_factory[S, V]( register: bool = True ):
|
|
111
|
+
# """Register the annotated function `f` as the getter of a sample lens"""
|
|
112
|
+
|
|
113
|
+
# # The actual lens decorator taking a lens getter function to a lens object
|
|
114
|
+
# def _decorator( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
115
|
+
# ret = Lens[S, V]( f )
|
|
116
|
+
# if register:
|
|
117
|
+
# _network.register( ret )
|
|
118
|
+
# return ret
|
|
119
|
+
|
|
120
|
+
# # Return the lens decorator
|
|
121
|
+
# return _decorator
|
|
122
|
+
|
|
123
|
+
# # For convenience
|
|
124
|
+
# lens = _lens_factory
|
|
125
|
+
|
|
126
|
+
def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
127
|
+
ret = Lens[S, V]( f )
|
|
128
|
+
_network.register( ret )
|
|
129
|
+
return ret
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
##
|
|
133
|
+
# Global registry of used lenses
|
|
134
|
+
|
|
135
|
+
# _registered_lenses: Dict[LensSignature, Lens] = dict()
|
|
136
|
+
# """TODO"""
|
|
137
|
+
|
|
138
|
+
class LensNetwork:
|
|
139
|
+
"""TODO"""
|
|
140
|
+
|
|
141
|
+
_instance = None
|
|
142
|
+
"""The singleton instance"""
|
|
143
|
+
|
|
144
|
+
def __new__(cls, *args, **kwargs):
|
|
145
|
+
if cls._instance is None:
|
|
146
|
+
# If no instance exists, create a new one
|
|
147
|
+
cls._instance = super().__new__(cls)
|
|
148
|
+
return cls._instance # Return the existing (or newly created) instance
|
|
149
|
+
|
|
150
|
+
def __init__(self):
|
|
151
|
+
if not hasattr(self, '_initialized'): # Check if already initialized
|
|
152
|
+
self._registry: Dict[LensSignature, Lens] = dict()
|
|
153
|
+
self._initialized = True
|
|
154
|
+
|
|
155
|
+
def register( self, _lens: Lens ):
|
|
156
|
+
"""Set `lens` as the canonical view between its source and view types"""
|
|
157
|
+
|
|
158
|
+
# sig = inspect.signature( _lens.get )
|
|
159
|
+
# input_types = list( sig.parameters.values() )
|
|
160
|
+
# assert len( input_types ) == 1, \
|
|
161
|
+
# 'Wrong number of input args for lens: should only have one'
|
|
162
|
+
|
|
163
|
+
# input_type = input_types[0].annotation
|
|
164
|
+
# print( input_type )
|
|
165
|
+
# output_type = sig.return_annotation
|
|
166
|
+
|
|
167
|
+
# self._registry[input_type, output_type] = _lens
|
|
168
|
+
print( _lens.source_type )
|
|
169
|
+
self._registry[_lens.source_type, _lens.view_type] = _lens
|
|
170
|
+
|
|
171
|
+
def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
|
|
172
|
+
"""TODO"""
|
|
173
|
+
|
|
174
|
+
# TODO Handle compositional closure
|
|
175
|
+
ret = self._registry.get( (source, view), None )
|
|
176
|
+
if ret is None:
|
|
177
|
+
raise ValueError( f'No registered lens from source {source} to view {view}' )
|
|
178
|
+
|
|
179
|
+
return ret
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# Create global singleton registry instance
|
|
183
|
+
_network = LensNetwork()
|
|
184
|
+
|
|
185
|
+
# def lens( f: LensPutter ) -> Lens:
|
|
186
|
+
# """Register the annotated function `f` as a sample lens"""
|
|
187
|
+
# ##
|
|
188
|
+
|
|
189
|
+
# sig = inspect.signature( f )
|
|
190
|
+
|
|
191
|
+
# input_types = list( sig.parameters.values() )
|
|
192
|
+
# output_type = sig.return_annotation
|
|
193
|
+
|
|
194
|
+
# _registered_lenses[]
|
|
195
|
+
|
|
196
|
+
# f.lens = Lens(
|
|
197
|
+
|
|
198
|
+
# )
|
|
199
|
+
|
|
200
|
+
# return f
|
|
@@ -179,7 +179,7 @@ def test_wds(
|
|
|
179
179
|
).as_posix()
|
|
180
180
|
file_wds_pattern = file_pattern.format( shard_id = '%06d' )
|
|
181
181
|
|
|
182
|
-
with wds.ShardWriter(
|
|
182
|
+
with wds.writer.ShardWriter(
|
|
183
183
|
pattern = file_wds_pattern,
|
|
184
184
|
maxcount = shard_maxcount,
|
|
185
185
|
) as sink:
|
|
@@ -339,7 +339,7 @@ def test_wds(
|
|
|
339
339
|
)
|
|
340
340
|
for case in test_cases ]
|
|
341
341
|
)
|
|
342
|
-
def
|
|
342
|
+
def test_parquet_export(
|
|
343
343
|
SampleType: Type[atdata.PackableSample],
|
|
344
344
|
sample_data: atds.MsgpackRawSample,
|
|
345
345
|
sample_wds_stem: str,
|
|
@@ -360,7 +360,7 @@ def test_create_sample(
|
|
|
360
360
|
## Start out by writing tar dataset
|
|
361
361
|
|
|
362
362
|
wds_filename = (tmp_path / f'{sample_wds_stem}.tar').as_posix()
|
|
363
|
-
with wds.TarWriter( wds_filename ) as sink:
|
|
363
|
+
with wds.writer.TarWriter( wds_filename ) as sink:
|
|
364
364
|
for _ in range( n_copies_dataset ):
|
|
365
365
|
new_sample = SampleType.from_data( sample_data )
|
|
366
366
|
sink.write( new_sample.as_wds )
|
|
@@ -371,73 +371,12 @@ def test_create_sample(
|
|
|
371
371
|
parquet_filename = tmp_path / f'{sample_wds_stem}.parquet'
|
|
372
372
|
dataset.to_parquet( parquet_filename )
|
|
373
373
|
|
|
374
|
+
parquet_filename = tmp_path / f'{sample_wds_stem}-segments.parquet'
|
|
375
|
+
dataset.to_parquet( parquet_filename, maxcount = n_per_file )
|
|
376
|
+
|
|
374
377
|
## Double-check our `parquet` export
|
|
375
378
|
|
|
376
379
|
# TODO
|
|
377
380
|
|
|
378
|
-
def test_lens():
|
|
379
|
-
"""Test a lens between sample types"""
|
|
380
|
-
|
|
381
|
-
# Set up the lens scenario
|
|
382
|
-
|
|
383
|
-
@atdata.packable
|
|
384
|
-
class Source:
|
|
385
|
-
name: str
|
|
386
|
-
age: int
|
|
387
|
-
height: float
|
|
388
|
-
|
|
389
|
-
@atdata.packable
|
|
390
|
-
class View:
|
|
391
|
-
name: str
|
|
392
|
-
height: float
|
|
393
|
-
|
|
394
|
-
@atdata.lens
|
|
395
|
-
def polite( s: Source ) -> View:
|
|
396
|
-
return View(
|
|
397
|
-
name = s.name,
|
|
398
|
-
height = s.height,
|
|
399
|
-
)
|
|
400
|
-
|
|
401
|
-
@polite.putter
|
|
402
|
-
def polite_update( v: View, s: Source ) -> Source:
|
|
403
|
-
return Source(
|
|
404
|
-
name = v.name,
|
|
405
|
-
height = v.height,
|
|
406
|
-
#
|
|
407
|
-
age = s.age,
|
|
408
|
-
)
|
|
409
|
-
|
|
410
|
-
# Test with an example sample
|
|
411
|
-
|
|
412
|
-
test_source = Source(
|
|
413
|
-
name = 'Hello World',
|
|
414
|
-
age = 42,
|
|
415
|
-
height = 182.9,
|
|
416
|
-
)
|
|
417
|
-
correct_view = View(
|
|
418
|
-
name = test_source.name,
|
|
419
|
-
height = test_source.height,
|
|
420
|
-
)
|
|
421
|
-
|
|
422
|
-
test_view = polite( test_source )
|
|
423
|
-
assert test_view == correct_view, \
|
|
424
|
-
f'Incorrect lens behavior: {test_view}, and not {correct_view}'
|
|
425
|
-
|
|
426
|
-
# This lens should be well-behaved
|
|
427
|
-
|
|
428
|
-
update_view = View(
|
|
429
|
-
name = 'Now Taller',
|
|
430
|
-
height = 192.9,
|
|
431
|
-
)
|
|
432
|
-
|
|
433
|
-
x = polite( polite.put( update_view, test_source ) )
|
|
434
|
-
assert x == update_view, \
|
|
435
|
-
f'Violation of GetPut: {x} =/= {update_view}'
|
|
436
|
-
|
|
437
|
-
y = polite.put( polite( test_source ), test_source )
|
|
438
|
-
assert y == test_source, \
|
|
439
|
-
f'Violation of PutGet: {y} =/= {test_source}'
|
|
440
|
-
|
|
441
|
-
# TODO Test PutPut
|
|
442
381
|
|
|
443
382
|
##
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Test lens functionality."""
|
|
2
|
+
|
|
3
|
+
##
|
|
4
|
+
# Imports
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
import webdataset as wds
|
|
10
|
+
import atdata
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from numpy.typing import NDArray
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
##
|
|
17
|
+
# Tests
|
|
18
|
+
|
|
19
|
+
def test_lens():
|
|
20
|
+
"""Test a lens between sample types"""
|
|
21
|
+
|
|
22
|
+
# Set up the lens scenario
|
|
23
|
+
|
|
24
|
+
@atdata.packable
|
|
25
|
+
class Source:
|
|
26
|
+
name: str
|
|
27
|
+
age: int
|
|
28
|
+
height: float
|
|
29
|
+
|
|
30
|
+
@atdata.packable
|
|
31
|
+
class View:
|
|
32
|
+
name: str
|
|
33
|
+
height: float
|
|
34
|
+
|
|
35
|
+
@atdata.lens
|
|
36
|
+
def polite( s: Source ) -> View:
|
|
37
|
+
return View(
|
|
38
|
+
name = s.name,
|
|
39
|
+
height = s.height,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@polite.putter
|
|
43
|
+
def polite_update( v: View, s: Source ) -> Source:
|
|
44
|
+
return Source(
|
|
45
|
+
name = v.name,
|
|
46
|
+
height = v.height,
|
|
47
|
+
#
|
|
48
|
+
age = s.age,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Test with an example sample
|
|
52
|
+
|
|
53
|
+
test_source = Source(
|
|
54
|
+
name = 'Hello World',
|
|
55
|
+
age = 42,
|
|
56
|
+
height = 182.9,
|
|
57
|
+
)
|
|
58
|
+
correct_view = View(
|
|
59
|
+
name = test_source.name,
|
|
60
|
+
height = test_source.height,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
test_view = polite( test_source )
|
|
64
|
+
assert test_view == correct_view, \
|
|
65
|
+
f'Incorrect lens behavior: {test_view}, and not {correct_view}'
|
|
66
|
+
|
|
67
|
+
# This lens should be well-behaved
|
|
68
|
+
|
|
69
|
+
update_view = View(
|
|
70
|
+
name = 'Now Taller',
|
|
71
|
+
height = 192.9,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
x = polite( polite.put( update_view, test_source ) )
|
|
75
|
+
assert x == update_view, \
|
|
76
|
+
f'Violation of GetPut: {x} =/= {update_view}'
|
|
77
|
+
|
|
78
|
+
y = polite.put( polite( test_source ), test_source )
|
|
79
|
+
assert y == test_source, \
|
|
80
|
+
f'Violation of PutGet: {y} =/= {test_source}'
|
|
81
|
+
|
|
82
|
+
# TODO Test PutPut
|
|
83
|
+
|
|
84
|
+
def test_conversion( tmp_path ):
|
|
85
|
+
"""Test automatic interconversion between sample types"""
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class Source( atdata.PackableSample ):
|
|
89
|
+
name: str
|
|
90
|
+
height: float
|
|
91
|
+
favorite_pizza: str
|
|
92
|
+
favorite_image: NDArray
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class View( atdata.PackableSample ):
|
|
96
|
+
name: str
|
|
97
|
+
favorite_pizza: str
|
|
98
|
+
favorite_image: NDArray
|
|
99
|
+
|
|
100
|
+
@atdata.lens
|
|
101
|
+
def polite( s: Source ) -> View:
|
|
102
|
+
return View(
|
|
103
|
+
name = s.name,
|
|
104
|
+
favorite_pizza = s.favorite_pizza,
|
|
105
|
+
favorite_image = s.favorite_image,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
lens_network = atdata.LensNetwork()
|
|
109
|
+
print( lens_network._registry )
|
|
110
|
+
|
|
111
|
+
# Map a test sample through the view
|
|
112
|
+
test_source = Source(
|
|
113
|
+
name = 'Larry',
|
|
114
|
+
height = 42.,
|
|
115
|
+
favorite_pizza = 'pineapple',
|
|
116
|
+
favorite_image = np.random.randn( 224, 224 )
|
|
117
|
+
)
|
|
118
|
+
test_view = polite( test_source )
|
|
119
|
+
|
|
120
|
+
# Create a test dataset
|
|
121
|
+
|
|
122
|
+
k_test = 100
|
|
123
|
+
test_filename = (
|
|
124
|
+
tmp_path
|
|
125
|
+
/ 'test-source.tar'
|
|
126
|
+
).as_posix()
|
|
127
|
+
|
|
128
|
+
with wds.writer.TarWriter( test_filename ) as dest:
|
|
129
|
+
for i in range( k_test ):
|
|
130
|
+
# Create a new copied sample
|
|
131
|
+
cur_sample = Source(
|
|
132
|
+
name = test_source.name,
|
|
133
|
+
height = test_source.height,
|
|
134
|
+
favorite_pizza = test_source.favorite_pizza,
|
|
135
|
+
favorite_image = test_source.favorite_image,
|
|
136
|
+
)
|
|
137
|
+
dest.write( cur_sample.as_wds )
|
|
138
|
+
|
|
139
|
+
# Try reading the test dataset
|
|
140
|
+
|
|
141
|
+
ds = (
|
|
142
|
+
atdata.Dataset[Source]( test_filename )
|
|
143
|
+
.as_type( View )
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
assert ds.sample_type == View, \
|
|
147
|
+
'Auto-mapped'
|
|
148
|
+
|
|
149
|
+
sample: View | None = None
|
|
150
|
+
for sample in ds.ordered( batch_size = None ):
|
|
151
|
+
# Load only the first sample
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
assert sample is not None, \
|
|
155
|
+
'Did not load any samples from `Source` dataset'
|
|
156
|
+
|
|
157
|
+
assert sample.name == test_view.name, \
|
|
158
|
+
f'Divergence on auto-mapped dataset: `name` should be {test_view.name}, but is {sample.name}'
|
|
159
|
+
# assert sample.height == test_view.height, \
|
|
160
|
+
# f'Divergence on auto-mapped dataset: `height` should be {test_view.height}, but is {sample.height}'
|
|
161
|
+
assert sample.favorite_pizza == test_view.favorite_pizza, \
|
|
162
|
+
f'Divergence on auto-mapped dataset: `favorite_pizza` should be {test_view.favorite_pizza}, but is {sample.favorite_pizza}'
|
|
163
|
+
assert np.all( sample.favorite_image == test_view.favorite_image ), \
|
|
164
|
+
f'Divergence on auto-mapped dataset: `favorite_image`'
|
|
165
|
+
|
|
166
|
+
##
|
|
@@ -1,122 +0,0 @@
|
|
|
1
|
-
"""Lenses between typed datasets"""
|
|
2
|
-
|
|
3
|
-
##
|
|
4
|
-
# Imports
|
|
5
|
-
|
|
6
|
-
from .dataset import PackableSample
|
|
7
|
-
|
|
8
|
-
import functools
|
|
9
|
-
import inspect
|
|
10
|
-
|
|
11
|
-
from typing import (
|
|
12
|
-
TypeAlias,
|
|
13
|
-
Type,
|
|
14
|
-
TypeVar,
|
|
15
|
-
Tuple,
|
|
16
|
-
Dict,
|
|
17
|
-
Callable,
|
|
18
|
-
Optional,
|
|
19
|
-
Generic,
|
|
20
|
-
)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
##
|
|
24
|
-
# Typing helpers
|
|
25
|
-
|
|
26
|
-
DatasetType: TypeAlias = Type[PackableSample]
|
|
27
|
-
LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
|
|
28
|
-
|
|
29
|
-
S = TypeVar( 'S', bound = PackableSample )
|
|
30
|
-
V = TypeVar( 'V', bound = PackableSample )
|
|
31
|
-
type LensGetter[S, V] = Callable[[S], V]
|
|
32
|
-
type LensPutter[S, V] = Callable[[V, S], S]
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
##
|
|
36
|
-
# Shortcut decorators
|
|
37
|
-
|
|
38
|
-
class Lens( Generic[S, V] ):
|
|
39
|
-
"""TODO"""
|
|
40
|
-
|
|
41
|
-
def __init__( self, get: LensGetter[S, V],
|
|
42
|
-
put: Optional[LensPutter[S, V]] = None
|
|
43
|
-
) -> None:
|
|
44
|
-
"""TODO"""
|
|
45
|
-
##
|
|
46
|
-
|
|
47
|
-
# Update
|
|
48
|
-
functools.update_wrapper( self, get )
|
|
49
|
-
|
|
50
|
-
# Store the getter
|
|
51
|
-
self._getter = get
|
|
52
|
-
|
|
53
|
-
# Determine and store the putter
|
|
54
|
-
if put is None:
|
|
55
|
-
# Trivial putter does not update the source
|
|
56
|
-
def _trivial_put( v: V, s: S ) -> S:
|
|
57
|
-
return s
|
|
58
|
-
put = _trivial_put
|
|
59
|
-
|
|
60
|
-
self._putter = put
|
|
61
|
-
|
|
62
|
-
# Register this lens for this type signature
|
|
63
|
-
|
|
64
|
-
sig = inspect.signature( get )
|
|
65
|
-
input_types = list( sig.parameters.values() )
|
|
66
|
-
assert len( input_types ) == 1, \
|
|
67
|
-
'Wrong number of input args for lens: should only have one'
|
|
68
|
-
|
|
69
|
-
input_type = input_types[0].annotation
|
|
70
|
-
output_type = sig.return_annotation
|
|
71
|
-
|
|
72
|
-
_registered_lenses[(input_type, output_type)] = self
|
|
73
|
-
print( _registered_lenses )
|
|
74
|
-
|
|
75
|
-
#
|
|
76
|
-
|
|
77
|
-
def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
|
|
78
|
-
"""TODO"""
|
|
79
|
-
##
|
|
80
|
-
self._putter = put
|
|
81
|
-
return put
|
|
82
|
-
|
|
83
|
-
def put( self, v: V, s: S ) -> S:
|
|
84
|
-
"""TODO"""
|
|
85
|
-
return self._putter( v, s )
|
|
86
|
-
|
|
87
|
-
def get( self, s: S ) -> V:
|
|
88
|
-
"""TODO"""
|
|
89
|
-
return self( s )
|
|
90
|
-
|
|
91
|
-
#
|
|
92
|
-
|
|
93
|
-
def __call__( self, s: S ) -> V:
|
|
94
|
-
return self._getter( s )
|
|
95
|
-
|
|
96
|
-
def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
97
|
-
"""Register the annotated function `f` as the getter of a sample lens"""
|
|
98
|
-
return Lens[S, V]( f )
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
##
|
|
102
|
-
# Global registration of used lenses
|
|
103
|
-
|
|
104
|
-
_registered_lenses: Dict[LensSignature, Lens] = dict()
|
|
105
|
-
"""TODO"""
|
|
106
|
-
|
|
107
|
-
# def lens( f: LensPutter ) -> Lens:
|
|
108
|
-
# """Register the annotated function `f` as a sample lens"""
|
|
109
|
-
# ##
|
|
110
|
-
|
|
111
|
-
# sig = inspect.signature( f )
|
|
112
|
-
|
|
113
|
-
# input_types = list( sig.parameters.values() )
|
|
114
|
-
# output_type = sig.return_annotation
|
|
115
|
-
|
|
116
|
-
# _registered_lenses[]
|
|
117
|
-
|
|
118
|
-
# f.lens = Lens(
|
|
119
|
-
|
|
120
|
-
# )
|
|
121
|
-
|
|
122
|
-
# return f
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|