atdata 0.1.1a3__tar.gz → 0.1.2a3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atdata
3
- Version: 0.1.1a3
3
+ Version: 0.1.2a3
4
4
  Summary: A loose federation of distributed, typed datasets
5
5
  Author-email: Maxine Levesque <hello@maxine.science>
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "atdata"
3
- version = "0.1.1a3"
3
+ version = "0.1.2a3"
4
4
  description = "A loose federation of distributed, typed datasets"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -7,6 +7,7 @@ from .dataset import (
7
7
  PackableSample,
8
8
  SampleBatch,
9
9
  Dataset,
10
+ packable,
10
11
  )
11
12
 
12
13
 
@@ -5,6 +5,7 @@
5
5
 
6
6
  import webdataset as wds
7
7
 
8
+ import functools
8
9
  from dataclasses import dataclass
9
10
  import uuid
10
11
 
@@ -96,7 +97,8 @@ def _make_packable( x ):
96
97
  class PackableSample( ABC ):
97
98
  """A sample that can be packed and unpacked with msgpack"""
98
99
 
99
- def __post_init__( self ):
100
+ def _ensure_good( self ):
101
+ """TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
100
102
 
101
103
  # Auto-convert known types when annotated
102
104
  for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
@@ -120,12 +122,17 @@ class PackableSample( ABC ):
120
122
  elif isinstance( var_cur_value, bytes ):
121
123
  setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
122
124
 
125
+ def __post_init__( self ):
126
+ self._ensure_good()
127
+
123
128
  ##
124
129
 
125
130
  @classmethod
126
131
  def from_data( cls, data: MsgpackRawSample ) -> Self:
127
132
  """Create a sample instance from unpacked msgpack data"""
128
- return cls( **data )
133
+ ret = cls( **data )
134
+ ret._ensure_good()
135
+ return ret
129
136
 
130
137
  @classmethod
131
138
  def from_bytes( cls, bs: bytes ) -> Self:
@@ -415,4 +422,42 @@ class Dataset( Generic[ST] ):
415
422
  # This default implementation simply creates a list one sample at a time
416
423
  # """
417
424
  # assert cls.batch_class is not None, 'No batch class specified'
418
- # return cls.batch_class( **batch )
425
+ # return cls.batch_class( **batch )
426
+
427
+
428
+ ##
429
+ # Shortcut decorators
430
+
431
+ # def packable( cls ):
432
+ # """TODO"""
433
+
434
+ # def decorator( cls ):
435
+ # # Create a new class dynamically
436
+ # # The new class inherits from the new_parent_class first, then the original cls
437
+ # new_bases = (PackableSample,) + cls.__bases__
438
+ # new_cls = type(cls.__name__, new_bases, dict(cls.__dict__))
439
+
440
+ # # Optionally, update __module__ and __qualname__ for better introspection
441
+ # new_cls.__module__ = cls.__module__
442
+ # new_cls.__qualname__ = cls.__qualname__
443
+
444
+ # return new_cls
445
+ # return decorator
446
+
447
+ def packable( cls ):
448
+ """TODO"""
449
+
450
+ ##
451
+
452
+ as_dataclass = dataclass( cls )
453
+
454
+ class as_packable( PackableSample, as_dataclass ):
455
+ def __post_init__( self ):
456
+ return PackableSample.__post_init__( self )
457
+
458
+ as_packable.__name__ = cls.__name__
459
+ as_packable.__annotations__ = cls.__annotations__
460
+
461
+ ##
462
+
463
+ return as_packable
@@ -39,6 +39,17 @@ class NumpyTestSample( atdata.PackableSample ):
39
39
  label: int
40
40
  image: NDArray
41
41
 
