atdata 0.1.3a3__tar.gz → 0.1.3b3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atdata
3
- Version: 0.1.3a3
3
+ Version: 0.1.3b3
4
4
  Summary: A loose federation of distributed, typed datasets
5
5
  Author-email: Maxine Levesque <hello@maxine.science>
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "atdata"
3
- version = "0.1.3a3"
3
+ version = "0.1.3b3"
4
4
  description = "A loose federation of distributed, typed datasets"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -8,6 +8,9 @@ import webdataset as wds
8
8
  from pathlib import Path
9
9
  import uuid
10
10
  import functools
11
+
12
+ import dataclasses
13
+ import types
11
14
  from dataclasses import (
12
15
  dataclass,
13
16
  asdict,
@@ -21,6 +24,7 @@ from tqdm import tqdm
21
24
  import numpy as np
22
25
  import pandas as pd
23
26
 
27
+ import typing
24
28
  from typing import (
25
29
  Any,
26
30
  Optional,
@@ -28,6 +32,7 @@ from typing import (
28
32
  Sequence,
29
33
  Iterable,
30
34
  Callable,
35
+ Union,
31
36
  #
32
37
  Self,
33
38
  Generic,
@@ -108,6 +113,24 @@ def _make_packable( x ):
108
113
  return eh.array_to_bytes( x )
109
114
  return x
110
115
 
116
+ def _is_possibly_ndarray_type( t ):
117
+ """Checks if a type annotation is possibly an NDArray."""
118
+
119
+ # Directly an NDArray
120
+ if t == NDArray:
121
+ # print( 'is an NDArray' )
122
+ return True
123
+
124
+ # Check for Optionals (i.e., NDArray | None)
125
+ if isinstance( t, types.UnionType ):
126
+ t_parts = t.__args__
127
+ if any( x == NDArray
128
+ for x in t_parts ):
129
+ return True
130
+
131
+ # Not an NDArray
132
+ return False
133
+
111
134
  @dataclass
112
135
  class PackableSample( ABC ):
113
136
  """A sample that can be packed and unpacked with msgpack"""
@@ -116,10 +139,13 @@ class PackableSample( ABC ):
116
139
  """TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
117
140
 
118
141
  # Auto-convert known types when annotated
119
- for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
142
+ # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
143
+ for field in dataclasses.fields( self ):
144
+ var_name = field.name
145
+ var_type = field.type
120
146
 
121
147
  # Annotation for this variable is to be an NDArray
122
- if var_type == NDArray:
148
+ if _is_possibly_ndarray_type( var_type ):
123
149
  # ... so, we'll always auto-convert to numpy
124
150
 
125
151
  var_cur_value = getattr( self, var_name )
@@ -135,6 +161,9 @@ class PackableSample( ABC ):
135
161
  # setattr( self, var_name, var_cur_value.to_numpy )
136
162
 
137
163
  elif isinstance( var_cur_value, bytes ):
164
+ # TODO This does create a constraint that serialized bytes
165
+ # in a field that might be an NDArray are always interpreted
166
+ # as being the NDArray interpretation
138
167
  setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
139
168
 
140
169
  def __post_init__( self ):
@@ -204,7 +233,7 @@ class SampleBatch( Generic[DT] ):
204
233
  @property
205
234
  def sample_type( self ) -> Type:
206
235
  """The type of each sample in this batch"""
207
- return self.__orig_class__.__args__[0]
236
+ return typing.get_args( self.__orig_class__)[0]
208
237
 
209
238
  def __getattr__( self, name ):
210
239
  # Aggregate named params of sample type
@@ -253,7 +282,7 @@ class Dataset( Generic[ST] ):
253
282
  def sample_type( self ) -> Type:
254
283
  """The type of each returned sample from this `Dataset`'s iterator"""
255
284
  # TODO Figure out why linting fails here
256
- return self.__orig_class__.__args__[0]
285
+ return typing.get_args( self.__orig_class__ )[0]
257
286
  @property
258
287
  def batch_type( self ) -> Type:
259
288
  """The type of a batch built from `sample_class`"""
@@ -371,29 +400,29 @@ class Dataset( Generic[ST] ):
371
400
 
372
401
  if batch_size is None:
373
402
  # TODO Duplication here
374
- return wds.DataPipeline(
375
- wds.SimpleShardList( self.url ),
376
- wds.shuffle( buffer_shards ),
377
- wds.split_by_worker,
403
+ return wds.pipeline.DataPipeline(
404
+ wds.shardlists.SimpleShardList( self.url ),
405
+ wds.filters.shuffle( buffer_shards ),
406
+ wds.shardlists.split_by_worker,
378
407
  #
379
- wds.tarfile_to_samples(),
408
+ wds.tariterators.tarfile_to_samples(),
380
409
  # wds.shuffle( buffer_samples ),
381
410
  # wds.map( self.preprocess ),
382
- wds.shuffle( buffer_samples ),
383
- wds.map( self.wrap ),
411
+ wds.filters.shuffle( buffer_samples ),
412
+ wds.filters.map( self.wrap ),
384
413
  )
385
414
 
386
- return wds.DataPipeline(
387
- wds.SimpleShardList( self.url ),
388
- wds.shuffle( buffer_shards ),
389
- wds.split_by_worker,
415
+ return wds.pipeline.DataPipeline(
416
+ wds.shardlists.SimpleShardList( self.url ),
417
+ wds.filters.shuffle( buffer_shards ),
418
+ wds.shardlists.split_by_worker,
390
419
  #
391
- wds.tarfile_to_samples(),
420
+ wds.tariterators.tarfile_to_samples(),
392
421
  # wds.shuffle( buffer_samples ),
393
422
  # wds.map( self.preprocess ),
394
- wds.shuffle( buffer_samples ),
395
- wds.batched( batch_size ),
396
- wds.map( self.wrap_batch ),
423
+ wds.filters.shuffle( buffer_samples ),
424
+ wds.filters.batched( batch_size ),
425
+ wds.filters.map( self.wrap_batch ),
397
426
  )
398
427
 
399
428
  # TODO Rewrite to eliminate `pandas` dependency directly calling
@@ -165,7 +165,7 @@ class LensNetwork:
165
165
  # output_type = sig.return_annotation
166
166
 
167
167
  # self._registry[input_type, output_type] = _lens
168
- print( _lens.source_type )
168
+ # print( _lens.source_type )
169
169
  self._registry[_lens.source_type, _lens.view_type] = _lens
170
170
 
171
171
  def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
@@ -50,6 +50,12 @@ class NumpyTestSampleDecorated:
50
50
  label: int
51
51
  image: NDArray
52
52
 
53
+ @atdata.packable
54
+ class NumpyOptionalSampleDecorated:
55
+ label: int
56
+ image: NDArray
57
+ embeddings: NDArray | None = None
58
+
53
59
  test_cases = [
54
60
  {
55
61
  'SampleType': BasicTestSample,
@@ -91,6 +97,28 @@ test_cases = [
91
97
  'sample_wds_stem': 'numpy_test_decorated',
92
98
  'test_parquet': False,
93
99
  },
100
+ {
101
+ 'SampleType': NumpyOptionalSampleDecorated,
102
+ 'sample_data':
103
+ {
104
+ 'label': 9_001,
105
+ 'image': np.random.randn( 1024, 1024 ),
106
+ 'embeddings': np.random.randn( 512 ),
107
+ },
108
+ 'sample_wds_stem': 'numpy_optional_decorated',
109
+ 'test_parquet': False,
110
+ },
111
+ {
112
+ 'SampleType': NumpyOptionalSampleDecorated,
113
+ 'sample_data':
114
+ {
115
+ 'label': 9_001,
116
+ 'image': np.random.randn( 1024, 1024 ),
117
+ 'embeddings': None,
118
+ },
119
+ 'sample_wds_stem': 'numpy_optional_decorated_none',
120
+ 'test_parquet': False,
121
+ },
94
122
  ]
95
123
 
96
124
 
File without changes
File without changes
File without changes
File without changes
File without changes