atdata 0.1.2a4__tar.gz → 0.1.2b1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,16 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: atdata
3
- Version: 0.1.2a4
3
+ Version: 0.1.2b1
4
4
  Summary: A loose federation of distributed, typed datasets
5
5
  Author-email: Maxine Levesque <hello@maxine.science>
6
6
  License-File: LICENSE
7
7
  Requires-Python: >=3.12
8
+ Requires-Dist: fastparquet>=2024.11.0
8
9
  Requires-Dist: msgpack>=1.1.2
9
10
  Requires-Dist: numpy>=2.3.4
10
11
  Requires-Dist: ormsgpack>=1.11.0
12
+ Requires-Dist: pandas>=2.3.3
13
+ Requires-Dist: tqdm>=4.67.1
11
14
  Requires-Dist: webdataset>=1.0.2
12
15
  Description-Content-Type: text/markdown
13
16
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "atdata"
3
- version = "0.1.2a4"
3
+ version = "0.1.2b1"
4
4
  description = "A loose federation of distributed, typed datasets"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -8,9 +8,12 @@ authors = [
8
8
  ]
9
9
  requires-python = ">=3.12"
10
10
  dependencies = [
11
+ "fastparquet>=2024.11.0",
11
12
  "msgpack>=1.1.2",
12
13
  "numpy>=2.3.4",
13
14
  "ormsgpack>=1.11.0",
15
+ "pandas>=2.3.3",
16
+ "tqdm>=4.67.1",
14
17
  "webdataset>=1.0.2",
15
18
  ]
16
19
 
@@ -10,5 +10,10 @@ from .dataset import (
10
10
  packable,
11
11
  )
12
12
 
13
+ from .lens import (
14
+ Lens,
15
+ lens,
16
+ )
17
+
13
18
 
14
19
  #
@@ -5,21 +5,29 @@
5
5
 
6
6
  import webdataset as wds
7
7
 
8
- import functools
9
- from dataclasses import dataclass
8
+ from pathlib import Path
10
9
  import uuid
11
-
12
- import numpy as np
13
-
10
+ import functools
11
+ from dataclasses import (
12
+ dataclass,
13
+ asdict,
14
+ )
14
15
  from abc import (
15
16
  ABC,
16
17
  abstractmethod,
17
18
  )
19
+
20
+ from tqdm import tqdm
21
+ import numpy as np
22
+ import pandas as pd
23
+
18
24
  from typing import (
19
25
  Any,
20
26
  Optional,
21
27
  Dict,
22
28
  Sequence,
29
+ Iterable,
30
+ Callable,
23
31
  #
24
32
  Self,
25
33
  Generic,
@@ -45,9 +53,14 @@ from . import _helpers as eh
45
53
  ##
46
54
  # Typing help
47
55
 
56
+ Pathlike = str | Path
57
+
48
58
  WDSRawSample: TypeAlias = Dict[str, Any]
49
59
  WDSRawBatch: TypeAlias = Dict[str, Any]
50
60
 
61
+ SampleExportRow: TypeAlias = Dict[str, Any]
62
+ SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
63
+
51
64
 
52
65
  ##
53
66
  # Main base classes
@@ -94,6 +107,7 @@ def _make_packable( x ):
94
107
  return eh.array_to_bytes( x )
95
108
  return x
96
109
 
110
+ @dataclass
97
111
  class PackableSample( ABC ):
98
112
  """A sample that can be packed and unpacked with msgpack"""
99
113
 
@@ -235,6 +249,7 @@ class Dataset( Generic[ST] ):
235
249
  @property
236
250
  def sample_type( self ) -> Type:
237
251
  """The type of each returned sample from this `Dataset`'s iterator"""
252
+ # TODO Figure out why linting fails here
238
253
  return self.__orig_class__.__args__[0]
239
254
  @property
240
255
  def batch_type( self ) -> Type:
@@ -286,7 +301,7 @@ class Dataset( Generic[ST] ):
286
301
 
287
302
  def ordered( self,
288
303
  batch_size: int | None = 1,
289
- ) -> wds.DataPipeline:
304
+ ) -> Iterable[ST]:
290
305
  """Iterate over the dataset in order
291
306
 
292
307
  Args:
@@ -325,7 +340,7 @@ class Dataset( Generic[ST] ):
325
340
  buffer_shards: int = 100,
326
341
  buffer_samples: int = 10_000,
327
342
  batch_size: int | None = 1,
328
- ) -> wds.DataPipeline:
343
+ ) -> Iterable[ST]:
329
344
  """Iterate over the dataset in random order
