atdata 0.1.2b1__py3-none-any.whl → 0.1.3a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +1 -0
- atdata/dataset.py +33 -15
- atdata/lens.py +105 -27
- {atdata-0.1.2b1.dist-info → atdata-0.1.3a2.dist-info}/METADATA +1 -1
- atdata-0.1.3a2.dist-info/RECORD +9 -0
- atdata-0.1.2b1.dist-info/RECORD +0 -9
- {atdata-0.1.2b1.dist-info → atdata-0.1.3a2.dist-info}/WHEEL +0 -0
- {atdata-0.1.2b1.dist-info → atdata-0.1.3a2.dist-info}/entry_points.txt +0 -0
- {atdata-0.1.2b1.dist-info → atdata-0.1.3a2.dist-info}/licenses/LICENSE +0 -0
atdata/__init__.py
CHANGED
atdata/dataset.py
CHANGED
|
@@ -48,6 +48,7 @@ from numpy.typing import (
|
|
|
48
48
|
import msgpack
|
|
49
49
|
import ormsgpack
|
|
50
50
|
from . import _helpers as eh
|
|
51
|
+
from .lens import Lens, LensNetwork
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
##
|
|
@@ -231,6 +232,8 @@ class SampleBatch( Generic[DT] ):
|
|
|
231
232
|
ST = TypeVar( 'ST', bound = PackableSample )
|
|
232
233
|
# BT = TypeVar( 'BT' )
|
|
233
234
|
|
|
235
|
+
RT = TypeVar( 'RT', bound = PackableSample )
|
|
236
|
+
|
|
234
237
|
# TODO For python 3.13
|
|
235
238
|
# BT = TypeVar( 'BT', default = None )
|
|
236
239
|
# IT = TypeVar( 'IT', default = Any )
|
|
@@ -268,6 +271,17 @@ class Dataset( Generic[ST] ):
|
|
|
268
271
|
super().__init__()
|
|
269
272
|
self.url = url
|
|
270
273
|
|
|
274
|
+
# Allow addition of automatic transformation of raw underlying data
|
|
275
|
+
self._output_lens: Lens | None = None
|
|
276
|
+
|
|
277
|
+
def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
|
|
278
|
+
"""TODO"""
|
|
279
|
+
ret = Dataset[other]( self.url )
|
|
280
|
+
# Get the singleton lens registry
|
|
281
|
+
lenses = LensNetwork()
|
|
282
|
+
ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
|
|
283
|
+
return ret
|
|
284
|
+
|
|
271
285
|
# @classmethod
|
|
272
286
|
# def register( cls, uri: str,
|
|
273
287
|
# sample_class: Type,
|
|
@@ -293,9 +307,9 @@ class Dataset( Generic[ST] ):
|
|
|
293
307
|
A full (non-lazy) list of the individual ``tar`` files within the
|
|
294
308
|
source WebDataset.
|
|
295
309
|
"""
|
|
296
|
-
pipe = wds.DataPipeline(
|
|
297
|
-
wds.SimpleShardList( self.url ),
|
|
298
|
-
wds.map( lambda x: x['url'] )
|
|
310
|
+
pipe = wds.pipeline.DataPipeline(
|
|
311
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
312
|
+
wds.filters.map( lambda x: x['url'] )
|
|
299
313
|
)
|
|
300
314
|
return list( pipe )
|
|
301
315
|
|
|
@@ -317,23 +331,23 @@ class Dataset( Generic[ST] ):
|
|
|
317
331
|
|
|
318
332
|
if batch_size is None:
|
|
319
333
|
# TODO Duplication here
|
|
320
|
-
return wds.DataPipeline(
|
|
321
|
-
wds.SimpleShardList( self.url ),
|
|
322
|
-
wds.split_by_worker,
|
|
334
|
+
return wds.pipeline.DataPipeline(
|
|
335
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
336
|
+
wds.shardlists.split_by_worker,
|
|
323
337
|
#
|
|
324
|
-
wds.tarfile_to_samples(),
|
|
338
|
+
wds.tariterators.tarfile_to_samples(),
|
|
325
339
|
# wds.map( self.preprocess ),
|
|
326
|
-
wds.map( self.wrap ),
|
|
340
|
+
wds.filters.map( self.wrap ),
|
|
327
341
|
)
|
|
328
342
|
|
|
329
|
-
return wds.DataPipeline(
|
|
330
|
-
wds.SimpleShardList( self.url ),
|
|
331
|
-
wds.split_by_worker,
|
|
343
|
+
return wds.pipeline.DataPipeline(
|
|
344
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
345
|
+
wds.shardlists.split_by_worker,
|
|
332
346
|
#
|
|
333
|
-
wds.tarfile_to_samples(),
|
|
347
|
+
wds.tariterators.tarfile_to_samples(),
|
|
334
348
|
# wds.map( self.preprocess ),
|
|
335
|
-
wds.batched( batch_size ),
|
|
336
|
-
wds.map( self.wrap_batch ),
|
|
349
|
+
wds.filters.batched( batch_size ),
|
|
350
|
+
wds.filters.map( self.wrap_batch ),
|
|
337
351
|
)
|
|
338
352
|
|
|
339
353
|
def shuffled( self,
|
|
@@ -461,7 +475,11 @@ class Dataset( Generic[ST] ):
|
|
|
461
475
|
assert 'msgpack' in sample
|
|
462
476
|
assert type( sample['msgpack'] ) == bytes
|
|
463
477
|
|
|
464
|
-
|
|
478
|
+
if self._output_lens is None:
|
|
479
|
+
return self.sample_type.from_bytes( sample['msgpack'] )
|
|
480
|
+
|
|
481
|
+
source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
|
|
482
|
+
return self._output_lens( source_sample )
|
|
465
483
|
|
|
466
484
|
# try:
|
|
467
485
|
# assert type( sample ) == dict
|
atdata/lens.py
CHANGED
|
@@ -3,8 +3,6 @@
|
|
|
3
3
|
##
|
|
4
4
|
# Imports
|
|
5
5
|
|
|
6
|
-
from .dataset import PackableSample
|
|
7
|
-
|
|
8
6
|
import functools
|
|
9
7
|
import inspect
|
|
10
8
|
|
|
@@ -17,17 +15,22 @@ from typing import (
|
|
|
17
15
|
Callable,
|
|
18
16
|
Optional,
|
|
19
17
|
Generic,
|
|
18
|
+
#
|
|
19
|
+
TYPE_CHECKING
|
|
20
20
|
)
|
|
21
21
|
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from .dataset import PackableSample
|
|
24
|
+
|
|
22
25
|
|
|
23
26
|
##
|
|
24
27
|
# Typing helpers
|
|
25
28
|
|
|
26
|
-
DatasetType: TypeAlias = Type[PackableSample]
|
|
29
|
+
DatasetType: TypeAlias = Type['PackableSample']
|
|
27
30
|
LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
|
|
28
31
|
|
|
29
|
-
S = TypeVar( 'S', bound = PackableSample )
|
|
30
|
-
V = TypeVar( 'V', bound = PackableSample )
|
|
32
|
+
S = TypeVar( 'S', bound = 'PackableSample' )
|
|
33
|
+
V = TypeVar( 'V', bound = 'PackableSample' )
|
|
31
34
|
type LensGetter[S, V] = Callable[[S], V]
|
|
32
35
|
type LensPutter[S, V] = Callable[[V, S], S]
|
|
33
36
|
|
|
@@ -38,15 +41,37 @@ type LensPutter[S, V] = Callable[[V, S], S]
|
|
|
38
41
|
class Lens( Generic[S, V] ):
|
|
39
42
|
"""TODO"""
|
|
40
43
|
|
|
44
|
+
# @property
|
|
45
|
+
# def source_type( self ) -> Type[S]:
|
|
46
|
+
# """The source type (S) for the lens; what is put to"""
|
|
47
|
+
# # TODO Figure out why linting fails here
|
|
48
|
+
# return self.__orig_class__.__args__[0]
|
|
49
|
+
|
|
50
|
+
# @property
|
|
51
|
+
# def view_type( self ) -> Type[V]:
|
|
52
|
+
# """The view type (V) for the lens; what is get'd from"""
|
|
53
|
+
# # TODO FIgure out why linting fails here
|
|
54
|
+
# return self.__orig_class__.__args__[1]
|
|
55
|
+
|
|
41
56
|
def __init__( self, get: LensGetter[S, V],
|
|
42
57
|
put: Optional[LensPutter[S, V]] = None
|
|
43
58
|
) -> None:
|
|
44
59
|
"""TODO"""
|
|
45
60
|
##
|
|
46
61
|
|
|
47
|
-
#
|
|
62
|
+
# Check argument validity
|
|
63
|
+
|
|
64
|
+
sig = inspect.signature( get )
|
|
65
|
+
input_types = list( sig.parameters.values() )
|
|
66
|
+
assert len( input_types ) == 1, \
|
|
67
|
+
'Wrong number of input args for lens: should only have one'
|
|
68
|
+
|
|
69
|
+
# Update function details for this object as returned by annotation
|
|
48
70
|
functools.update_wrapper( self, get )
|
|
49
71
|
|
|
72
|
+
self.source_type: Type[PackableSample] = input_types[0].annotation
|
|
73
|
+
self.view_type = sig.return_annotation
|
|
74
|
+
|
|
50
75
|
# Store the getter
|
|
51
76
|
self._getter = get
|
|
52
77
|
|
|
@@ -56,21 +81,7 @@ class Lens( Generic[S, V] ):
|
|
|
56
81
|
def _trivial_put( v: V, s: S ) -> S:
|
|
57
82
|
return s
|
|
58
83
|
put = _trivial_put
|
|
59
|
-
|
|
60
84
|
self._putter = put
|
|
61
|
-
|
|
62
|
-
# Register this lens for this type signature
|
|
63
|
-
|
|
64
|
-
sig = inspect.signature( get )
|
|
65
|
-
input_types = list( sig.parameters.values() )
|
|
66
|
-
assert len( input_types ) == 1, \
|
|
67
|
-
'Wrong number of input args for lens: should only have one'
|
|
68
|
-
|
|
69
|
-
input_type = input_types[0].annotation
|
|
70
|
-
output_type = sig.return_annotation
|
|
71
|
-
|
|
72
|
-
_registered_lenses[(input_type, output_type)] = self
|
|
73
|
-
print( _registered_lenses )
|
|
74
85
|
|
|
75
86
|
#
|
|
76
87
|
|
|
@@ -80,6 +91,8 @@ class Lens( Generic[S, V] ):
|
|
|
80
91
|
self._putter = put
|
|
81
92
|
return put
|
|
82
93
|
|
|
94
|
+
# Methods to actually execute transformations
|
|
95
|
+
|
|
83
96
|
def put( self, v: V, s: S ) -> S:
|
|
84
97
|
"""TODO"""
|
|
85
98
|
return self._putter( v, s )
|
|
@@ -88,21 +101,86 @@ class Lens( Generic[S, V] ):
|
|
|
88
101
|
"""TODO"""
|
|
89
102
|
return self( s )
|
|
90
103
|
|
|
91
|
-
#
|
|
104
|
+
# Convenience to enable calling the lens as its getter
|
|
92
105
|
|
|
93
106
|
def __call__( self, s: S ) -> V:
|
|
94
107
|
return self._getter( s )
|
|
95
108
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
109
|
+
# TODO Figure out how to properly parameterize this
|
|
110
|
+
# def _lens_factory[S, V]( register: bool = True ):
|
|
111
|
+
# """Register the annotated function `f` as the getter of a sample lens"""
|
|
112
|
+
|
|
113
|
+
# # The actual lens decorator taking a lens getter function to a lens object
|
|
114
|
+
# def _decorator( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
115
|
+
# ret = Lens[S, V]( f )
|
|
116
|
+
# if register:
|
|
117
|
+
# _network.register( ret )
|
|
118
|
+
# return ret
|
|
119
|
+
|
|
120
|
+
# # Return the lens decorator
|
|
121
|
+
# return _decorator
|
|
122
|
+
|
|
123
|
+
# # For convenience
|
|
124
|
+
# lens = _lens_factory
|
|
125
|
+
|
|
126
|
+
def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
127
|
+
ret = Lens[S, V]( f )
|
|
128
|
+
_network.register( ret )
|
|
129
|
+
return ret
|
|
99
130
|
|
|
100
131
|
|
|
101
132
|
##
|
|
102
|
-
# Global
|
|
133
|
+
# Global registry of used lenses
|
|
134
|
+
|
|
135
|
+
# _registered_lenses: Dict[LensSignature, Lens] = dict()
|
|
136
|
+
# """TODO"""
|
|
137
|
+
|
|
138
|
+
class LensNetwork:
|
|
139
|
+
"""TODO"""
|
|
140
|
+
|
|
141
|
+
_instance = None
|
|
142
|
+
"""The singleton instance"""
|
|
143
|
+
|
|
144
|
+
def __new__(cls, *args, **kwargs):
|
|
145
|
+
if cls._instance is None:
|
|
146
|
+
# If no instance exists, create a new one
|
|
147
|
+
cls._instance = super().__new__(cls)
|
|
148
|
+
return cls._instance # Return the existing (or newly created) instance
|
|
149
|
+
|
|
150
|
+
def __init__(self):
|
|
151
|
+
if not hasattr(self, '_initialized'): # Check if already initialized
|
|
152
|
+
self._registry: Dict[LensSignature, Lens] = dict()
|
|
153
|
+
self._initialized = True
|
|
154
|
+
|
|
155
|
+
def register( self, _lens: Lens ):
|
|
156
|
+
"""Set `lens` as the canonical view between its source and view types"""
|
|
157
|
+
|
|
158
|
+
# sig = inspect.signature( _lens.get )
|
|
159
|
+
# input_types = list( sig.parameters.values() )
|
|
160
|
+
# assert len( input_types ) == 1, \
|
|
161
|
+
# 'Wrong number of input args for lens: should only have one'
|
|
162
|
+
|
|
163
|
+
# input_type = input_types[0].annotation
|
|
164
|
+
# print( input_type )
|
|
165
|
+
# output_type = sig.return_annotation
|
|
166
|
+
|
|
167
|
+
# self._registry[input_type, output_type] = _lens
|
|
168
|
+
print( _lens.source_type )
|
|
169
|
+
self._registry[_lens.source_type, _lens.view_type] = _lens
|
|
170
|
+
|
|
171
|
+
def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
|
|
172
|
+
"""TODO"""
|
|
173
|
+
|
|
174
|
+
# TODO Handle compositional closure
|
|
175
|
+
ret = self._registry.get( (source, view), None )
|
|
176
|
+
if ret is None:
|
|
177
|
+
raise ValueError( f'No registered lens from source {source} to view {view}' )
|
|
178
|
+
|
|
179
|
+
return ret
|
|
180
|
+
|
|
103
181
|
|
|
104
|
-
|
|
105
|
-
|
|
182
|
+
# Create global singleton registry instance
|
|
183
|
+
_network = LensNetwork()
|
|
106
184
|
|
|
107
185
|
# def lens( f: LensPutter ) -> Lens:
|
|
108
186
|
# """Register the annotated function `f` as a sample lens"""
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
atdata/__init__.py,sha256=V2qBg7i2mfCNG9nww6Gi_fDp7iwolDMrNzhmNO6VA7M,233
|
|
2
|
+
atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
|
|
3
|
+
atdata/dataset.py,sha256=brNKGMkA_au2nLF5oUmjwub1E08DVwBKl9PnzPV6rPM,16722
|
|
4
|
+
atdata/lens.py,sha256=bGlxQ6PEnLj5poQ41DHj1LfpsmI5fELnjnUf4qXOsCo,5304
|
|
5
|
+
atdata-0.1.3a2.dist-info/METADATA,sha256=pYLvssfCzaiqer4yZQM3vh7nRlP8ZcEutXA8po7GZC0,529
|
|
6
|
+
atdata-0.1.3a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
atdata-0.1.3a2.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
|
|
8
|
+
atdata-0.1.3a2.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
9
|
+
atdata-0.1.3a2.dist-info/RECORD,,
|
atdata-0.1.2b1.dist-info/RECORD
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
atdata/__init__.py,sha256=YnlohxQwTUK6V84XHm2gdeCQH5sIrTHVLSApB-nt_z8,216
|
|
2
|
-
atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
|
|
3
|
-
atdata/dataset.py,sha256=pBaND2D33JiJoiL9CtCTBa3octtifa21P19K076sW3Q,15905
|
|
4
|
-
atdata/lens.py,sha256=ikExMWdGP3QH-bEuUDNAYO_ZjeaKJTfL9lpaN9CrRB4,2624
|
|
5
|
-
atdata-0.1.2b1.dist-info/METADATA,sha256=WtHM3N0kMxJKwPXl5cluRNnvk49WU_e72lVhmcInraY,529
|
|
6
|
-
atdata-0.1.2b1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
atdata-0.1.2b1.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
|
|
8
|
-
atdata-0.1.2b1.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
9
|
-
atdata-0.1.2b1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|