atdata 0.1.2a4__tar.gz → 0.1.3b3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,16 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atdata
3
- Version: 0.1.2a4
3
+ Version: 0.1.3b3
4
4
  Summary: A loose federation of distributed, typed datasets
5
5
  Author-email: Maxine Levesque <hello@maxine.science>
6
6
  License-File: LICENSE
7
7
  Requires-Python: >=3.12
8
+ Requires-Dist: fastparquet>=2024.11.0
8
9
  Requires-Dist: msgpack>=1.1.2
9
10
  Requires-Dist: numpy>=2.3.4
10
11
  Requires-Dist: ormsgpack>=1.11.0
12
+ Requires-Dist: pandas>=2.3.3
13
+ Requires-Dist: tqdm>=4.67.1
11
14
  Requires-Dist: webdataset>=1.0.2
12
15
  Description-Content-Type: text/markdown
13
16
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "atdata"
3
- version = "0.1.2a4"
3
+ version = "0.1.3b3"
4
4
  description = "A loose federation of distributed, typed datasets"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -8,9 +8,12 @@ authors = [
8
8
  ]
9
9
  requires-python = ">=3.12"
10
10
  dependencies = [
11
+ "fastparquet>=2024.11.0",
11
12
  "msgpack>=1.1.2",
12
13
  "numpy>=2.3.4",
13
14
  "ormsgpack>=1.11.0",
15
+ "pandas>=2.3.3",
16
+ "tqdm>=4.67.1",
14
17
  "webdataset>=1.0.2",
15
18
  ]
16
19
 
@@ -10,5 +10,11 @@ from .dataset import (
10
10
  packable,
11
11
  )
12
12
 
13
+ from .lens import (
14
+ Lens,
15
+ LensNetwork,
16
+ lens,
17
+ )
18
+
13
19
 
14
20
  #
@@ -5,21 +5,34 @@
5
5
 
6
6
  import webdataset as wds
7
7
 
8
- import functools
9
- from dataclasses import dataclass
8
+ from pathlib import Path
10
9
  import uuid
10
+ import functools
11
11
 
12
- import numpy as np
13
-
12
+ import dataclasses
13
+ import types
14
+ from dataclasses import (
15
+ dataclass,
16
+ asdict,
17
+ )
14
18
  from abc import (
15
19
  ABC,
16
20
  abstractmethod,
17
21
  )