330
345
 
331
346
  Args:
@@ -366,6 +381,64 @@ class Dataset( Generic[ST] ):
366
381
  wds.batched( batch_size ),
367
382
  wds.map( self.wrap_batch ),
368
383
  )
384
+
385
+ # TODO Rewrite to eliminate `pandas` dependency directly calling
386
+ # `fastparquet`
387
+ def to_parquet( self, path: Pathlike,
388
+ sample_map: Optional[SampleExportMap] = None,
389
+ maxcount: Optional[int] = None,
390
+ **kwargs,
391
+ ):
392
+ """Save dataset contents to a `parquet` file at `path`
393
+
394
+ `kwargs` sent to `pandas.to_parquet`
395
+ """
396
+ ##
397
+
398
+ # Normalize args
399
+ path = Path( path )
400
+ if sample_map is None:
401
+ sample_map = asdict
402
+
403
+ verbose = kwargs.get( 'verbose', False )
404
+
405
+ it = self.ordered( batch_size = None )
406
+ if verbose:
407
+ it = tqdm( it )
408
+
409
+ #
410
+
411
+ if maxcount is None:
412
+ # Load and save full dataset
413
+ df = pd.DataFrame( [ sample_map( x )
414
+ for x in self.ordered( batch_size = None ) ] )
415
+ df.to_parquet( path, **kwargs )
416
+
417
+ else:
418
+ # Load and save dataset in segments of size `maxcount`
419
+
420
+ cur_segment = 0
421
+ cur_buffer = []
422
+ path_template = (path.parent / f'{path.stem}-%06d.{path.suffix}').as_posix()
423
+
424
+ for x in self.ordered( batch_size = None ):
425
+ cur_buffer.append( sample_map( x ) )
426
+
427
+ if len( cur_buffer ) >= maxcount:
428
+ # Write current segment
429
+ cur_path = path_template.format( cur_segment )
430
+ df = pd.DataFrame( cur_buffer )
431
+ df.to_parquet( cur_path, **kwargs )
432
+
433
+ cur_segment += 1
434
+ cur_buffer = []
435
+
436
+ if len( cur_buffer ) > 0:
437
+ # Write one last segment with remainder
438
+ cur_path = path_template.format( cur_segment )
439
+ df = pd.DataFrame( cur_buffer )
440
+ df.to_parquet( cur_path, **kwargs )
441
+
369
442
 
370
443
  # Implemented by specific subclasses
371
444
 
@@ -390,18 +463,18 @@ class Dataset( Generic[ST] ):
390
463
 
391
464
  return self.sample_type.from_bytes( sample['msgpack'] )
392
465
 
393
- try:
394
- assert type( sample ) == dict
395
- return cls.sample_class( **{
396
- k: v
397
- for k, v in sample.items() if k != '__key__'
398
- } )
466
+ # try:
467
+ # assert type( sample ) == dict
468
+ # return cls.sample_class( **{
469
+ # k: v
470
+ # for k, v in sample.items() if k != '__key__'
471
+ # } )
399
472
 
400
- except Exception as e:
401
- # Sample constructor failed -- revert to default
402
- return AnySample(
403
- value = sample,
404
- )
473
+ # except Exception as e:
474
+ # # Sample constructor failed -- revert to default
475
+ # return AnySample(
476
+ # value = sample,
477
+ # )
405
478
 
406
479
  def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
