atdata 0.1.1a3__tar.gz → 0.1.2a3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atdata-0.1.1a3 → atdata-0.1.2a3}/PKG-INFO +1 -1
- {atdata-0.1.1a3 → atdata-0.1.2a3}/pyproject.toml +1 -1
- {atdata-0.1.1a3 → atdata-0.1.2a3}/src/atdata/__init__.py +1 -0
- {atdata-0.1.1a3 → atdata-0.1.2a3}/src/atdata/dataset.py +48 -3
- {atdata-0.1.1a3 → atdata-0.1.2a3}/tests/test_dataset.py +61 -6
- {atdata-0.1.1a3 → atdata-0.1.2a3}/.github/workflows/uv-publish-pypi.yml +0 -0
- {atdata-0.1.1a3 → atdata-0.1.2a3}/.github/workflows/uv-test.yml +0 -0
- {atdata-0.1.1a3 → atdata-0.1.2a3}/.gitignore +0 -0
- {atdata-0.1.1a3 → atdata-0.1.2a3}/.python-version +0 -0
- {atdata-0.1.1a3 → atdata-0.1.2a3}/LICENSE +0 -0
- {atdata-0.1.1a3 → atdata-0.1.2a3}/README.md +0 -0
- {atdata-0.1.1a3 → atdata-0.1.2a3}/src/atdata/_helpers.py +0 -0
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
import webdataset as wds
|
|
7
7
|
|
|
8
|
+
import functools
|
|
8
9
|
from dataclasses import dataclass
|
|
9
10
|
import uuid
|
|
10
11
|
|
|
@@ -96,7 +97,8 @@ def _make_packable( x ):
|
|
|
96
97
|
class PackableSample( ABC ):
|
|
97
98
|
"""A sample that can be packed and unpacked with msgpack"""
|
|
98
99
|
|
|
99
|
-
def
|
|
100
|
+
def _ensure_good( self ):
|
|
101
|
+
"""TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
|
|
100
102
|
|
|
101
103
|
# Auto-convert known types when annotated
|
|
102
104
|
for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
@@ -120,12 +122,17 @@ class PackableSample( ABC ):
|
|
|
120
122
|
elif isinstance( var_cur_value, bytes ):
|
|
121
123
|
setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
|
|
122
124
|
|
|
125
|
+
def __post_init__( self ):
|
|
126
|
+
self._ensure_good()
|
|
127
|
+
|
|
123
128
|
##
|
|
124
129
|
|
|
125
130
|
@classmethod
|
|
126
131
|
def from_data( cls, data: MsgpackRawSample ) -> Self:
|
|
127
132
|
"""Create a sample instance from unpacked msgpack data"""
|
|
128
|
-
|
|
133
|
+
ret = cls( **data )
|
|
134
|
+
ret._ensure_good()
|
|
135
|
+
return ret
|
|
129
136
|
|
|
130
137
|
@classmethod
|
|
131
138
|
def from_bytes( cls, bs: bytes ) -> Self:
|
|
@@ -415,4 +422,42 @@ class Dataset( Generic[ST] ):
|
|
|
415
422
|
# This default implementation simply creates a list one sample at a time
|
|
416
423
|
# """
|
|
417
424
|
# assert cls.batch_class is not None, 'No batch class specified'
|
|
418
|
-
# return cls.batch_class( **batch )
|
|
425
|
+
# return cls.batch_class( **batch )
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
##
|
|
429
|
+
# Shortcut decorators
|
|
430
|
+
|
|
431
|
+
# def packable( cls ):
|
|
432
|
+
# """TODO"""
|
|
433
|
+
|
|
434
|
+
# def decorator( cls ):
|
|
435
|
+
# # Create a new class dynamically
|
|
436
|
+
# # The new class inherits from the new_parent_class first, then the original cls
|
|
437
|
+
# new_bases = (PackableSample,) + cls.__bases__
|
|
438
|
+
# new_cls = type(cls.__name__, new_bases, dict(cls.__dict__))
|
|
439
|
+
|
|
440
|
+
# # Optionally, update __module__ and __qualname__ for better introspection
|
|
441
|
+
# new_cls.__module__ = cls.__module__
|
|
442
|
+
# new_cls.__qualname__ = cls.__qualname__
|
|
443
|
+
|
|
444
|
+
# return new_cls
|
|
445
|
+
# return decorator
|
|
446
|
+
|
|
447
|
+
def packable( cls ):
|
|
448
|
+
"""TODO"""
|
|
449
|
+
|
|
450
|
+
##
|
|
451
|
+
|
|
452
|
+
as_dataclass = dataclass( cls )
|
|
453
|
+
|
|
454
|
+
class as_packable( PackableSample, as_dataclass ):
|
|
455
|
+
def __post_init__( self ):
|
|
456
|
+
return PackableSample.__post_init__( self )
|
|
457
|
+
|
|
458
|
+
as_packable.__name__ = cls.__name__
|
|
459
|
+
as_packable.__annotations__ = cls.__annotations__
|
|
460
|
+
|
|
461
|
+
##
|
|
462
|
+
|
|
463
|
+
return as_packable
|
|
@@ -39,6 +39,17 @@ class NumpyTestSample( atdata.PackableSample ):
|
|
|
39
39
|
label: int
|
|
40
40
|
image: NDArray
|
|
41
41
|
|
|
42
|
+
@atdata.packable
|
|
43
|
+
class BasicTestSampleDecorated:
|
|
44
|
+
name: str
|
|
45
|
+
position: int
|
|
46
|
+
value: float
|
|
47
|
+
|
|
48
|
+
@atdata.packable
|
|
49
|
+
class NumpyTestSampleDecorated:
|
|
50
|
+
label: int
|
|
51
|
+
image: NDArray
|
|
52
|
+
|
|
42
53
|
test_cases = [
|
|
43
54
|
{
|
|
44
55
|
'SampleType': BasicTestSample,
|
|
@@ -51,13 +62,31 @@ test_cases = [
|
|
|
51
62
|
},
|
|
52
63
|
{
|
|
53
64
|
'SampleType': NumpyTestSample,
|
|
54
|
-
'sample_data':
|
|
65
|
+
'sample_data':
|
|
55
66
|
{
|
|
56
67
|
'label': 9_001,
|
|
57
68
|
'image': np.random.randn( 1024, 1024 ),
|
|
58
69
|
},
|
|
59
70
|
'sample_wds_stem': 'numpy_test',
|
|
60
71
|
},
|
|
72
|
+
{
|
|
73
|
+
'SampleType': BasicTestSampleDecorated,
|
|
74
|
+
'sample_data': {
|
|
75
|
+
'name': 'Hello, world!',
|
|
76
|
+
'position': 42,
|
|
77
|
+
'value': 1024.768,
|
|
78
|
+
},
|
|
79
|
+
'sample_wds_stem': 'basic_test_decorated',
|
|
80
|
+
},
|
|
81
|
+
{
|
|
82
|
+
'SampleType': NumpyTestSampleDecorated,
|
|
83
|
+
'sample_data':
|
|
84
|
+
{
|
|
85
|
+
'label': 9_001,
|
|
86
|
+
'image': np.random.randn( 1024, 1024 ),
|
|
87
|
+
},
|
|
88
|
+
'sample_wds_stem': 'numpy_test_decorated',
|
|
89
|
+
},
|
|
61
90
|
]
|
|
62
91
|
|
|
63
92
|
|
|
@@ -89,6 +118,35 @@ def test_create_sample(
|
|
|
89
118
|
|
|
90
119
|
#
|
|
91
120
|
|
|
121
|
+
# def test_decorator_syntax():
|
|
122
|
+
# """Test use of decorator syntax for sample types"""
|
|
123
|
+
|
|
124
|
+
# @atdata.packable
|
|
125
|
+
# class BasicTestSampleDecorated:
|
|
126
|
+
# name: str
|
|
127
|
+
# position: int
|
|
128
|
+
# value: float
|
|
129
|
+
|
|
130
|
+
# @atdata.packable
|
|
131
|
+
# class NumpyTestSampleDecorated:
|
|
132
|
+
# label: int
|
|
133
|
+
# image: NDArray
|
|
134
|
+
|
|
135
|
+
# ##
|
|
136
|
+
|
|
137
|
+
# test_create_sample( BasicTestSampleDecorated, {
|
|
138
|
+
# 'name': 'Hello, world!',
|
|
139
|
+
# 'position': 42,
|
|
140
|
+
# 'value': 1024.768,
|
|
141
|
+
# } )
|
|
142
|
+
|
|
143
|
+
# test_create_sample( NumpyTestSampleDecorated, {
|
|
144
|
+
# 'label': 9_001,
|
|
145
|
+
# 'image': np.random.randn( 1024, 1024 ),
|
|
146
|
+
# } )
|
|
147
|
+
|
|
148
|
+
#
|
|
149
|
+
|
|
92
150
|
@pytest.mark.parametrize(
|
|
93
151
|
('SampleType', 'sample_data', 'sample_wds_stem'),
|
|
94
152
|
[ (case['SampleType'], case['sample_data'], case['sample_wds_stem'])
|
|
@@ -109,7 +167,6 @@ def test_wds(
|
|
|
109
167
|
batch_size = 4
|
|
110
168
|
n_iterate = 10
|
|
111
169
|
|
|
112
|
-
|
|
113
170
|
## Write sharded dataset
|
|
114
171
|
|
|
115
172
|
file_pattern = (
|
|
@@ -140,7 +197,7 @@ def test_wds(
|
|
|
140
197
|
|
|
141
198
|
iterations_run = 0
|
|
142
199
|
for i_iterate, cur_sample in enumerate( dataset.ordered( batch_size = None ) ):
|
|
143
|
-
|
|
200
|
+
|
|
144
201
|
assert isinstance( cur_sample, SampleType ), \
|
|
145
202
|
f'Single sample for {SampleType} written to `wds` is of wrong type'
|
|
146
203
|
|
|
@@ -152,7 +209,7 @@ def test_wds(
|
|
|
152
209
|
else:
|
|
153
210
|
is_correct = getattr( cur_sample, k ) == v
|
|
154
211
|
assert is_correct, \
|
|
155
|
-
f'{SampleType}: Incorrect sample value found for {k}'
|
|
212
|
+
f'{SampleType}: Incorrect sample value found for {k} - {type( getattr( cur_sample, k ) )}'
|
|
156
213
|
|
|
157
214
|
iterations_run += 1
|
|
158
215
|
if iterations_run >= n_iterate:
|
|
@@ -166,7 +223,6 @@ def test_wds(
|
|
|
166
223
|
start_id = f'{0:06d}'
|
|
167
224
|
end_id = f'{9:06d}'
|
|
168
225
|
first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' )
|
|
169
|
-
print( first_filename )
|
|
170
226
|
dataset = atdata.Dataset[SampleType]( first_filename )
|
|
171
227
|
|
|
172
228
|
iterations_run = 0
|
|
@@ -241,7 +297,6 @@ def test_wds(
|
|
|
241
297
|
start_id = f'{0:06d}'
|
|
242
298
|
end_id = f'{9:06d}'
|
|
243
299
|
first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' )
|
|
244
|
-
print( first_filename )
|
|
245
300
|
dataset = atdata.Dataset[SampleType]( first_filename )
|
|
246
301
|
|
|
247
302
|
iterations_run = 0
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|