atdata 0.1.2a4__tar.gz → 0.1.3b3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atdata-0.1.2a4 → atdata-0.1.3b3}/PKG-INFO +4 -1
- {atdata-0.1.2a4 → atdata-0.1.3b3}/pyproject.toml +4 -1
- {atdata-0.1.2a4 → atdata-0.1.3b3}/src/atdata/__init__.py +6 -0
- {atdata-0.1.2a4 → atdata-0.1.3b3}/src/atdata/dataset.py +187 -56
- atdata-0.1.3b3/src/atdata/lens.py +200 -0
- {atdata-0.1.2a4 → atdata-0.1.3b3}/tests/test_dataset.py +84 -1
- atdata-0.1.3b3/tests/test_lens.py +166 -0
- {atdata-0.1.2a4 → atdata-0.1.3b3}/.github/workflows/uv-publish-pypi.yml +0 -0
- {atdata-0.1.2a4 → atdata-0.1.3b3}/.github/workflows/uv-test.yml +0 -0
- {atdata-0.1.2a4 → atdata-0.1.3b3}/.gitignore +0 -0
- {atdata-0.1.2a4 → atdata-0.1.3b3}/.python-version +0 -0
- {atdata-0.1.2a4 → atdata-0.1.3b3}/LICENSE +0 -0
- {atdata-0.1.2a4 → atdata-0.1.3b3}/README.md +0 -0
- {atdata-0.1.2a4 → atdata-0.1.3b3}/src/atdata/_helpers.py +0 -0
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: atdata
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3b3
|
|
4
4
|
Summary: A loose federation of distributed, typed datasets
|
|
5
5
|
Author-email: Maxine Levesque <hello@maxine.science>
|
|
6
6
|
License-File: LICENSE
|
|
7
7
|
Requires-Python: >=3.12
|
|
8
|
+
Requires-Dist: fastparquet>=2024.11.0
|
|
8
9
|
Requires-Dist: msgpack>=1.1.2
|
|
9
10
|
Requires-Dist: numpy>=2.3.4
|
|
10
11
|
Requires-Dist: ormsgpack>=1.11.0
|
|
12
|
+
Requires-Dist: pandas>=2.3.3
|
|
13
|
+
Requires-Dist: tqdm>=4.67.1
|
|
11
14
|
Requires-Dist: webdataset>=1.0.2
|
|
12
15
|
Description-Content-Type: text/markdown
|
|
13
16
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "atdata"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.3b3"
|
|
4
4
|
description = "A loose federation of distributed, typed datasets"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [
|
|
@@ -8,9 +8,12 @@ authors = [
|
|
|
8
8
|
]
|
|
9
9
|
requires-python = ">=3.12"
|
|
10
10
|
dependencies = [
|
|
11
|
+
"fastparquet>=2024.11.0",
|
|
11
12
|
"msgpack>=1.1.2",
|
|
12
13
|
"numpy>=2.3.4",
|
|
13
14
|
"ormsgpack>=1.11.0",
|
|
15
|
+
"pandas>=2.3.3",
|
|
16
|
+
"tqdm>=4.67.1",
|
|
14
17
|
"webdataset>=1.0.2",
|
|
15
18
|
]
|
|
16
19
|
|
|
@@ -5,21 +5,34 @@
|
|
|
5
5
|
|
|
6
6
|
import webdataset as wds
|
|
7
7
|
|
|
8
|
-
import
|
|
9
|
-
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
10
9
|
import uuid
|
|
10
|
+
import functools
|
|
11
11
|
|
|
12
|
-
import
|
|
13
|
-
|
|
12
|
+
import dataclasses
|
|
13
|
+
import types
|
|
14
|
+
from dataclasses import (
|
|
15
|
+
dataclass,
|
|
16
|
+
asdict,
|
|
17
|
+
)
|
|
14
18
|
from abc import (
|
|
15
19
|
ABC,
|
|
16
20
|
abstractmethod,
|
|
17
21
|
)
|
|
22
|
+
|
|
23
|
+
from tqdm import tqdm
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
26
|
+
|
|
27
|
+
import typing
|
|
18
28
|
from typing import (
|
|
19
29
|
Any,
|
|
20
30
|
Optional,
|
|
21
31
|
Dict,
|
|
22
32
|
Sequence,
|
|
33
|
+
Iterable,
|
|
34
|
+
Callable,
|
|
35
|
+
Union,
|
|
23
36
|
#
|
|
24
37
|
Self,
|
|
25
38
|
Generic,
|
|
@@ -40,14 +53,20 @@ from numpy.typing import (
|
|
|
40
53
|
import msgpack
|
|
41
54
|
import ormsgpack
|
|
42
55
|
from . import _helpers as eh
|
|
56
|
+
from .lens import Lens, LensNetwork
|
|
43
57
|
|
|
44
58
|
|
|
45
59
|
##
|
|
46
60
|
# Typing help
|
|
47
61
|
|
|
62
|
+
Pathlike = str | Path
|
|
63
|
+
|
|
48
64
|
WDSRawSample: TypeAlias = Dict[str, Any]
|
|
49
65
|
WDSRawBatch: TypeAlias = Dict[str, Any]
|
|
50
66
|
|
|
67
|
+
SampleExportRow: TypeAlias = Dict[str, Any]
|
|
68
|
+
SampleExportMap: TypeAlias = Callable[['PackableSample'], SampleExportRow]
|
|
69
|
+
|
|
51
70
|
|
|
52
71
|
##
|
|
53
72
|
# Main base classes
|
|
@@ -94,6 +113,25 @@ def _make_packable( x ):
|
|
|
94
113
|
return eh.array_to_bytes( x )
|
|
95
114
|
return x
|
|
96
115
|
|
|
116
|
+
def _is_possibly_ndarray_type( t ):
|
|
117
|
+
"""Checks if a type annotation is possibly an NDArray."""
|
|
118
|
+
|
|
119
|
+
# Directly an NDArray
|
|
120
|
+
if t == NDArray:
|
|
121
|
+
# print( 'is an NDArray' )
|
|
122
|
+
return True
|
|
123
|
+
|
|
124
|
+
# Check for Optionals (i.e., NDArray | None)
|
|
125
|
+
if isinstance( t, types.UnionType ):
|
|
126
|
+
t_parts = t.__args__
|
|
127
|
+
if any( x == NDArray
|
|
128
|
+
for x in t_parts ):
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
# Not an NDArray
|
|
132
|
+
return False
|
|
133
|
+
|
|
134
|
+
@dataclass
|
|
97
135
|
class PackableSample( ABC ):
|
|
98
136
|
"""A sample that can be packed and unpacked with msgpack"""
|
|
99
137
|
|
|
@@ -101,10 +139,13 @@ class PackableSample( ABC ):
|
|
|
101
139
|
"""TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
|
|
102
140
|
|
|
103
141
|
# Auto-convert known types when annotated
|
|
104
|
-
for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
142
|
+
# for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
143
|
+
for field in dataclasses.fields( self ):
|
|
144
|
+
var_name = field.name
|
|
145
|
+
var_type = field.type
|
|
105
146
|
|
|
106
147
|
# Annotation for this variable is to be an NDArray
|
|
107
|
-
if var_type
|
|
148
|
+
if _is_possibly_ndarray_type( var_type ):
|
|
108
149
|
# ... so, we'll always auto-convert to numpy
|
|
109
150
|
|
|
110
151
|
var_cur_value = getattr( self, var_name )
|
|
@@ -120,6 +161,9 @@ class PackableSample( ABC ):
|
|
|
120
161
|
# setattr( self, var_name, var_cur_value.to_numpy )
|
|
121
162
|
|
|
122
163
|
elif isinstance( var_cur_value, bytes ):
|
|
164
|
+
# TODO This does create a constraint that serialized bytes
|
|
165
|
+
# in a field that might be an NDArray are always interpreted
|
|
166
|
+
# as being the NDArray interpretation
|
|
123
167
|
setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
|
|
124
168
|
|
|
125
169
|
def __post_init__( self ):
|
|
@@ -189,7 +233,7 @@ class SampleBatch( Generic[DT] ):
|
|
|
189
233
|
@property
|
|
190
234
|
def sample_type( self ) -> Type:
|
|
191
235
|
"""The type of each sample in this batch"""
|
|
192
|
-
return self.__orig_class__
|
|
236
|
+
return typing.get_args( self.__orig_class__)[0]
|
|
193
237
|
|
|
194
238
|
def __getattr__( self, name ):
|
|
195
239
|
# Aggregate named params of sample type
|
|
@@ -217,6 +261,8 @@ class SampleBatch( Generic[DT] ):
|
|
|
217
261
|
ST = TypeVar( 'ST', bound = PackableSample )
|
|
218
262
|
# BT = TypeVar( 'BT' )
|
|
219
263
|
|
|
264
|
+
RT = TypeVar( 'RT', bound = PackableSample )
|
|
265
|
+
|
|
220
266
|
# TODO For python 3.13
|
|
221
267
|
# BT = TypeVar( 'BT', default = None )
|
|
222
268
|
# IT = TypeVar( 'IT', default = Any )
|
|
@@ -235,7 +281,8 @@ class Dataset( Generic[ST] ):
|
|
|
235
281
|
@property
|
|
236
282
|
def sample_type( self ) -> Type:
|
|
237
283
|
"""The type of each returned sample from this `Dataset`'s iterator"""
|
|
238
|
-
|
|
284
|
+
# TODO Figure out why linting fails here
|
|
285
|
+
return typing.get_args( self.__orig_class__ )[0]
|
|
239
286
|
@property
|
|
240
287
|
def batch_type( self ) -> Type:
|
|
241
288
|
"""The type of a batch built from `sample_class`"""
|
|
@@ -253,6 +300,17 @@ class Dataset( Generic[ST] ):
|
|
|
253
300
|
super().__init__()
|
|
254
301
|
self.url = url
|
|
255
302
|
|
|
303
|
+
# Allow addition of automatic transformation of raw underlying data
|
|
304
|
+
self._output_lens: Lens | None = None
|
|
305
|
+
|
|
306
|
+
def as_type( self, other: Type[RT] ) -> 'Dataset[RT]':
|
|
307
|
+
"""TODO"""
|
|
308
|
+
ret = Dataset[other]( self.url )
|
|
309
|
+
# Get the singleton lens registry
|
|
310
|
+
lenses = LensNetwork()
|
|
311
|
+
ret._output_lens = lenses.transform( self.sample_type, ret.sample_type )
|
|
312
|
+
return ret
|
|
313
|
+
|
|
256
314
|
# @classmethod
|
|
257
315
|
# def register( cls, uri: str,
|
|
258
316
|
# sample_class: Type,
|
|
@@ -278,15 +336,15 @@ class Dataset( Generic[ST] ):
|
|
|
278
336
|
A full (non-lazy) list of the individual ``tar`` files within the
|
|
279
337
|
source WebDataset.
|
|
280
338
|
"""
|
|
281
|
-
pipe = wds.DataPipeline(
|
|
282
|
-
wds.SimpleShardList( self.url ),
|
|
283
|
-
wds.map( lambda x: x['url'] )
|
|
339
|
+
pipe = wds.pipeline.DataPipeline(
|
|
340
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
341
|
+
wds.filters.map( lambda x: x['url'] )
|
|
284
342
|
)
|
|
285
343
|
return list( pipe )
|
|
286
344
|
|
|
287
345
|
def ordered( self,
|
|
288
346
|
batch_size: int | None = 1,
|
|
289
|
-
) ->
|
|
347
|
+
) -> Iterable[ST]:
|
|
290
348
|
"""Iterate over the dataset in order
|
|
291
349
|
|
|
292
350
|
Args:
|
|
@@ -302,30 +360,30 @@ class Dataset( Generic[ST] ):
|
|
|
302
360
|
|
|
303
361
|
if batch_size is None:
|
|
304
362
|
# TODO Duplication here
|
|
305
|
-
return wds.DataPipeline(
|
|
306
|
-
wds.SimpleShardList( self.url ),
|
|
307
|
-
wds.split_by_worker,
|
|
363
|
+
return wds.pipeline.DataPipeline(
|
|
364
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
365
|
+
wds.shardlists.split_by_worker,
|
|
308
366
|
#
|
|
309
|
-
wds.tarfile_to_samples(),
|
|
367
|
+
wds.tariterators.tarfile_to_samples(),
|
|
310
368
|
# wds.map( self.preprocess ),
|
|
311
|
-
wds.map( self.wrap ),
|
|
369
|
+
wds.filters.map( self.wrap ),
|
|
312
370
|
)
|
|
313
371
|
|
|
314
|
-
return wds.DataPipeline(
|
|
315
|
-
wds.SimpleShardList( self.url ),
|
|
316
|
-
wds.split_by_worker,
|
|
372
|
+
return wds.pipeline.DataPipeline(
|
|
373
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
374
|
+
wds.shardlists.split_by_worker,
|
|
317
375
|
#
|
|
318
|
-
wds.tarfile_to_samples(),
|
|
376
|
+
wds.tariterators.tarfile_to_samples(),
|
|
319
377
|
# wds.map( self.preprocess ),
|
|
320
|
-
wds.batched( batch_size ),
|
|
321
|
-
wds.map( self.wrap_batch ),
|
|
378
|
+
wds.filters.batched( batch_size ),
|
|
379
|
+
wds.filters.map( self.wrap_batch ),
|
|
322
380
|
)
|
|
323
381
|
|
|
324
382
|
def shuffled( self,
|
|
325
383
|
buffer_shards: int = 100,
|
|
326
384
|
buffer_samples: int = 10_000,
|
|
327
385
|
batch_size: int | None = 1,
|
|
328
|
-
) ->
|
|
386
|
+
) -> Iterable[ST]:
|
|
329
387
|
"""Iterate over the dataset in random order
|
|
330
388
|
|
|
331
389
|
Args:
|
|
@@ -342,30 +400,88 @@ class Dataset( Generic[ST] ):
|
|
|
342
400
|
|
|
343
401
|
if batch_size is None:
|
|
344
402
|
# TODO Duplication here
|
|
345
|
-
return wds.DataPipeline(
|
|
346
|
-
wds.SimpleShardList( self.url ),
|
|
347
|
-
wds.shuffle( buffer_shards ),
|
|
348
|
-
wds.split_by_worker,
|
|
403
|
+
return wds.pipeline.DataPipeline(
|
|
404
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
405
|
+
wds.filters.shuffle( buffer_shards ),
|
|
406
|
+
wds.shardlists.split_by_worker,
|
|
349
407
|
#
|
|
350
|
-
wds.tarfile_to_samples(),
|
|
408
|
+
wds.tariterators.tarfile_to_samples(),
|
|
351
409
|
# wds.shuffle( buffer_samples ),
|
|
352
410
|
# wds.map( self.preprocess ),
|
|
353
|
-
wds.shuffle( buffer_samples ),
|
|
354
|
-
wds.map( self.wrap ),
|
|
411
|
+
wds.filters.shuffle( buffer_samples ),
|
|
412
|
+
wds.filters.map( self.wrap ),
|
|
355
413
|
)
|
|
356
414
|
|
|
357
|
-
return wds.DataPipeline(
|
|
358
|
-
wds.SimpleShardList( self.url ),
|
|
359
|
-
wds.shuffle( buffer_shards ),
|
|
360
|
-
wds.split_by_worker,
|
|
415
|
+
return wds.pipeline.DataPipeline(
|
|
416
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
417
|
+
wds.filters.shuffle( buffer_shards ),
|
|
418
|
+
wds.shardlists.split_by_worker,
|
|
361
419
|
#
|
|
362
|
-
wds.tarfile_to_samples(),
|
|
420
|
+
wds.tariterators.tarfile_to_samples(),
|
|
363
421
|
# wds.shuffle( buffer_samples ),
|
|
364
422
|
# wds.map( self.preprocess ),
|
|
365
|
-
wds.shuffle( buffer_samples ),
|
|
366
|
-
wds.batched( batch_size ),
|
|
367
|
-
wds.map( self.wrap_batch ),
|
|
423
|
+
wds.filters.shuffle( buffer_samples ),
|
|
424
|
+
wds.filters.batched( batch_size ),
|
|
425
|
+
wds.filters.map( self.wrap_batch ),
|
|
368
426
|
)
|
|
427
|
+
|
|
428
|
+
# TODO Rewrite to eliminate `pandas` dependency directly calling
|
|
429
|
+
# `fastparquet`
|
|
430
|
+
def to_parquet( self, path: Pathlike,
|
|
431
|
+
sample_map: Optional[SampleExportMap] = None,
|
|
432
|
+
maxcount: Optional[int] = None,
|
|
433
|
+
**kwargs,
|
|
434
|
+
):
|
|
435
|
+
"""Save dataset contents to a `parquet` file at `path`
|
|
436
|
+
|
|
437
|
+
`kwargs` sent to `pandas.to_parquet`
|
|
438
|
+
"""
|
|
439
|
+
##
|
|
440
|
+
|
|
441
|
+
# Normalize args
|
|
442
|
+
path = Path( path )
|
|
443
|
+
if sample_map is None:
|
|
444
|
+
sample_map = asdict
|
|
445
|
+
|
|
446
|
+
verbose = kwargs.get( 'verbose', False )
|
|
447
|
+
|
|
448
|
+
it = self.ordered( batch_size = None )
|
|
449
|
+
if verbose:
|
|
450
|
+
it = tqdm( it )
|
|
451
|
+
|
|
452
|
+
#
|
|
453
|
+
|
|
454
|
+
if maxcount is None:
|
|
455
|
+
# Load and save full dataset
|
|
456
|
+
df = pd.DataFrame( [ sample_map( x )
|
|
457
|
+
for x in self.ordered( batch_size = None ) ] )
|
|
458
|
+
df.to_parquet( path, **kwargs )
|
|
459
|
+
|
|
460
|
+
else:
|
|
461
|
+
# Load and save dataset in segments of size `maxcount`
|
|
462
|
+
|
|
463
|
+
cur_segment = 0
|
|
464
|
+
cur_buffer = []
|
|
465
|
+
path_template = (path.parent / f'{path.stem}-%06d.{path.suffix}').as_posix()
|
|
466
|
+
|
|
467
|
+
for x in self.ordered( batch_size = None ):
|
|
468
|
+
cur_buffer.append( sample_map( x ) )
|
|
469
|
+
|
|
470
|
+
if len( cur_buffer ) >= maxcount:
|
|
471
|
+
# Write current segment
|
|
472
|
+
cur_path = path_template.format( cur_segment )
|
|
473
|
+
df = pd.DataFrame( cur_buffer )
|
|
474
|
+
df.to_parquet( cur_path, **kwargs )
|
|
475
|
+
|
|
476
|
+
cur_segment += 1
|
|
477
|
+
cur_buffer = []
|
|
478
|
+
|
|
479
|
+
if len( cur_buffer ) > 0:
|
|
480
|
+
# Write one last segment with remainder
|
|
481
|
+
cur_path = path_template.format( cur_segment )
|
|
482
|
+
df = pd.DataFrame( cur_buffer )
|
|
483
|
+
df.to_parquet( cur_path, **kwargs )
|
|
484
|
+
|
|
369
485
|
|
|
370
486
|
# Implemented by specific subclasses
|
|
371
487
|
|
|
@@ -388,20 +504,24 @@ class Dataset( Generic[ST] ):
|
|
|
388
504
|
assert 'msgpack' in sample
|
|
389
505
|
assert type( sample['msgpack'] ) == bytes
|
|
390
506
|
|
|
391
|
-
|
|
507
|
+
if self._output_lens is None:
|
|
508
|
+
return self.sample_type.from_bytes( sample['msgpack'] )
|
|
509
|
+
|
|
510
|
+
source_sample = self._output_lens.source_type.from_bytes( sample['msgpack'] )
|
|
511
|
+
return self._output_lens( source_sample )
|
|
392
512
|
|
|
393
|
-
try:
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
513
|
+
# try:
|
|
514
|
+
# assert type( sample ) == dict
|
|
515
|
+
# return cls.sample_class( **{
|
|
516
|
+
# k: v
|
|
517
|
+
# for k, v in sample.items() if k != '__key__'
|
|
518
|
+
# } )
|
|
399
519
|
|
|
400
|
-
except Exception as e:
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
520
|
+
# except Exception as e:
|
|
521
|
+
# # Sample constructor failed -- revert to default
|
|
522
|
+
# return AnySample(
|
|
523
|
+
# value = sample,
|
|
524
|
+
# )
|
|
405
525
|
|
|
406
526
|
def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
|
|
407
527
|
"""Wrap a `batch` of samples into the appropriate dataset-specific type
|
|
@@ -410,10 +530,17 @@ class Dataset( Generic[ST] ):
|
|
|
410
530
|
"""
|
|
411
531
|
|
|
412
532
|
assert 'msgpack' in batch
|
|
413
|
-
batch_unpacked = [ self.sample_type.from_bytes( bs )
|
|
414
|
-
for bs in batch['msgpack'] ]
|
|
415
|
-
return SampleBatch[self.sample_type]( batch_unpacked )
|
|
416
533
|
|
|
534
|
+
if self._output_lens is None:
|
|
535
|
+
batch_unpacked = [ self.sample_type.from_bytes( bs )
|
|
536
|
+
for bs in batch['msgpack'] ]
|
|
537
|
+
return SampleBatch[self.sample_type]( batch_unpacked )
|
|
538
|
+
|
|
539
|
+
batch_source = [ self._output_lens.source_type.from_bytes( bs )
|
|
540
|
+
for bs in batch['msgpack'] ]
|
|
541
|
+
batch_view = [ self._output_lens( s )
|
|
542
|
+
for s in batch_source ]
|
|
543
|
+
return SampleBatch[self.sample_type]( batch_view )
|
|
417
544
|
|
|
418
545
|
# # @classmethod
|
|
419
546
|
# def wrap_batch( self, batch: WDSRawBatch ) -> BT:
|
|
@@ -449,6 +576,9 @@ def packable( cls ):
|
|
|
449
576
|
|
|
450
577
|
##
|
|
451
578
|
|
|
579
|
+
class_name = cls.__name__
|
|
580
|
+
class_annotations = cls.__annotations__
|
|
581
|
+
|
|
452
582
|
# Add in dataclass niceness to original class
|
|
453
583
|
as_dataclass = dataclass( cls )
|
|
454
584
|
|
|
@@ -458,8 +588,9 @@ def packable( cls ):
|
|
|
458
588
|
def __post_init__( self ):
|
|
459
589
|
return PackableSample.__post_init__( self )
|
|
460
590
|
|
|
461
|
-
|
|
462
|
-
as_packable.
|
|
591
|
+
# TODO This doesn't properly carry over the original
|
|
592
|
+
as_packable.__name__ = class_name
|
|
593
|
+
as_packable.__annotations__ = class_annotations
|
|
463
594
|
|
|
464
595
|
##
|
|
465
596
|
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Lenses between typed datasets"""
|
|
2
|
+
|
|
3
|
+
##
|
|
4
|
+
# Imports
|
|
5
|
+
|
|
6
|
+
import functools
|
|
7
|
+
import inspect
|
|
8
|
+
|
|
9
|
+
from typing import (
|
|
10
|
+
TypeAlias,
|
|
11
|
+
Type,
|
|
12
|
+
TypeVar,
|
|
13
|
+
Tuple,
|
|
14
|
+
Dict,
|
|
15
|
+
Callable,
|
|
16
|
+
Optional,
|
|
17
|
+
Generic,
|
|
18
|
+
#
|
|
19
|
+
TYPE_CHECKING
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from .dataset import PackableSample
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
##
|
|
27
|
+
# Typing helpers
|
|
28
|
+
|
|
29
|
+
DatasetType: TypeAlias = Type['PackableSample']
|
|
30
|
+
LensSignature: TypeAlias = Tuple[DatasetType, DatasetType]
|
|
31
|
+
|
|
32
|
+
S = TypeVar( 'S', bound = 'PackableSample' )
|
|
33
|
+
V = TypeVar( 'V', bound = 'PackableSample' )
|
|
34
|
+
type LensGetter[S, V] = Callable[[S], V]
|
|
35
|
+
type LensPutter[S, V] = Callable[[V, S], S]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
##
|
|
39
|
+
# Shortcut decorators
|
|
40
|
+
|
|
41
|
+
class Lens( Generic[S, V] ):
|
|
42
|
+
"""TODO"""
|
|
43
|
+
|
|
44
|
+
# @property
|
|
45
|
+
# def source_type( self ) -> Type[S]:
|
|
46
|
+
# """The source type (S) for the lens; what is put to"""
|
|
47
|
+
# # TODO Figure out why linting fails here
|
|
48
|
+
# return self.__orig_class__.__args__[0]
|
|
49
|
+
|
|
50
|
+
# @property
|
|
51
|
+
# def view_type( self ) -> Type[V]:
|
|
52
|
+
# """The view type (V) for the lens; what is get'd from"""
|
|
53
|
+
# # TODO FIgure out why linting fails here
|
|
54
|
+
# return self.__orig_class__.__args__[1]
|
|
55
|
+
|
|
56
|
+
def __init__( self, get: LensGetter[S, V],
|
|
57
|
+
put: Optional[LensPutter[S, V]] = None
|
|
58
|
+
) -> None:
|
|
59
|
+
"""TODO"""
|
|
60
|
+
##
|
|
61
|
+
|
|
62
|
+
# Check argument validity
|
|
63
|
+
|
|
64
|
+
sig = inspect.signature( get )
|
|
65
|
+
input_types = list( sig.parameters.values() )
|
|
66
|
+
assert len( input_types ) == 1, \
|
|
67
|
+
'Wrong number of input args for lens: should only have one'
|
|
68
|
+
|
|
69
|
+
# Update function details for this object as returned by annotation
|
|
70
|
+
functools.update_wrapper( self, get )
|
|
71
|
+
|
|
72
|
+
self.source_type: Type[PackableSample] = input_types[0].annotation
|
|
73
|
+
self.view_type = sig.return_annotation
|
|
74
|
+
|
|
75
|
+
# Store the getter
|
|
76
|
+
self._getter = get
|
|
77
|
+
|
|
78
|
+
# Determine and store the putter
|
|
79
|
+
if put is None:
|
|
80
|
+
# Trivial putter does not update the source
|
|
81
|
+
def _trivial_put( v: V, s: S ) -> S:
|
|
82
|
+
return s
|
|
83
|
+
put = _trivial_put
|
|
84
|
+
self._putter = put
|
|
85
|
+
|
|
86
|
+
#
|
|
87
|
+
|
|
88
|
+
def putter( self, put: LensPutter[S, V] ) -> LensPutter[S, V]:
|
|
89
|
+
"""TODO"""
|
|
90
|
+
##
|
|
91
|
+
self._putter = put
|
|
92
|
+
return put
|
|
93
|
+
|
|
94
|
+
# Methods to actually execute transformations
|
|
95
|
+
|
|
96
|
+
def put( self, v: V, s: S ) -> S:
|
|
97
|
+
"""TODO"""
|
|
98
|
+
return self._putter( v, s )
|
|
99
|
+
|
|
100
|
+
def get( self, s: S ) -> V:
|
|
101
|
+
"""TODO"""
|
|
102
|
+
return self( s )
|
|
103
|
+
|
|
104
|
+
# Convenience to enable calling the lens as its getter
|
|
105
|
+
|
|
106
|
+
def __call__( self, s: S ) -> V:
|
|
107
|
+
return self._getter( s )
|
|
108
|
+
|
|
109
|
+
# TODO Figure out how to properly parameterize this
|
|
110
|
+
# def _lens_factory[S, V]( register: bool = True ):
|
|
111
|
+
# """Register the annotated function `f` as the getter of a sample lens"""
|
|
112
|
+
|
|
113
|
+
# # The actual lens decorator taking a lens getter function to a lens object
|
|
114
|
+
# def _decorator( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
115
|
+
# ret = Lens[S, V]( f )
|
|
116
|
+
# if register:
|
|
117
|
+
# _network.register( ret )
|
|
118
|
+
# return ret
|
|
119
|
+
|
|
120
|
+
# # Return the lens decorator
|
|
121
|
+
# return _decorator
|
|
122
|
+
|
|
123
|
+
# # For convenience
|
|
124
|
+
# lens = _lens_factory
|
|
125
|
+
|
|
126
|
+
def lens( f: LensGetter[S, V] ) -> Lens[S, V]:
|
|
127
|
+
ret = Lens[S, V]( f )
|
|
128
|
+
_network.register( ret )
|
|
129
|
+
return ret
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
##
|
|
133
|
+
# Global registry of used lenses
|
|
134
|
+
|
|
135
|
+
# _registered_lenses: Dict[LensSignature, Lens] = dict()
|
|
136
|
+
# """TODO"""
|
|
137
|
+
|
|
138
|
+
class LensNetwork:
|
|
139
|
+
"""TODO"""
|
|
140
|
+
|
|
141
|
+
_instance = None
|
|
142
|
+
"""The singleton instance"""
|
|
143
|
+
|
|
144
|
+
def __new__(cls, *args, **kwargs):
|
|
145
|
+
if cls._instance is None:
|
|
146
|
+
# If no instance exists, create a new one
|
|
147
|
+
cls._instance = super().__new__(cls)
|
|
148
|
+
return cls._instance # Return the existing (or newly created) instance
|
|
149
|
+
|
|
150
|
+
def __init__(self):
|
|
151
|
+
if not hasattr(self, '_initialized'): # Check if already initialized
|
|
152
|
+
self._registry: Dict[LensSignature, Lens] = dict()
|
|
153
|
+
self._initialized = True
|
|
154
|
+
|
|
155
|
+
def register( self, _lens: Lens ):
|
|
156
|
+
"""Set `lens` as the canonical view between its source and view types"""
|
|
157
|
+
|
|
158
|
+
# sig = inspect.signature( _lens.get )
|
|
159
|
+
# input_types = list( sig.parameters.values() )
|
|
160
|
+
# assert len( input_types ) == 1, \
|
|
161
|
+
# 'Wrong number of input args for lens: should only have one'
|
|
162
|
+
|
|
163
|
+
# input_type = input_types[0].annotation
|
|
164
|
+
# print( input_type )
|
|
165
|
+
# output_type = sig.return_annotation
|
|
166
|
+
|
|
167
|
+
# self._registry[input_type, output_type] = _lens
|
|
168
|
+
# print( _lens.source_type )
|
|
169
|
+
self._registry[_lens.source_type, _lens.view_type] = _lens
|
|
170
|
+
|
|
171
|
+
def transform( self, source: DatasetType, view: DatasetType ) -> Lens:
|
|
172
|
+
"""TODO"""
|
|
173
|
+
|
|
174
|
+
# TODO Handle compositional closure
|
|
175
|
+
ret = self._registry.get( (source, view), None )
|
|
176
|
+
if ret is None:
|
|
177
|
+
raise ValueError( f'No registered lens from source {source} to view {view}' )
|
|
178
|
+
|
|
179
|
+
return ret
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# Create global singleton registry instance
|
|
183
|
+
_network = LensNetwork()
|
|
184
|
+
|
|
185
|
+
# def lens( f: LensPutter ) -> Lens:
|
|
186
|
+
# """Register the annotated function `f` as a sample lens"""
|
|
187
|
+
# ##
|
|
188
|
+
|
|
189
|
+
# sig = inspect.signature( f )
|
|
190
|
+
|
|
191
|
+
# input_types = list( sig.parameters.values() )
|
|
192
|
+
# output_type = sig.return_annotation
|
|
193
|
+
|
|
194
|
+
# _registered_lenses[]
|
|
195
|
+
|
|
196
|
+
# f.lens = Lens(
|
|
197
|
+
|
|
198
|
+
# )
|
|
199
|
+
|
|
200
|
+
# return f
|
|
@@ -50,6 +50,12 @@ class NumpyTestSampleDecorated:
|
|
|
50
50
|
label: int
|
|
51
51
|
image: NDArray
|
|
52
52
|
|
|
53
|
+
@atdata.packable
|
|
54
|
+
class NumpyOptionalSampleDecorated:
|
|
55
|
+
label: int
|
|
56
|
+
image: NDArray
|
|
57
|
+
embeddings: NDArray | None = None
|
|
58
|
+
|
|
53
59
|
test_cases = [
|
|
54
60
|
{
|
|
55
61
|
'SampleType': BasicTestSample,
|
|
@@ -59,6 +65,7 @@ test_cases = [
|
|
|
59
65
|
'value': 1024.768,
|
|
60
66
|
},
|
|
61
67
|
'sample_wds_stem': 'basic_test',
|
|
68
|
+
'test_parquet': True,
|
|
62
69
|
},
|
|
63
70
|
{
|
|
64
71
|
'SampleType': NumpyTestSample,
|
|
@@ -68,6 +75,7 @@ test_cases = [
|
|
|
68
75
|
'image': np.random.randn( 1024, 1024 ),
|
|
69
76
|
},
|
|
70
77
|
'sample_wds_stem': 'numpy_test',
|
|
78
|
+
'test_parquet': False,
|
|
71
79
|
},
|
|
72
80
|
{
|
|
73
81
|
'SampleType': BasicTestSampleDecorated,
|
|
@@ -77,6 +85,7 @@ test_cases = [
|
|
|
77
85
|
'value': 1024.768,
|
|
78
86
|
},
|
|
79
87
|
'sample_wds_stem': 'basic_test_decorated',
|
|
88
|
+
'test_parquet': True,
|
|
80
89
|
},
|
|
81
90
|
{
|
|
82
91
|
'SampleType': NumpyTestSampleDecorated,
|
|
@@ -86,6 +95,29 @@ test_cases = [
|
|
|
86
95
|
'image': np.random.randn( 1024, 1024 ),
|
|
87
96
|
},
|
|
88
97
|
'sample_wds_stem': 'numpy_test_decorated',
|
|
98
|
+
'test_parquet': False,
|
|
99
|
+
},
|
|
100
|
+
{
|
|
101
|
+
'SampleType': NumpyOptionalSampleDecorated,
|
|
102
|
+
'sample_data':
|
|
103
|
+
{
|
|
104
|
+
'label': 9_001,
|
|
105
|
+
'image': np.random.randn( 1024, 1024 ),
|
|
106
|
+
'embeddings': np.random.randn( 512 ),
|
|
107
|
+
},
|
|
108
|
+
'sample_wds_stem': 'numpy_optional_decorated',
|
|
109
|
+
'test_parquet': False,
|
|
110
|
+
},
|
|
111
|
+
{
|
|
112
|
+
'SampleType': NumpyOptionalSampleDecorated,
|
|
113
|
+
'sample_data':
|
|
114
|
+
{
|
|
115
|
+
'label': 9_001,
|
|
116
|
+
'image': np.random.randn( 1024, 1024 ),
|
|
117
|
+
'embeddings': None,
|
|
118
|
+
},
|
|
119
|
+
'sample_wds_stem': 'numpy_optional_decorated_none',
|
|
120
|
+
'test_parquet': False,
|
|
89
121
|
},
|
|
90
122
|
]
|
|
91
123
|
|
|
@@ -175,7 +207,7 @@ def test_wds(
|
|
|
175
207
|
).as_posix()
|
|
176
208
|
file_wds_pattern = file_pattern.format( shard_id = '%06d' )
|
|
177
209
|
|
|
178
|
-
with wds.ShardWriter(
|
|
210
|
+
with wds.writer.ShardWriter(
|
|
179
211
|
pattern = file_wds_pattern,
|
|
180
212
|
maxcount = shard_maxcount,
|
|
181
213
|
) as sink:
|
|
@@ -323,5 +355,56 @@ def test_wds(
|
|
|
323
355
|
assert iterations_run == n_iterate, \
|
|
324
356
|
"Only found {iterations_run} samples, not {n_iterate}"
|
|
325
357
|
|
|
358
|
+
#
|
|
359
|
+
|
|
360
|
+
@pytest.mark.parametrize(
|
|
361
|
+
('SampleType', 'sample_data', 'sample_wds_stem', 'test_parquet'),
|
|
362
|
+
[ (
|
|
363
|
+
case['SampleType'],
|
|
364
|
+
case['sample_data'],
|
|
365
|
+
case['sample_wds_stem'],
|
|
366
|
+
case['test_parquet']
|
|
367
|
+
)
|
|
368
|
+
for case in test_cases ]
|
|
369
|
+
)
|
|
370
|
+
def test_parquet_export(
|
|
371
|
+
SampleType: Type[atdata.PackableSample],
|
|
372
|
+
sample_data: atds.MsgpackRawSample,
|
|
373
|
+
sample_wds_stem: str,
|
|
374
|
+
test_parquet: bool,
|
|
375
|
+
tmp_path
|
|
376
|
+
):
|
|
377
|
+
"""Test our ability to export a dataset to `parquet` format"""
|
|
378
|
+
|
|
379
|
+
# Skip irrelevant test cases
|
|
380
|
+
if not test_parquet:
|
|
381
|
+
return
|
|
382
|
+
|
|
383
|
+
## Testing hyperparameters
|
|
384
|
+
|
|
385
|
+
n_copies_dataset = 1_000
|
|
386
|
+
n_per_file = 100
|
|
387
|
+
|
|
388
|
+
## Start out by writing tar dataset
|
|
389
|
+
|
|
390
|
+
wds_filename = (tmp_path / f'{sample_wds_stem}.tar').as_posix()
|
|
391
|
+
with wds.writer.TarWriter( wds_filename ) as sink:
|
|
392
|
+
for _ in range( n_copies_dataset ):
|
|
393
|
+
new_sample = SampleType.from_data( sample_data )
|
|
394
|
+
sink.write( new_sample.as_wds )
|
|
395
|
+
|
|
396
|
+
## Now export to `parquet`
|
|
397
|
+
|
|
398
|
+
dataset = atdata.Dataset[SampleType]( wds_filename )
|
|
399
|
+
parquet_filename = tmp_path / f'{sample_wds_stem}.parquet'
|
|
400
|
+
dataset.to_parquet( parquet_filename )
|
|
401
|
+
|
|
402
|
+
parquet_filename = tmp_path / f'{sample_wds_stem}-segments.parquet'
|
|
403
|
+
dataset.to_parquet( parquet_filename, maxcount = n_per_file )
|
|
404
|
+
|
|
405
|
+
## Double-check our `parquet` export
|
|
406
|
+
|
|
407
|
+
# TODO
|
|
408
|
+
|
|
326
409
|
|
|
327
410
|
##
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Test lens functionality."""
|
|
2
|
+
|
|
3
|
+
##
|
|
4
|
+
# Imports
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
import webdataset as wds
|
|
10
|
+
import atdata
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from numpy.typing import NDArray
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
##
|
|
17
|
+
# Tests
|
|
18
|
+
|
|
19
|
+
def test_lens():
|
|
20
|
+
"""Test a lens between sample types"""
|
|
21
|
+
|
|
22
|
+
# Set up the lens scenario
|
|
23
|
+
|
|
24
|
+
@atdata.packable
|
|
25
|
+
class Source:
|
|
26
|
+
name: str
|
|
27
|
+
age: int
|
|
28
|
+
height: float
|
|
29
|
+
|
|
30
|
+
@atdata.packable
|
|
31
|
+
class View:
|
|
32
|
+
name: str
|
|
33
|
+
height: float
|
|
34
|
+
|
|
35
|
+
@atdata.lens
|
|
36
|
+
def polite( s: Source ) -> View:
|
|
37
|
+
return View(
|
|
38
|
+
name = s.name,
|
|
39
|
+
height = s.height,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@polite.putter
|
|
43
|
+
def polite_update( v: View, s: Source ) -> Source:
|
|
44
|
+
return Source(
|
|
45
|
+
name = v.name,
|
|
46
|
+
height = v.height,
|
|
47
|
+
#
|
|
48
|
+
age = s.age,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Test with an example sample
|
|
52
|
+
|
|
53
|
+
test_source = Source(
|
|
54
|
+
name = 'Hello World',
|
|
55
|
+
age = 42,
|
|
56
|
+
height = 182.9,
|
|
57
|
+
)
|
|
58
|
+
correct_view = View(
|
|
59
|
+
name = test_source.name,
|
|
60
|
+
height = test_source.height,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
test_view = polite( test_source )
|
|
64
|
+
assert test_view == correct_view, \
|
|
65
|
+
f'Incorrect lens behavior: {test_view}, and not {correct_view}'
|
|
66
|
+
|
|
67
|
+
# This lens should be well-behaved
|
|
68
|
+
|
|
69
|
+
update_view = View(
|
|
70
|
+
name = 'Now Taller',
|
|
71
|
+
height = 192.9,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
x = polite( polite.put( update_view, test_source ) )
|
|
75
|
+
assert x == update_view, \
|
|
76
|
+
f'Violation of GetPut: {x} =/= {update_view}'
|
|
77
|
+
|
|
78
|
+
y = polite.put( polite( test_source ), test_source )
|
|
79
|
+
assert y == test_source, \
|
|
80
|
+
f'Violation of PutGet: {y} =/= {test_source}'
|
|
81
|
+
|
|
82
|
+
# TODO Test PutPut
|
|
83
|
+
|
|
84
|
+
def test_conversion( tmp_path ):
|
|
85
|
+
"""Test automatic interconversion between sample types"""
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class Source( atdata.PackableSample ):
|
|
89
|
+
name: str
|
|
90
|
+
height: float
|
|
91
|
+
favorite_pizza: str
|
|
92
|
+
favorite_image: NDArray
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class View( atdata.PackableSample ):
|
|
96
|
+
name: str
|
|
97
|
+
favorite_pizza: str
|
|
98
|
+
favorite_image: NDArray
|
|
99
|
+
|
|
100
|
+
@atdata.lens
|
|
101
|
+
def polite( s: Source ) -> View:
|
|
102
|
+
return View(
|
|
103
|
+
name = s.name,
|
|
104
|
+
favorite_pizza = s.favorite_pizza,
|
|
105
|
+
favorite_image = s.favorite_image,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
lens_network = atdata.LensNetwork()
|
|
109
|
+
print( lens_network._registry )
|
|
110
|
+
|
|
111
|
+
# Map a test sample through the view
|
|
112
|
+
test_source = Source(
|
|
113
|
+
name = 'Larry',
|
|
114
|
+
height = 42.,
|
|
115
|
+
favorite_pizza = 'pineapple',
|
|
116
|
+
favorite_image = np.random.randn( 224, 224 )
|
|
117
|
+
)
|
|
118
|
+
test_view = polite( test_source )
|
|
119
|
+
|
|
120
|
+
# Create a test dataset
|
|
121
|
+
|
|
122
|
+
k_test = 100
|
|
123
|
+
test_filename = (
|
|
124
|
+
tmp_path
|
|
125
|
+
/ 'test-source.tar'
|
|
126
|
+
).as_posix()
|
|
127
|
+
|
|
128
|
+
with wds.writer.TarWriter( test_filename ) as dest:
|
|
129
|
+
for i in range( k_test ):
|
|
130
|
+
# Create a new copied sample
|
|
131
|
+
cur_sample = Source(
|
|
132
|
+
name = test_source.name,
|
|
133
|
+
height = test_source.height,
|
|
134
|
+
favorite_pizza = test_source.favorite_pizza,
|
|
135
|
+
favorite_image = test_source.favorite_image,
|
|
136
|
+
)
|
|
137
|
+
dest.write( cur_sample.as_wds )
|
|
138
|
+
|
|
139
|
+
# Try reading the test dataset
|
|
140
|
+
|
|
141
|
+
ds = (
|
|
142
|
+
atdata.Dataset[Source]( test_filename )
|
|
143
|
+
.as_type( View )
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
assert ds.sample_type == View, \
|
|
147
|
+
'Auto-mapped'
|
|
148
|
+
|
|
149
|
+
sample: View | None = None
|
|
150
|
+
for sample in ds.ordered( batch_size = None ):
|
|
151
|
+
# Load only the first sample
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
assert sample is not None, \
|
|
155
|
+
'Did not load any samples from `Source` dataset'
|
|
156
|
+
|
|
157
|
+
assert sample.name == test_view.name, \
|
|
158
|
+
f'Divergence on auto-mapped dataset: `name` should be {test_view.name}, but is {sample.name}'
|
|
159
|
+
# assert sample.height == test_view.height, \
|
|
160
|
+
# f'Divergence on auto-mapped dataset: `height` should be {test_view.height}, but is {sample.height}'
|
|
161
|
+
assert sample.favorite_pizza == test_view.favorite_pizza, \
|
|
162
|
+
f'Divergence on auto-mapped dataset: `favorite_pizza` should be {test_view.favorite_pizza}, but is {sample.favorite_pizza}'
|
|
163
|
+
assert np.all( sample.favorite_image == test_view.favorite_image ), \
|
|
164
|
+
f'Divergence on auto-mapped dataset: `favorite_image`'
|
|
165
|
+
|
|
166
|
+
##
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|