42
+ @atdata.packable
43
+ class BasicTestSampleDecorated:
44
+ name: str
45
+ position: int
46
+ value: float
47
+
48
+ @atdata.packable
49
+ class NumpyTestSampleDecorated:
50
+ label: int
51
+ image: NDArray
52
+
42
53
  test_cases = [
43
54
  {
44
55
  'SampleType': BasicTestSample,
@@ -51,13 +62,31 @@ test_cases = [
51
62
  },
52
63
  {
53
64
  'SampleType': NumpyTestSample,
54
- 'sample_data':
65
+ 'sample_data':
55
66
  {
56
67
  'label': 9_001,
57
68
  'image': np.random.randn( 1024, 1024 ),
58
69
  },
59
70
  'sample_wds_stem': 'numpy_test',
60
71
  },
72
+ {
73
+ 'SampleType': BasicTestSampleDecorated,
74
+ 'sample_data': {
75
+ 'name': 'Hello, world!',
76
+ 'position': 42,
77
+ 'value': 1024.768,
78
+ },
79
+ 'sample_wds_stem': 'basic_test_decorated',
80
+ },
81
+ {
82
+ 'SampleType': NumpyTestSampleDecorated,
83
+ 'sample_data':
84
+ {
85
+ 'label': 9_001,
86
+ 'image': np.random.randn( 1024, 1024 ),
87
+ },
88
+ 'sample_wds_stem': 'numpy_test_decorated',
89
+ },
61
90
  ]
62
91
 
63
92
 
@@ -89,6 +118,35 @@ def test_create_sample(
89
118
 
90
119
  #
91
120
 
121
+ # def test_decorator_syntax():
122
+ # """Test use of decorator syntax for sample types"""
123
+
124
+ # @atdata.packable
125
+ # class BasicTestSampleDecorated:
126
+ # name: str
127
+ # position: int
128
+ # value: float
129
+
130
+ # @atdata.packable
131
+ # class NumpyTestSampleDecorated:
132
+ # label: int
133
+ # image: NDArray
134
+
135
+ # ##
136
+
137
+ # test_create_sample( BasicTestSampleDecorated, {
138
+ # 'name': 'Hello, world!',
139
+ # 'position': 42,
140
+ # 'value': 1024.768,
141
+ # } )
142
+
143
+ # test_create_sample( NumpyTestSampleDecorated, {
144
+ # 'label': 9_001,
145
+ # 'image': np.random.randn( 1024, 1024 ),
146
+ # } )
147
+
148
+ #
149
+
92
150
  @pytest.mark.parametrize(
93
151
  ('SampleType', 'sample_data', 'sample_wds_stem'),
94
152
  [ (case['SampleType'], case['sample_data'], case['sample_wds_stem'])
@@ -109,7 +167,6 @@ def test_wds(
109
167
  batch_size = 4
110
168
  n_iterate = 10
111
169
 
112
-
113
170
  ## Write sharded dataset
114
171
 
115
172
  file_pattern = (
@@ -140,7 +197,7 @@ def test_wds(
140
197
 
141
198
  iterations_run = 0
142
199
  for i_iterate, cur_sample in enumerate( dataset.ordered( batch_size = None ) ):
143
-
200
+
144
201
  assert isinstance( cur_sample, SampleType ), \
145
202
  f'Single sample for {SampleType} written to `wds` is of wrong type'
146
203
 
@@ -152,7 +209,7 @@ def test_wds(
152
209
  else:
153
210
  is_correct = getattr( cur_sample, k ) == v
154
211
  assert is_correct, \
155
- f'{SampleType}: Incorrect sample value found for {k}'
212
+ f'{SampleType}: Incorrect sample value found for {k} - {type( getattr( cur_sample, k ) )}'
156
213
 
157
214
  iterations_run += 1
158
215
  if iterations_run >= n_iterate:
@@ -166,7 +223,6 @@ def test_wds(
166
223
  start_id = f'{0:06d}'
167
224
  end_id = f'{9:06d}'
168
225
  first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' )
169
- print( first_filename )
170
226
  dataset = atdata.Dataset[SampleType]( first_filename )
171
227
 
172
228
  iterations_run = 0
@@ -241,7 +297,6 @@ def test_wds(
241
297
  start_id = f'{0:06d}'
242
298
  end_id = f'{9:06d}'
243
299
  first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' )
244
- print( first_filename )
245
300
  dataset = atdata.Dataset[SampleType]( first_filename )
246
301
 
247
302
  iterations_run = 0
File without changes
File without changes
File without changes
File without changes