407
480
  """Wrap a `batch` of samples into the appropriate dataset-specific type
@@ -449,6 +522,9 @@ def packable( cls ):
449
522
 
450
523
  ##
451
524
 
525
+ class_name = cls.__name__
526
+ class_annotations = cls.__annotations__
527
+
452
528
  # Add in dataclass niceness to original class
453
529
  as_dataclass = dataclass( cls )
454
530
 
@@ -458,8 +534,9 @@ def packable( cls ):
458
534
  def __post_init__( self ):
459
535
  return PackableSample.__post_init__( self )
460
536
 
461
- as_packable.__name__ = cls.__name__
462
- as_packable.__annotations__ = cls.__annotations__
537
+ # TODO This doesn't properly carry over the original
538
+ as_packable.__name__ = class_name
539
+ as_packable.__annotations__ = class_annotations
463
540
 
464
541
  ##
465
542
 
@@ -0,0 +1,122 @@
1
+ """Lenses between typed datasets"""
2
+
3
+ ##
4
+ # Imports
5
+
6
+ from .dataset import PackableSample
7
+
8
+ import functools
9
+ import inspect
10
+
11
+ from typing import (
12
+ TypeAlias,
13
+ Type,
14
+ TypeVar,
15
+ Tuple,
16
+ Dict,
17
+ Callable,
18
+ Optional,
19
+ Generic,
20
+ )
21
+
22
+
23
+ ##
24
+ # Typing helpers
25
+
26
+ DatasetType: TypeAlias = Type[PackableSample]
27
+ LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
28
+
29
+ S = TypeVar( 'S', bound = PackableSample )
30
+ V = TypeVar( 'V', bound = PackableSample )
31
+ type LensGetter[S, V] = Callable[[S], V]
32
+ type LensPutter[S, V] = Callable[[V, S], S]
33
+
34
+
35
+ ##
36
+ # Shortcut decorators
37
+
38
+ class Lens( Generic[S, V] ):
39
+ """TODO"""
40
+
41
+ def __init__( self, get: LensGetter[S, V],
42
+ put: Optional[LensPutter[S, V]] = None
43
+ ) -> None:
44
+ """TODO"""
45
+ ##
46
+
47
+ # Update
48
+ functools.update_wrapper( self, get )
49
+
50
+ # Store the getter
51
+ self._getter = get
52
+
53
+ # Determine and store the putter
54
+ if put is None:
55
+ # Trivial putter does not update the source
56
+ def _trivial_put( v: V, s: S ) -> S:
57
+ return s
58
+ put = _trivial_put
59
+
60
+ self._putter = put
61
+
62
+ # Register this lens for this type signature
63
+
64
+ sig = inspect.signature( get )
65
+ input_types = list( sig.parameters.values() )
66
+ assert len( input_types ) == 1, \
67
+ 'Wrong number of input args for lens: should only have one'
68
+
69
+ input_type = input_types[0].annotation
70
+ output_type = sig.return_annotation
71
+
72
+ _registered_lenses[(input_type, output_type)] = self
73
+ print( _registered_lenses )
74
+
75
+ #
76
+
77
+ def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
78
+ """TODO"""
79
+ ##
80
+ self._putter = put
81
+ return put
82
+
83
+ def put( self, v: V, s: S ) -> S:
84
+ """TODO"""
85
+ return self._putter( v, s )
86
+
87
+ def get( self, s: S ) -> V:
88
+ """TODO"""
89
+ return self( s )
90
+
91
+ #
92
+
93
+ def __call__( self, s: S ) -> V:
94
+ return self._getter( s )
95
+
96
+ def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
97
+ """Register the annotated function `f` as the getter of a sample lens"""
98
+ return Lens[S, V]( f )
99
+
100
+
101
+ ##
102
+ # Global registration of used lenses
103
+
104
+ _registered_lenses: Dict[LensSignature, Lens] = dict()
105
+ """TODO"""
106
+
107
+ # def lens( f: LensPutter ) -> Lens:
108
+ # """Register the annotated function `f` as a sample lens"""
109
+ # ##
110
+
111
+ # sig = inspect.signature( f )
112
+
113
+ # input_types = list( sig.parameters.values() )
114
+ # output_type = sig.return_annotation
115
+
116
+ # _registered_lenses[]
117
+
118
+ # f.lens = Lens(
119
+
120
+ # )
121
+
122
+ # return f
@@ -59,6 +59,7 @@ test_cases = [
59
59
  'value': 1024.768,
60
60
  },
61
61
  'sample_wds_stem': 'basic_test',
62
+ 'test_parquet': True,
62
63
  },
63
64
  {
64
65
  'SampleType': NumpyTestSample,
@@ -68,6 +69,7 @@ test_cases = [
68
69
  'image': np.random.randn( 1024, 1024 ),
69
70
  },
70
71
  'sample_wds_stem': 'numpy_test',
72
+ 'test_parquet': False,
71
73
  },
72
74
  {
73
75
  'SampleType': BasicTestSampleDecorated,
@@ -77,6 +79,7 @@ test_cases = [
77
79
  'value': 1024.768,
78
80
  },
79
81
  'sample_wds_stem': 'basic_test_decorated',
82
+ 'test_parquet': True,
80
83
  },
81
84
  {
82
85
  'SampleType': NumpyTestSampleDecorated,
@@ -86,6 +89,7 @@ test_cases = [
86
89
  'image': np.random.randn( 1024, 1024 ),
87
90
  },
88
91
  'sample_wds_stem': 'numpy_test_decorated',
92
+ 'test_parquet': False,
89
93
  },
90
94
  ]
91
95
 
@@ -323,5 +327,117 @@ def test_wds(
323
327
  assert iterations_run == n_iterate, \
324
328
  "Only found {iterations_run} samples, not {n_iterate}"
325
329
 
330
+ #
331
+
332
+ @pytest.mark.parametrize(
333
+ ('SampleType', 'sample_data', 'sample_wds_stem', 'test_parquet'),
334
+ [ (
335
+ case['SampleType'],
336
+ case['sample_data'],
337
+ case['sample_wds_stem'],
338
+ case['test_parquet']
339
+ )
340
+ for case in test_cases ]
341
+ )
342
+ def test_create_sample(
343
+ SampleType: Type[atdata.PackableSample],
344
+ sample_data: atds.MsgpackRawSample,
345
+ sample_wds_stem: str,
346
+ test_parquet: bool,
347
+ tmp_path
348
+ ):
349
+ """Test our ability to export a dataset to `parquet` format"""
350
+
351
+ # Skip irrelevant test cases
352
+ if not test_parquet:
353
+ return
354
+
355
+ ## Testing hyperparameters
356
+
357
+ n_copies_dataset = 1_000
358
+ n_per_file = 100
359
+
360
+ ## Start out by writing tar dataset
361
+
362
+ wds_filename = (tmp_path / f'{sample_wds_stem}.tar').as_posix()
363
+ with wds.TarWriter( wds_filename ) as sink:
364
+ for _ in range( n_copies_dataset ):
365
+ new_sample = SampleType.from_data( sample_data )
366
+ sink.write( new_sample.as_wds )
367
+
368
+ ## Now export to `parquet`
369
+
370
+ dataset = atdata.Dataset[SampleType]( wds_filename )
371
+ parquet_filename = tmp_path / f'{sample_wds_stem}.parquet'
372
+ dataset.to_parquet( parquet_filename )
373
+
374
+ ## Double-check our `parquet` export
375
+
376
+ # TODO
377
+
378
+ def test_lens():
379
+ """Test a lens between sample types"""
380
+
381
+ # Set up the lens scenario
382
+
383
+ @atdata.packable
384
+ class Source:
385
+ name: str
386
+ age: int
387
+ height: float
388
+
389
+ @atdata.packable
390
+ class View:
391
+ name: str
392
+ height: float
393
+
394
+ @atdata.lens
395
+ def polite( s: Source ) -> View:
396
+ return View(
397
+ name = s.name,
398
+ height = s.height,
399
+ )
400
+
401
+ @polite.putter
402
+ def polite_update( v: View, s: Source ) -> Source:
403
+ return Source(
404
+ name = v.name,
405
+ height = v.height,
406
+ #
407
+ age = s.age,
408
+ )
409
+
410
+ # Test with an example sample
411
+
412
+ test_source = Source(
413
+ name = 'Hello World',
414
+ age = 42,
415
+ height = 182.9,
416
+ )
417
+ correct_view = View(
418
+ name = test_source.name,
419
+ height = test_source.height,
420
+ )
421
+
422
+ test_view = polite( test_source )
423
+ assert test_view == correct_view, \
424
+ f'Incorrect lens behavior: {test_view}, and not {correct_view}'
425
+
426
+ # This lens should be well-behaved
427
+
428
+ update_view = View(
429
+ name = 'Now Taller',
430
+ height = 192.9,
431
+ )
432
+
433
+ x = polite( polite.put( update_view, test_source ) )
434
+ assert x == update_view, \
435
+ f'Violation of GetPut: {x} =/= {update_view}'
436
+
437
+ y = polite.put( polite( test_source ), test_source )
438
+ assert y == test_source, \
439
+ f'Violation of PutGet: {y} =/= {test_source}'
440
+
441
+ # TODO Test PutPut
326
442
 
327
443
  ##
File without changes
File without changes
File without changes
File without changes