22
+
23
+ from tqdm import tqdm
24
+ import numpy as np
25
+ import pandas as pd
26
+
27
+ import typing
18
28
  from typing import (
19
29
  Any,
20
30
  Optional,
21
31
  Dict,
22
32
  Sequence,
33
+ Iterable,
34
+ Callable,
35
+ Union,
23
36
  #
24
37
  Self,
25
38
  Generic,
@@ -40,14 +53,20 @@ from numpy.typing import (
40
53
  import msgpack
41
54
  import ormsgpack
42
55
  from . import _helpers as eh
56
+ from .lens import Lens, LensNetwork
43
57
 
44
58
 
45
59
  ##
46
60
  # Typing help
47
61
 
62
+ Pathlike = str | Path
63
+
48
64
  WDSRawSample: TypeAlias = Dict[str, Any]
49
65
  WDSRawBatch: TypeAlias = Dict[str, Any]
50
66
 
67
+ SampleExportRow: TypeAlias = Dict[str, Any]
68
+ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
69
+
51
70
 
52
71
  ##
53
72
  # Main base classes
@@ -94,6 +113,25 @@ def _make_packable( x ):
94
113
  return eh.array_to_bytes( x )
95
114
  return x
96
115
 
116
+ def _is_possibly_ndarray_type( t ):
117
+ """Checks if a type annotation is possibly an NDArray."""
118
+
119
+ # Directly an NDArray
120
+ if t == NDArray:
121
+ # print( 'is an NDArray' )
122
+ return True
123
+
124
+ # Check for Optionals (i.e., NDArray | None)
125
+ if isinstance( t, types.UnionType ):
126
+ t_parts = t.__args__
127
+ if any( x == NDArray
128
+ for x in t_parts ):
129
+ return True
130
+
131
+ # Not an NDArray
132
+ return False
133
+
134
+ @dataclass
97
135
  class PackableSample( ABC ):
98
136
  """A sample that can be packed and unpacked with msgpack"""
99
137
 
@@ -101,10 +139,13 @@ class PackableSample( ABC ):
101
139
  """TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
102
140
 
103
141
  # Auto-convert known types when annotated
104
- for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
142
+ # for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
143
+ for field in dataclasses.fields( self ):
144
+ var_name = field.name
145
+ var_type = field.type
105
146
 
106
147
  # Annotation for this variable is to be an NDArray
107
- if var_type == NDArray:
148
+ if _is_possibly_ndarray_type( var_type ):
108
149
  # ... so, we'll always auto-convert to numpy
109
150
 
110
151
  var_cur_value = getattr( self, var_name )
@@ -120,6 +161,9 @@ class PackableSample( ABC ):
120
161
  # setattr( self, var_name, var_cur_value.to_numpy )
121
162
 
122
163
  elif isinstance( var_cur_value, bytes ):
164
+ # TODO This does create a constraint that serialized bytes
165
+ # in a field that might be an NDArray are always interpreted
166
+ # as being the NDArray interpretation
123
167
  setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
124
168
 
125
169
  def __post_init__( self ):
@@ -189,7 +233,7 @@ class SampleBatch( Generic[DT] ):
189
233
  @property
190
234
  def sample_type( self ) -> Type:
191
235
  """The type of each sample in this batch"""
192
- return self.__orig_class__.__args__[0]
236
+ return typing.get_args( self.__orig_class__)[0]
193
237
 
194
238
  def __getattr__( self, name ):
195
239
  # Aggregate named params of sample type
@@ -217,6 +261,8 @@ class SampleBatch( Generic[DT] ):
217
261
  ST = TypeVar( 'ST', bound = PackableSample )
218
262
  # BT = TypeVar( 'BT' )
219
263
 
264
+ RT = TypeVar( 'RT', bound = PackableSample )
265
+
220
266
  # TODO For python 3.13
221
267
  # BT = TypeVar( 'BT', default = None )
222
268
  # IT = TypeVar( 'IT', default = Any )
@@ -235,7 +281,8 @@ class Dataset( Generic[ST] ):
235
281
  @property
236
282
  def sample_type( self ) -> Type:
237
283
  """The type of each returned sample from this `Dataset`'s iterator"""
238
- return self.__orig_class__.__args__[0]
284
+ # TODO Figure out why linting fails here
285
+ return typing.get_args( self.__orig_class__ )[0]
239
286
  @property
240
287
  def batch_type( self ) -> Type:
241
288
  """The type of a batch built from `sample_class`"""
@@ -253,6 +300,17 @@ class Dataset( Generic[ST] ):
253
300
  super().__init__()
254
301
  self.url = url
255
302
 
303
+ # Allow addition of automatic transformation of raw underlying data
304
+ self._output_lens: Lens | None = None
305
+
306
+ def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
307
+ """TODO"""
308
+ ret = Dataset[other]( self.url )
309
+ # Get the singleton lens registry
310
+ lenses = LensNetwork()
311
+ ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
312
+ return ret
313
+
256
314
  # @classmethod
257
315
  # def register( cls, uri: str,
258
316
  # sample_class: Type,
@@ -278,15 +336,15 @@ class Dataset( Generic[ST] ):
278
336
  A full (non-lazy) list of the individual ``tar`` files within the
279
337
  source WebDataset.
280
338
  """
281
- pipe = wds.DataPipeline(
282
- wds.SimpleShardList( self.url ),
283
- wds.map( lambda x: x['url'] )
339
+ pipe = wds.pipeline.DataPipeline(
340
+ wds.shardlists.SimpleShardList( self.url ),
341
+ wds.filters.map( lambda x: x['url'] )
284
342
  )
285
343
  return list( pipe )
286
344
 
287
345
  def ordered( self,
288
346
  batch_size: int | None = 1,
289
- ) -> wds.DataPipeline:
347
+ ) -> Iterable[ST]:
290
348
  """Iterate over the dataset in order
291
349
 
292
350
  Args:
@@ -302,30 +360,30 @@ class Dataset( Generic[ST] ):
302
360
 
303
361
  if batch_size is None:
304
362
  # TODO Duplication here
305
- return wds.DataPipeline(
306
- wds.SimpleShardList( self.url ),
307
- wds.split_by_worker,
363
+ return wds.pipeline.DataPipeline(
364
+ wds.shardlists.SimpleShardList( self.url ),
365
+ wds.shardlists.split_by_worker,
308
366
  #
309
- wds.tarfile_to_samples(),
367
+ wds.tariterators.tarfile_to_samples(),
310
368
  # wds.map( self.preprocess ),
311
- wds.map( self.wrap ),
369
+ wds.filters.map( self.wrap ),
312
370
  )
313
371
 
314
- return wds.DataPipeline(
315
- wds.SimpleShardList( self.url ),
316
- wds.split_by_worker,
372
+ return wds.pipeline.DataPipeline(
373
+ wds.shardlists.SimpleShardList( self.url ),
374
+ wds.shardlists.split_by_worker,
317
375
  #
318
- wds.tarfile_to_samples(),
376
+ wds.tariterators.tarfile_to_samples(),
319
377
  # wds.map( self.preprocess ),
320
- wds.batched( batch_size ),
321
- wds.map( self.wrap_batch ),
378
+ wds.filters.batched( batch_size ),
379
+ wds.filters.map( self.wrap_batch ),
322
380
  )
323
381
 
324
382
  def shuffled( self,
325
383
  buffer_shards: int = 100,
326
384
  buffer_samples: int = 10_000,
327
385
  batch_size: int | None = 1,
328
- ) -> wds.DataPipeline:
386
+ ) -> Iterable[ST]:
329
387
  """Iterate over the dataset in random order
330
388
 
331
389
  Args:
@@ -342,30 +400,88 @@ class Dataset( Generic[ST] ):
342
400
 
343
401
  if batch_size is None:
344
402
  # TODO Duplication here
345
- return wds.DataPipeline(
346
- wds.SimpleShardList( self.url ),
347
- wds.shuffle( buffer_shards ),
348
- wds.split_by_worker,
403
+ return wds.pipeline.DataPipeline(
404
+ wds.shardlists.SimpleShardList( self.url ),
405
+ wds.filters.shuffle( buffer_shards ),
406
+ wds.shardlists.split_by_worker,
349
407
  #
350
- wds.tarfile_to_samples(),
408
+ wds.tariterators.tarfile_to_samples(),
351
409
  # wds.shuffle( buffer_samples ),
352
410
  # wds.map( self.preprocess ),
353
- wds.shuffle( buffer_samples ),
354
- wds.map( self.wrap ),
411
+ wds.filters.shuffle( buffer_samples ),
412
+ wds.filters.map( self.wrap ),
355
413
  )
356
414
 
357
- return wds.DataPipeline(
358
- wds.SimpleShardList( self.url ),
359
- wds.shuffle( buffer_shards ),
360
- wds.split_by_worker,
415
+ return wds.pipeline.DataPipeline(
416
+ wds.shardlists.SimpleShardList( self.url ),
417
+ wds.filters.shuffle( buffer_shards ),
418
+ wds.shardlists.split_by_worker,
361
419
  #
362
- wds.tarfile_to_samples(),
420
+ wds.tariterators.tarfile_to_samples(),
363
421
  # wds.shuffle( buffer_samples ),
364
422
  # wds.map( self.preprocess ),
365
- wds.shuffle( buffer_samples ),
366
- wds.batched( batch_size ),
367
- wds.map( self.wrap_batch ),
423
+ wds.filters.shuffle( buffer_samples ),
424
+ wds.filters.batched( batch_size ),
425
+ wds.filters.map( self.wrap_batch ),
368
426
  )
427
+
428
+ # TODO Rewrite to eliminate `pandas` dependency directly calling
429
+ # `fastparquet`
430
+ def to_parquet( self, path: Pathlike,
431
+ sample_map: Optional[SampleExportMap] = None,
432
+ maxcount: Optional[int] = None,
433
+ **kwargs,
434
+ ):
435
+ """Save dataset contents to a `parquet` file at `path`
436
+
437
+ `kwargs` sent to `pandas.to_parquet`
438
+ """
439
+ ##
440
+
441
+ # Normalize args
442
+ path = Path( path )
443
+ if sample_map is None:
444
+ sample_map = asdict
445
+
446
+ verbose = kwargs.get( 'verbose', False )
447
+
448
+ it = self.ordered( batch_size = None )
449
+ if verbose:
450
+ it = tqdm( it )
451
+
452
+ #
453
+
454
+ if maxcount is None:
455
+ # Load and save full dataset
456
+ df = pd.DataFrame( [ sample_map( x )
457
+ for x in self.ordered( batch_size = None ) ] )
458
+ df.to_parquet( path, **kwargs )
459
+
460
+ else:
461
+ # Load and save dataset in segments of size `maxcount`
462
+
463
+ cur_segment = 0
464
+ cur_buffer = []
465
+ path_template = (path.parent / f'{path.stem}-%06d.{path.suffix}').as_posix()
466
+
467
+ for x in self.ordered( batch_size = None ):
468
+ cur_buffer.append( sample_map( x ) )
469
+
470
+ if len( cur_buffer ) >= maxcount:
471
+ # Write current segment
472
+ cur_path = path_template.format( cur_segment )
473
+ df = pd.DataFrame( cur_buffer )
474
+ df.to_parquet( cur_path, **kwargs )
475
+
476
+ cur_segment += 1
477
+ cur_buffer = []
478
+
479
+ if len( cur_buffer ) > 0:
480
+ # Write one last segment with remainder
481
+ cur_path = path_template.format( cur_segment )
482
+ df = pd.DataFrame( cur_buffer )
483
+ df.to_parquet( cur_path, **kwargs )
484
+
369
485
 
370
486
  # Implemented by specific subclasses
371
487
 
@@ -388,20 +504,24 @@ class Dataset( Generic[ST] ):
388
504
  assert 'msgpack' in sample
389
505
  assert type( sample['msgpack'] ) == bytes
390
506
 
391
- return self.sample_type.from_bytes( sample['msgpack'] )
507
+ if self._output_lens is None:
508
+ return self.sample_type.from_bytes( sample['msgpack'] )
509
+
510
+ source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
511
+ return self._output_lens( source_sample )
392
512
 
393
- try:
394
- assert type( sample ) == dict
395
- return cls.sample_class( **{
396
- k: v
397
- for k, v in sample.items() if k != '__key__'
398
- } )
513
+ # try:
514
+ # assert type( sample ) == dict
515
+ # return cls.sample_class( **{
516
+ # k: v
517
+ # for k, v in sample.items() if k != '__key__'
518
+ # } )
399
519
 
400
- except Exception as e:
401
- # Sample constructor failed -- revert to default
402
- return AnySample(
403
- value = sample,
404
- )
520
+ # except Exception as e:
521
+ # # Sample constructor failed -- revert to default
522
+ # return AnySample(
523
+ # value = sample,
524
+ # )
405
525
 
406
526
  def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
407
527
  """Wrap a `batch` of samples into the appropriate dataset-specific type
@@ -410,10 +530,17 @@ class Dataset( Generic[ST] ):
410
530
  """
411
531
 
412
532
  assert 'msgpack' in batch
413
- batch_unpacked = [ self.sample_type.from_bytes( bs )
414
- for bs in batch['msgpack'] ]
415
- return SampleBatch[self.sample_type]( batch_unpacked )
416
533
 
534
+ if self._output_lens is None:
535
+ batch_unpacked = [ self.sample_type.from_bytes( bs )
536
+ for bs in batch['msgpack'] ]
537
+ return SampleBatch[self.sample_type]( batch_unpacked )
538
+
539
+ batch_source = [ self._output_lens.source_type.from_bytes( bs )
540
+ for bs in batch['msgpack'] ]
541
+ batch_view = [ self._output_lens( s )
542
+ for s in batch_source ]
543
+ return SampleBatch[self.sample_type]( batch_view )
417
544
 
418
545
  # # @classmethod
419
546
  # def wrap_batch( self, batch: WDSRawBatch ) -> BT:
@@ -449,6 +576,9 @@ def packable( cls ):
449
576
 
450
577
  ##
451
578
 
579
+ class_name = cls.__name__
580
+ class_annotations = cls.__annotations__
581
+
452
582
  # Add in dataclass niceness to original class
453
583
  as_dataclass = dataclass( cls )
454
584
 
@@ -458,8 +588,9 @@ def packable( cls ):
458
588
  def __post_init__( self ):
459
589
  return PackableSample.__post_init__( self )
460
590
 
461
- as_packable.__name__ = cls.__name__
462
- as_packable.__annotations__ = cls.__annotations__
591
+ # TODO This doesn't properly carry over the original
592
+ as_packable.__name__ = class_name
593
+ as_packable.__annotations__ = class_annotations
463
594
 
464
595
  ##
465
596
 
@@ -0,0 +1,200 @@
1
+ """Lenses between typed datasets"""
2
+
3
+ ##
4
+ # Imports
5
+
6
+ import functools
7
+ import inspect
8
+
9
+ from typing import (
10
+ TypeAlias,
11
+ Type,
12
+ TypeVar,
13
+ Tuple,
14
+ Dict,
15
+ Callable,
16
+ Optional,
17
+ Generic,
18
+ #
19
+ TYPE_CHECKING
20
+ )
21
+
22
+ if TYPE_CHECKING:
23
+ from .dataset import PackableSample
24
+
25
+
26
+ ##
27
+ # Typing helpers
28
+
29
+ DatasetType: TypeAlias = Type['PackableSample']
30
+ LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
31
+
32
+ S = TypeVar( 'S', bound = 'PackableSample' )
33
+ V = TypeVar( 'V', bound = 'PackableSample' )
34
+ type LensGetter[S, V] = Callable[[S], V]
35
+ type LensPutter[S, V] = Callable[[V, S], S]
36
+
37
+
38
+ ##
39
+ # Shortcut decorators
40
+
41
+ class Lens( Generic[S, V] ):
42
+ """TODO"""
43
+
44
+ # @property
45
+ # def source_type( self ) -> Type[S]:
46
+ # """The source type (S) for the lens; what is put to"""
47
+ # # TODO Figure out why linting fails here
48
+ # return self.__orig_class__.__args__[0]
49
+
50
+ # @property
51
+ # def view_type( self ) -> Type[V]:
52
+ # """The view type (V) for the lens; what is get'd from"""
53
+ # # TODO FIgure out why linting fails here
54
+ # return self.__orig_class__.__args__[1]
55
+
56
+ def __init__( self, get: LensGetter[S, V],
57
+ put: Optional[LensPutter[S, V]] = None
58
+ ) -> None:
59
+ """TODO"""
60
+ ##
61
+
62
+ # Check argument validity
63
+
64
+ sig = inspect.signature( get )
65
+ input_types = list( sig.parameters.values() )
66
+ assert len( input_types ) == 1, \
67
+ 'Wrong number of input args for lens: should only have one'
68
+
69
+ # Update function details for this object as returned by annotation
70
+ functools.update_wrapper( self, get )
71
+
72
+ self.source_type: Type[PackableSample] = input_types[0].annotation
73
+ self.view_type = sig.return_annotation
74
+
75
+ # Store the getter
76
+ self._getter = get
77
+
78
+ # Determine and store the putter
79
+ if put is None:
80
+ # Trivial putter does not update the source
81
+ def _trivial_put( v: V, s: S ) -> S:
82
+ return s
83
+ put = _trivial_put
84
+ self._putter = put
85
+
86
+ #
87
+
88
+ def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
89
+ """TODO"""
90
+ ##
91
+ self._putter = put
92
+ return put
93
+
94
+ # Methods to actually execute transformations
95
+
96
+ def put( self, v: V, s: S ) -> S:
97
+ """TODO"""
98
+ return self._putter( v, s )
99
+
100
+ def get( self, s: S ) -> V:
101
+ """TODO"""
102
+ return self( s )
103
+
104
+ # Convenience to enable calling the lens as its getter
105
+
106
+ def __call__( self, s: S ) -> V:
107
+ return self._getter( s )
108
+
109
+ # TODO Figure out how to properly parameterize this
110
+ # def _lens_factory[S, V]( register: bool = True ):
111
+ # """Register the annotated function `f` as the getter of a sample lens"""
112
+
113
+ # # The actual lens decorator taking a lens getter function to a lens object
114
+ # def _decorator( f: LensGetter[S, V] ) -> Lens[S, V]:
115
+ # ret = Lens[S, V]( f )
116
+ # if register:
117
+ # _network.register( ret )
118
+ # return ret
119
+
120
+ # # Return the lens decorator
121
+ # return _decorator
122
+
123
+ # # For convenience
124
+ # lens = _lens_factory
125
+
126
+ def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
127
+ ret = Lens[S, V]( f )
128
+ _network.register( ret )
129
+ return ret
130
+
131
+
132
+ ##
133
+ # Global registry of used lenses
134
+
135
+ # _registered_lenses: Dict[LensSignature, Lens] = dict()
136
+ # """TODO"""
137
+
138
+ class LensNetwork:
139
+ """TODO"""
140
+
141
+ _instance = None
142
+ """The singleton instance"""
143
+
144
+ def __new__(cls, *args, **kwargs):
145
+ if cls._instance is None:
146
+ # If no instance exists, create a new one
147
+ cls._instance = super().__new__(cls)
148
+ return cls._instance # Return the existing (or newly created) instance
149
+
150
+ def __init__(self):
151
+ if not hasattr(self, '_initialized'): # Check if already initialized
152
+ self._registry: Dict[LensSignature, Lens] = dict()
153
+ self._initialized = True
154
+
155
+ def register( self, _lens: Lens ):
156
+ """Set `lens` as the canonical view between its source and view types"""
157
+
158
+ # sig = inspect.signature( _lens.get )
159
+ # input_types = list( sig.parameters.values() )
160
+ # assert len( input_types ) == 1, \
161
+ # 'Wrong number of input args for lens: should only have one'
162
+
163
+ # input_type = input_types[0].annotation
164
+ # print( input_type )
165
+ # output_type = sig.return_annotation
166
+
167
+ # self._registry[input_type, output_type] = _lens
168
+ # print( _lens.source_type )
169
+ self._registry[_lens.source_type, _lens.view_type] = _lens
170
+
171
+ def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
172
+ """TODO"""
173
+
174
+ # TODO Handle compositional closure
175
+ ret = self._registry.get( (source, view), None )
176
+ if ret is None:
177
+ raise ValueError( f'No registered lens from source {source} to view {view}' )
178
+
179
+ return ret
180
+
181
+
182
+ # Create global singleton registry instance
183
+ _network = LensNetwork()
184
+
185
+ # def lens( f: LensPutter ) -> Lens:
186
+ # """Register the annotated function `f` as a sample lens"""
187
+ # ##
188
+
189
+ # sig = inspect.signature( f )
190
+
191
+ # input_types = list( sig.parameters.values() )
192
+ # output_type = sig.return_annotation
193
+
194
+ # _registered_lenses[]
195
+
196
+ # f.lens = Lens(
197
+
198
+ # )
199
+
200
+ # return f
@@ -50,6 +50,12 @@ class NumpyTestSampleDecorated:
50
50
  label: int
51
51
  image: NDArray
52
52
 
53
+ @atdata.packable
54
+ class NumpyOptionalSampleDecorated:
55
+ label: int
56
+ image: NDArray
57
+ embeddings: NDArray | None = None
58
+
53
59
  test_cases = [
54
60
  {
55
61
  'SampleType': BasicTestSample,
@@ -59,6 +65,7 @@ test_cases = [
59
65
  'value': 1024.768,
60
66
  },
61
67
  'sample_wds_stem': 'basic_test',
68
+ 'test_parquet': True,
62
69
  },
63
70
  {
64
71
  'SampleType': NumpyTestSample,
@@ -68,6 +75,7 @@ test_cases = [
68
75
  'image': np.random.randn( 1024, 1024 ),
69
76
  },
70
77
  'sample_wds_stem': 'numpy_test',
78
+ 'test_parquet': False,
71
79
  },
72
80
  {
73
81
  'SampleType': BasicTestSampleDecorated,
@@ -77,6 +85,7 @@ test_cases = [
77
85
  'value': 1024.768,
78
86
  },
79
87
  'sample_wds_stem': 'basic_test_decorated',
88
+ 'test_parquet': True,
80
89
  },
81
90
  {
82
91
  'SampleType': NumpyTestSampleDecorated,
@@ -86,6 +95,29 @@ test_cases = [
86
95
  'image': np.random.randn( 1024, 1024 ),
87
96
  },
88
97
  'sample_wds_stem': 'numpy_test_decorated',
98
+ 'test_parquet': False,
99
+ },
100
+ {
101
+ 'SampleType': NumpyOptionalSampleDecorated,
102
+ 'sample_data':
103
+ {
104
+ 'label': 9_001,
105
+ 'image': np.random.randn( 1024, 1024 ),
106
+ 'embeddings': np.random.randn( 512 ),
107
+ },
108
+ 'sample_wds_stem': 'numpy_optional_decorated',
109
+ 'test_parquet': False,
110
+ },
111
+ {
112
+ 'SampleType': NumpyOptionalSampleDecorated,
113
+ 'sample_data':
114
+ {
115
+ 'label': 9_001,
116
+ 'image': np.random.randn( 1024, 1024 ),
117
+ 'embeddings': None,
118
+ },
119
+ 'sample_wds_stem': 'numpy_optional_decorated_none',
120
+ 'test_parquet': False,
89
121
  },
90
122
  ]
91
123
 
@@ -175,7 +207,7 @@ def test_wds(
175
207
  ).as_posix()
176
208
  file_wds_pattern = file_pattern.format( shard_id = '%06d' )
177
209
 
178
- with wds.ShardWriter(
210
+ with wds.writer.ShardWriter(
179
211
  pattern = file_wds_pattern,
180
212
  maxcount = shard_maxcount,
181
213
  ) as sink:
@@ -323,5 +355,56 @@ def test_wds(
323
355
  assert iterations_run == n_iterate, \
324
356
  "Only found {iterations_run} samples, not {n_iterate}"
325
357
 
358
+ #
359
+
360
+ @pytest.mark.parametrize(
361
+ ('SampleType', 'sample_data', 'sample_wds_stem', 'test_parquet'),
362
+ [ (
363
+ case['SampleType'],
364
+ case['sample_data'],
365
+ case['sample_wds_stem'],
366
+ case['test_parquet']
367
+ )
368
+ for case in test_cases ]
369
+ )
370
+ def test_parquet_export(
371
+ SampleType: Type[atdata.PackableSample],
372
+ sample_data: atds.MsgpackRawSample,
373
+ sample_wds_stem: str,
374
+ test_parquet: bool,
375
+ tmp_path
376
+ ):
377
+ """Test our ability to export a dataset to `parquet` format"""
378
+
379
+ # Skip irrelevant test cases
380
+ if not test_parquet:
381
+ return
382
+
383
+ ## Testing hyperparameters
384
+
385
+ n_copies_dataset = 1_000
386
+ n_per_file = 100
387
+
388
+ ## Start out by writing tar dataset
389
+
390
+ wds_filename = (tmp_path / f'{sample_wds_stem}.tar').as_posix()
391
+ with wds.writer.TarWriter( wds_filename ) as sink:
392
+ for _ in range( n_copies_dataset ):
393
+ new_sample = SampleType.from_data( sample_data )
394
+ sink.write( new_sample.as_wds )
395
+
396
+ ## Now export to `parquet`
397
+
398
+ dataset = atdata.Dataset[SampleType]( wds_filename )
399
+ parquet_filename = tmp_path / f'{sample_wds_stem}.parquet'
400
+ dataset.to_parquet( parquet_filename )
401
+
402
+ parquet_filename = tmp_path / f'{sample_wds_stem}-segments.parquet'
403
+ dataset.to_parquet( parquet_filename, maxcount = n_per_file )
404
+
405
+ ## Double-check our `parquet` export
406
+
407
+ # TODO
408
+
326
409
 
327
410
  ##
@@ -0,0 +1,166 @@
1
+ """Test lens functionality."""
2
+
3
+ ##
4
+ # Imports
5
+
6
+ import pytest
7
+
8
+ from dataclasses import dataclass
9
+ import webdataset as wds
10
+ import atdata
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+
15
+
16
+ ##
17
+ # Tests
18
+
19
+ def test_lens():
20
+ """Test a lens between sample types"""
21
+
22
+ # Set up the lens scenario
23
+
24
+ @atdata.packable
25
+ class Source:
26
+ name: str
27
+ age: int
28
+ height: float
29
+
30
+ @atdata.packable
31
+ class View:
32
+ name: str
33
+ height: float
34
+
35
+ @atdata.lens
36
+ def polite( s: Source ) -> View:
37
+ return View(
38
+ name = s.name,
39
+ height = s.height,
40
+ )
41
+
42
+ @polite.putter
43
+ def polite_update( v: View, s: Source ) -> Source:
44
+ return Source(
45
+ name = v.name,
46
+ height = v.height,
47
+ #
48
+ age = s.age,
49
+ )
50
+
51
+ # Test with an example sample
52
+
53
+ test_source = Source(
54
+ name = 'Hello World',
55
+ age = 42,
56
+ height = 182.9,
57
+ )
58
+ correct_view = View(
59
+ name = test_source.name,
60
+ height = test_source.height,
61
+ )
62
+
63
+ test_view = polite( test_source )
64
+ assert test_view == correct_view, \
65
+ f'Incorrect lens behavior: {test_view}, and not {correct_view}'
66
+
67
+ # This lens should be well-behaved
68
+
69
+ update_view = View(
70
+ name = 'Now Taller',
71
+ height = 192.9,
72
+ )
73
+
74
+ x = polite( polite.put( update_view, test_source ) )
75
+ assert x == update_view, \
76
+ f'Violation of GetPut: {x} =/= {update_view}'
77
+
78
+ y = polite.put( polite( test_source ), test_source )
79
+ assert y == test_source, \
80
+ f'Violation of PutGet: {y} =/= {test_source}'
81
+
82
+ # TODO Test PutPut
83
+
84
+ def test_conversion( tmp_path ):
85
+ """Test automatic interconversion between sample types"""
86
+
87
+ @dataclass
88
+ class Source( atdata.PackableSample ):
89
+ name: str
90
+ height: float
91
+ favorite_pizza: str
92
+ favorite_image: NDArray
93
+
94
+ @dataclass
95
+ class View( atdata.PackableSample ):
96
+ name: str
97
+ favorite_pizza: str
98
+ favorite_image: NDArray
99
+
100
+ @atdata.lens
101
+ def polite( s: Source ) -> View:
102
+ return View(
103
+ name = s.name,
104
+ favorite_pizza = s.favorite_pizza,
105
+ favorite_image = s.favorite_image,
106
+ )
107
+
108
+ lens_network = atdata.LensNetwork()
109
+ print( lens_network._registry )
110
+
111
+ # Map a test sample through the view
112
+ test_source = Source(
113
+ name = 'Larry',
114
+ height = 42.,
115
+ favorite_pizza = 'pineapple',
116
+ favorite_image = np.random.randn( 224, 224 )
117
+ )
118
+ test_view = polite( test_source )
119
+
120
+ # Create a test dataset
121
+
122
+ k_test = 100
123
+ test_filename = (
124
+ tmp_path
125
+ / 'test-source.tar'
126
+ ).as_posix()
127
+
128
+ with wds.writer.TarWriter( test_filename ) as dest:
129
+ for i in range( k_test ):
130
+ # Create a new copied sample
131
+ cur_sample = Source(
132
+ name = test_source.name,
133
+ height = test_source.height,
134
+ favorite_pizza = test_source.favorite_pizza,
135
+ favorite_image = test_source.favorite_image,
136
+ )
137
+ dest.write( cur_sample.as_wds )
138
+
139
+ # Try reading the test dataset
140
+
141
+ ds = (
142
+ atdata.Dataset[Source]( test_filename )
143
+ .as_type( View )
144
+ )
145
+
146
+ assert ds.sample_type == View, \
147
+ 'Auto-mapped'
148
+
149
+ sample: View | None = None
150
+ for sample in ds.ordered( batch_size = None ):
151
+ # Load only the first sample
152
+ break
153
+
154
+ assert sample is not None, \
155
+ 'Did not load any samples from `Source` dataset'
156
+
157
+ assert sample.name == test_view.name, \
158
+ f'Divergence on auto-mapped dataset: `name` should be {test_view.name}, but is {sample.name}'
159
+ # assert sample.height == test_view.height, \
160
+ # f'Divergence on auto-mapped dataset: `height` should be {test_view.height}, but is {sample.height}'
161
+ assert sample.favorite_pizza == test_view.favorite_pizza, \
162
+ f'Divergence on auto-mapped dataset: `favorite_pizza` should be {test_view.favorite_pizza}, but is {sample.favorite_pizza}'
163
+ assert np.all( sample.favorite_image == test_view.favorite_image ), \
164
+ f'Divergence on auto-mapped dataset: `favorite_image`'
165
+
166
+ ##
File without changes
File without changes
File without changes
File without changes