atdata 0.1.2a1__tar.gz → 0.1.2a4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atdata-0.1.2a1 → atdata-0.1.2a4}/PKG-INFO +1 -1
- {atdata-0.1.2a1 → atdata-0.1.2a4}/pyproject.toml +1 -1
- {atdata-0.1.2a1 → atdata-0.1.2a4}/src/atdata/dataset.py +14 -3
- {atdata-0.1.2a1 → atdata-0.1.2a4}/tests/test_dataset.py +53 -27
- {atdata-0.1.2a1 → atdata-0.1.2a4}/.github/workflows/uv-publish-pypi.yml +0 -0
- {atdata-0.1.2a1 → atdata-0.1.2a4}/.github/workflows/uv-test.yml +0 -0
- {atdata-0.1.2a1 → atdata-0.1.2a4}/.gitignore +0 -0
- {atdata-0.1.2a1 → atdata-0.1.2a4}/.python-version +0 -0
- {atdata-0.1.2a1 → atdata-0.1.2a4}/LICENSE +0 -0
- {atdata-0.1.2a1 → atdata-0.1.2a4}/README.md +0 -0
- {atdata-0.1.2a1 → atdata-0.1.2a4}/src/atdata/__init__.py +0 -0
- {atdata-0.1.2a1 → atdata-0.1.2a4}/src/atdata/_helpers.py +0 -0
|
@@ -97,7 +97,8 @@ def _make_packable( x ):
|
|
|
97
97
|
class PackableSample( ABC ):
|
|
98
98
|
"""A sample that can be packed and unpacked with msgpack"""
|
|
99
99
|
|
|
100
|
-
def
|
|
100
|
+
def _ensure_good( self ):
|
|
101
|
+
"""TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
|
|
101
102
|
|
|
102
103
|
# Auto-convert known types when annotated
|
|
103
104
|
for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
@@ -121,12 +122,17 @@ class PackableSample( ABC ):
|
|
|
121
122
|
elif isinstance( var_cur_value, bytes ):
|
|
122
123
|
setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
|
|
123
124
|
|
|
125
|
+
def __post_init__( self ):
|
|
126
|
+
self._ensure_good()
|
|
127
|
+
|
|
124
128
|
##
|
|
125
129
|
|
|
126
130
|
@classmethod
|
|
127
131
|
def from_data( cls, data: MsgpackRawSample ) -> Self:
|
|
128
132
|
"""Create a sample instance from unpacked msgpack data"""
|
|
129
|
-
|
|
133
|
+
ret = cls( **data )
|
|
134
|
+
ret._ensure_good()
|
|
135
|
+
return ret
|
|
130
136
|
|
|
131
137
|
@classmethod
|
|
132
138
|
def from_bytes( cls, bs: bytes ) -> Self:
|
|
@@ -443,12 +449,17 @@ def packable( cls ):
|
|
|
443
449
|
|
|
444
450
|
##
|
|
445
451
|
|
|
452
|
+
# Add in dataclass niceness to original class
|
|
446
453
|
as_dataclass = dataclass( cls )
|
|
447
454
|
|
|
455
|
+
# This triggers a bunch of behind-the-scenes stuff for the newly annotated class
|
|
456
|
+
@dataclass
|
|
448
457
|
class as_packable( as_dataclass, PackableSample ):
|
|
449
|
-
|
|
458
|
+
def __post_init__( self ):
|
|
459
|
+
return PackableSample.__post_init__( self )
|
|
450
460
|
|
|
451
461
|
as_packable.__name__ = cls.__name__
|
|
462
|
+
as_packable.__annotations__ = cls.__annotations__
|
|
452
463
|
|
|
453
464
|
##
|
|
454
465
|
|
|
@@ -39,6 +39,17 @@ class NumpyTestSample( atdata.PackableSample ):
|
|
|
39
39
|
label: int
|
|
40
40
|
image: NDArray
|
|
41
41
|
|
|
42
|
+
@atdata.packable
|
|
43
|
+
class BasicTestSampleDecorated:
|
|
44
|
+
name: str
|
|
45
|
+
position: int
|
|
46
|
+
value: float
|
|
47
|
+
|
|
48
|
+
@atdata.packable
|
|
49
|
+
class NumpyTestSampleDecorated:
|
|
50
|
+
label: int
|
|
51
|
+
image: NDArray
|
|
52
|
+
|
|
42
53
|
test_cases = [
|
|
43
54
|
{
|
|
44
55
|
'SampleType': BasicTestSample,
|
|
@@ -58,6 +69,24 @@ test_cases = [
|
|
|
58
69
|
},
|
|
59
70
|
'sample_wds_stem': 'numpy_test',
|
|
60
71
|
},
|
|
72
|
+
{
|
|
73
|
+
'SampleType': BasicTestSampleDecorated,
|
|
74
|
+
'sample_data': {
|
|
75
|
+
'name': 'Hello, world!',
|
|
76
|
+
'position': 42,
|
|
77
|
+
'value': 1024.768,
|
|
78
|
+
},
|
|
79
|
+
'sample_wds_stem': 'basic_test_decorated',
|
|
80
|
+
},
|
|
81
|
+
{
|
|
82
|
+
'SampleType': NumpyTestSampleDecorated,
|
|
83
|
+
'sample_data':
|
|
84
|
+
{
|
|
85
|
+
'label': 9_001,
|
|
86
|
+
'image': np.random.randn( 1024, 1024 ),
|
|
87
|
+
},
|
|
88
|
+
'sample_wds_stem': 'numpy_test_decorated',
|
|
89
|
+
},
|
|
61
90
|
]
|
|
62
91
|
|
|
63
92
|
|
|
@@ -89,32 +118,32 @@ def test_create_sample(
|
|
|
89
118
|
|
|
90
119
|
#
|
|
91
120
|
|
|
92
|
-
def test_decorator_syntax():
|
|
93
|
-
|
|
121
|
+
# def test_decorator_syntax():
|
|
122
|
+
# """Test use of decorator syntax for sample types"""
|
|
94
123
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
124
|
+
# @atdata.packable
|
|
125
|
+
# class BasicTestSampleDecorated:
|
|
126
|
+
# name: str
|
|
127
|
+
# position: int
|
|
128
|
+
# value: float
|
|
129
|
+
|
|
130
|
+
# @atdata.packable
|
|
131
|
+
# class NumpyTestSampleDecorated:
|
|
132
|
+
# label: int
|
|
133
|
+
# image: NDArray
|
|
105
134
|
|
|
106
|
-
|
|
135
|
+
# ##
|
|
107
136
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
137
|
+
# test_create_sample( BasicTestSampleDecorated, {
|
|
138
|
+
# 'name': 'Hello, world!',
|
|
139
|
+
# 'position': 42,
|
|
140
|
+
# 'value': 1024.768,
|
|
141
|
+
# } )
|
|
113
142
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
143
|
+
# test_create_sample( NumpyTestSampleDecorated, {
|
|
144
|
+
# 'label': 9_001,
|
|
145
|
+
# 'image': np.random.randn( 1024, 1024 ),
|
|
146
|
+
# } )
|
|
118
147
|
|
|
119
148
|
#
|
|
120
149
|
|
|
@@ -138,7 +167,6 @@ def test_wds(
|
|
|
138
167
|
batch_size = 4
|
|
139
168
|
n_iterate = 10
|
|
140
169
|
|
|
141
|
-
|
|
142
170
|
## Write sharded dataset
|
|
143
171
|
|
|
144
172
|
file_pattern = (
|
|
@@ -169,7 +197,7 @@ def test_wds(
|
|
|
169
197
|
|
|
170
198
|
iterations_run = 0
|
|
171
199
|
for i_iterate, cur_sample in enumerate( dataset.ordered( batch_size = None ) ):
|
|
172
|
-
|
|
200
|
+
|
|
173
201
|
assert isinstance( cur_sample, SampleType ), \
|
|
174
202
|
f'Single sample for {SampleType} written to `wds` is of wrong type'
|
|
175
203
|
|
|
@@ -181,7 +209,7 @@ def test_wds(
|
|
|
181
209
|
else:
|
|
182
210
|
is_correct = getattr( cur_sample, k ) == v
|
|
183
211
|
assert is_correct, \
|
|
184
|
-
f'{SampleType}: Incorrect sample value found for {k}'
|
|
212
|
+
f'{SampleType}: Incorrect sample value found for {k} - {type( getattr( cur_sample, k ) )}'
|
|
185
213
|
|
|
186
214
|
iterations_run += 1
|
|
187
215
|
if iterations_run >= n_iterate:
|
|
@@ -195,7 +223,6 @@ def test_wds(
|
|
|
195
223
|
start_id = f'{0:06d}'
|
|
196
224
|
end_id = f'{9:06d}'
|
|
197
225
|
first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' )
|
|
198
|
-
print( first_filename )
|
|
199
226
|
dataset = atdata.Dataset[SampleType]( first_filename )
|
|
200
227
|
|
|
201
228
|
iterations_run = 0
|
|
@@ -270,7 +297,6 @@ def test_wds(
|
|
|
270
297
|
start_id = f'{0:06d}'
|
|
271
298
|
end_id = f'{9:06d}'
|
|
272
299
|
first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' )
|
|
273
|
-
print( first_filename )
|
|
274
300
|
dataset = atdata.Dataset[SampleType]( first_filename )
|
|
275
301
|
|
|
276
302
|
iterations_run = 0
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|