atdata 0.1.2a1__tar.gz → 0.1.2a4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atdata
3
- Version: 0.1.2a1
3
+ Version: 0.1.2a4
4
4
  Summary: A loose federation of distributed, typed datasets
5
5
  Author-email: Maxine Levesque <hello@maxine.science>
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "atdata"
3
- version = "0.1.2a1"
3
+ version = "0.1.2a4"
4
4
  description = "A loose federation of distributed, typed datasets"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -97,7 +97,8 @@ def _make_packable( x ):
97
97
  class PackableSample( ABC ):
98
98
  """A sample that can be packed and unpacked with msgpack"""
99
99
 
100
- def __post_init__( self ):
100
+ def _ensure_good( self ):
101
+ """TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
101
102
 
102
103
  # Auto-convert known types when annotated
103
104
  for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
@@ -121,12 +122,17 @@ class PackableSample( ABC ):
121
122
  elif isinstance( var_cur_value, bytes ):
122
123
  setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
123
124
 
125
+ def __post_init__( self ):
126
+ self._ensure_good()
127
+
124
128
  ##
125
129
 
126
130
  @classmethod
127
131
  def from_data( cls, data: MsgpackRawSample ) -> Self:
128
132
  """Create a sample instance from unpacked msgpack data"""
129
- return cls( **data )
133
+ ret = cls( **data )
134
+ ret._ensure_good()
135
+ return ret
130
136
 
131
137
  @classmethod
132
138
  def from_bytes( cls, bs: bytes ) -> Self:
@@ -443,12 +449,17 @@ def packable( cls ):
443
449
 
444
450
  ##
445
451
 
452
+ # Add in dataclass niceness to original class
446
453
  as_dataclass = dataclass( cls )
447
454
 
455
+ # This triggers a bunch of behind-the-scenes stuff for the newly annotated class
456
+ @dataclass
448
457
  class as_packable( as_dataclass, PackableSample ):
449
- pass
458
+ def __post_init__( self ):
459
+ return PackableSample.__post_init__( self )
450
460
 
451
461
  as_packable.__name__ = cls.__name__
462
+ as_packable.__annotations__ = cls.__annotations__
452
463
 
453
464
  ##
454
465
 
@@ -39,6 +39,17 @@ class NumpyTestSample( atdata.PackableSample ):
39
39
  label: int
40
40
  image: NDArray
41
41
 
42
+ @atdata.packable
43
+ class BasicTestSampleDecorated:
44
+ name: str
45
+ position: int
46
+ value: float
47
+
48
+ @atdata.packable
49
+ class NumpyTestSampleDecorated:
50
+ label: int
51
+ image: NDArray
52
+
42
53
  test_cases = [
43
54
  {
44
55
  'SampleType': BasicTestSample,
@@ -58,6 +69,24 @@ test_cases = [
58
69
  },
59
70
  'sample_wds_stem': 'numpy_test',
60
71
  },
72
+ {
73
+ 'SampleType': BasicTestSampleDecorated,
74
+ 'sample_data': {
75
+ 'name': 'Hello, world!',
76
+ 'position': 42,
77
+ 'value': 1024.768,
78
+ },
79
+ 'sample_wds_stem': 'basic_test_decorated',
80
+ },
81
+ {
82
+ 'SampleType': NumpyTestSampleDecorated,
83
+ 'sample_data':
84
+ {
85
+ 'label': 9_001,
86
+ 'image': np.random.randn( 1024, 1024 ),
87
+ },
88
+ 'sample_wds_stem': 'numpy_test_decorated',
89
+ },
61
90
  ]
62
91
 
63
92
 
@@ -89,32 +118,32 @@ def test_create_sample(
89
118
 
90
119
  #
91
120
 
92
- def test_decorator_syntax():
93
- """Test use of decorator syntax for sample types"""
121
+ # def test_decorator_syntax():
122
+ # """Test use of decorator syntax for sample types"""
94
123
 
95
- @atdata.packable
96
- class BasicTestSampleDecorated:
97
- name: str
98
- position: int
99
- value: float
100
-
101
- @atdata.packable
102
- class NumpyTestSampleDecorated:
103
- label: int
104
- image: NDArray
124
+ # @atdata.packable
125
+ # class BasicTestSampleDecorated:
126
+ # name: str
127
+ # position: int
128
+ # value: float
129
+
130
+ # @atdata.packable
131
+ # class NumpyTestSampleDecorated:
132
+ # label: int
133
+ # image: NDArray
105
134
 
106
- ##
135
+ # ##
107
136
 
108
- test_create_sample( BasicTestSampleDecorated, {
109
- 'name': 'Hello, world!',
110
- 'position': 42,
111
- 'value': 1024.768,
112
- } )
137
+ # test_create_sample( BasicTestSampleDecorated, {
138
+ # 'name': 'Hello, world!',
139
+ # 'position': 42,
140
+ # 'value': 1024.768,
141
+ # } )
113
142
 
114
- test_create_sample( NumpyTestSampleDecorated, {
115
- 'label': 9_001,
116
- 'image': np.random.randn( 1024, 1024 ),
117
- } )
143
+ # test_create_sample( NumpyTestSampleDecorated, {
144
+ # 'label': 9_001,
145
+ # 'image': np.random.randn( 1024, 1024 ),
146
+ # } )
118
147
 
119
148
  #
120
149
 
@@ -138,7 +167,6 @@ def test_wds(
138
167
  batch_size = 4
139
168
  n_iterate = 10
140
169
 
141
-
142
170
  ## Write sharded dataset
143
171
 
144
172
  file_pattern = (
@@ -169,7 +197,7 @@ def test_wds(
169
197
 
170
198
  iterations_run = 0
171
199
  for i_iterate, cur_sample in enumerate( dataset.ordered( batch_size = None ) ):
172
-
200
+
173
201
  assert isinstance( cur_sample, SampleType ), \
174
202
  f'Single sample for {SampleType} written to `wds` is of wrong type'
175
203
 
@@ -181,7 +209,7 @@ def test_wds(
181
209
  else:
182
210
  is_correct = getattr( cur_sample, k ) == v
183
211
  assert is_correct, \
184
- f'{SampleType}: Incorrect sample value found for {k}'
212
+ f'{SampleType}: Incorrect sample value found for {k} - {type( getattr( cur_sample, k ) )}'
185
213
 
186
214
  iterations_run += 1
187
215
  if iterations_run >= n_iterate:
@@ -195,7 +223,6 @@ def test_wds(
195
223
  start_id = f'{0:06d}'
196
224
  end_id = f'{9:06d}'
197
225
  first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' )
198
- print( first_filename )
199
226
  dataset = atdata.Dataset[SampleType]( first_filename )
200
227
 
201
228
  iterations_run = 0
@@ -270,7 +297,6 @@ def test_wds(
270
297
  start_id = f'{0:06d}'
271
298
  end_id = f'{9:06d}'
272
299
  first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' )
273
- print( first_filename )
274
300
  dataset = atdata.Dataset[SampleType]( first_filename )
275
301
 
276
302
  iterations_run = 0
File without changes
File without changes
File without changes
File without changes