atdata 0.1.3a3__tar.gz → 0.1.3b1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {atdata-0.1.3a3 → atdata-0.1.3b1}/PKG-INFO +1 -1
- {atdata-0.1.3a3 → atdata-0.1.3b1}/pyproject.toml +1 -1
- {atdata-0.1.3a3 → atdata-0.1.3b1}/src/atdata/dataset.py +48 -19
- {atdata-0.1.3a3 → atdata-0.1.3b1}/tests/test_dataset.py +28 -0
- {atdata-0.1.3a3 → atdata-0.1.3b1}/.github/workflows/uv-publish-pypi.yml +0 -0
- {atdata-0.1.3a3 → atdata-0.1.3b1}/.github/workflows/uv-test.yml +0 -0
- {atdata-0.1.3a3 → atdata-0.1.3b1}/.gitignore +0 -0
- {atdata-0.1.3a3 → atdata-0.1.3b1}/.python-version +0 -0
- {atdata-0.1.3a3 → atdata-0.1.3b1}/LICENSE +0 -0
- {atdata-0.1.3a3 → atdata-0.1.3b1}/README.md +0 -0
- {atdata-0.1.3a3 → atdata-0.1.3b1}/src/atdata/__init__.py +0 -0
- {atdata-0.1.3a3 → atdata-0.1.3b1}/src/atdata/_helpers.py +0 -0
- {atdata-0.1.3a3 → atdata-0.1.3b1}/src/atdata/lens.py +0 -0
- {atdata-0.1.3a3 → atdata-0.1.3b1}/tests/test_lens.py +0 -0
|
@@ -8,6 +8,9 @@ import webdataset as wds
|
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
import uuid
|
|
10
10
|
import functools
|
|
11
|
+
|
|
12
|
+
import dataclasses
|
|
13
|
+
import types
|
|
11
14
|
from dataclasses import (
|
|
12
15
|
dataclass,
|
|
13
16
|
asdict,
|
|
@@ -21,6 +24,7 @@ from tqdm import tqdm
|
|
|
21
24
|
import numpy as np
|
|
22
25
|
import pandas as pd
|
|
23
26
|
|
|
27
|
+
import typing
|
|
24
28
|
from typing import (
|
|
25
29
|
Any,
|
|
26
30
|
Optional,
|
|
@@ -28,6 +32,7 @@ from typing import (
|
|
|
28
32
|
Sequence,
|
|
29
33
|
Iterable,
|
|
30
34
|
Callable,
|
|
35
|
+
Union,
|
|
31
36
|
#
|
|
32
37
|
Self,
|
|
33
38
|
Generic,
|
|
@@ -108,6 +113,24 @@ def _make_packable( x ):
|
|
|
108
113
|
return eh.array_to_bytes( x )
|
|
109
114
|
return x
|
|
110
115
|
|
|
116
|
+
def _is_possibly_ndarray_type( t ):
|
|
117
|
+
"""Checks if a type annotation is possibly an NDArray."""
|
|
118
|
+
|
|
119
|
+
# Directly an NDArray
|
|
120
|
+
if t == NDArray:
|
|
121
|
+
print( 'is an NDArray' )
|
|
122
|
+
return True
|
|
123
|
+
|
|
124
|
+
# Check for Optionals (i.e., NDArray | None)
|
|
125
|
+
if isinstance( t, types.UnionType ):
|
|
126
|
+
t_parts = t.__args__
|
|
127
|
+
if any( x == NDArray
|
|
128
|
+
for x in t_parts ):
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
# Not an NDArray
|
|
132
|
+
return False
|
|
133
|
+
|
|
111
134
|
@dataclass
|
|
112
135
|
class PackableSample( ABC ):
|
|
113
136
|
"""A sample that can be packed and unpacked with msgpack"""
|
|
@@ -116,10 +139,13 @@ class PackableSample( ABC ):
|
|
|
116
139
|
"""TODO Stupid kludge because of __post_init__ nonsense for wrapped classes"""
|
|
117
140
|
|
|
118
141
|
# Auto-convert known types when annotated
|
|
119
|
-
for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
142
|
+
# for var_name, var_type in vars( self.__class__ )['__annotations__'].items():
|
|
143
|
+
for field in dataclasses.fields( self ):
|
|
144
|
+
var_name = field.name
|
|
145
|
+
var_type = field.type
|
|
120
146
|
|
|
121
147
|
# Annotation for this variable is to be an NDArray
|
|
122
|
-
if var_type
|
|
148
|
+
if _is_possibly_ndarray_type( var_type ):
|
|
123
149
|
# ... so, we'll always auto-convert to numpy
|
|
124
150
|
|
|
125
151
|
var_cur_value = getattr( self, var_name )
|
|
@@ -135,6 +161,9 @@ class PackableSample( ABC ):
|
|
|
135
161
|
# setattr( self, var_name, var_cur_value.to_numpy )
|
|
136
162
|
|
|
137
163
|
elif isinstance( var_cur_value, bytes ):
|
|
164
|
+
# TODO This does create a constraint that serialized bytes
|
|
165
|
+
# in a field that might be an NDArray are always interpreted
|
|
166
|
+
# as being the NDArray interpretation
|
|
138
167
|
setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
|
|
139
168
|
|
|
140
169
|
def __post_init__( self ):
|
|
@@ -204,7 +233,7 @@ class SampleBatch( Generic[DT] ):
|
|
|
204
233
|
@property
|
|
205
234
|
def sample_type( self ) -> Type:
|
|
206
235
|
"""The type of each sample in this batch"""
|
|
207
|
-
return self.__orig_class__
|
|
236
|
+
return typing.get_args( self.__orig_class__)[0]
|
|
208
237
|
|
|
209
238
|
def __getattr__( self, name ):
|
|
210
239
|
# Aggregate named params of sample type
|
|
@@ -253,7 +282,7 @@ class Dataset( Generic[ST] ):
|
|
|
253
282
|
def sample_type( self ) -> Type:
|
|
254
283
|
"""The type of each returned sample from this `Dataset`'s iterator"""
|
|
255
284
|
# TODO Figure out why linting fails here
|
|
256
|
-
return self.__orig_class__
|
|
285
|
+
return typing.get_args( self.__orig_class__ )[0]
|
|
257
286
|
@property
|
|
258
287
|
def batch_type( self ) -> Type:
|
|
259
288
|
"""The type of a batch built from `sample_class`"""
|
|
@@ -371,29 +400,29 @@ class Dataset( Generic[ST] ):
|
|
|
371
400
|
|
|
372
401
|
if batch_size is None:
|
|
373
402
|
# TODO Duplication here
|
|
374
|
-
return wds.DataPipeline(
|
|
375
|
-
wds.SimpleShardList( self.url ),
|
|
376
|
-
wds.shuffle( buffer_shards ),
|
|
377
|
-
wds.split_by_worker,
|
|
403
|
+
return wds.pipeline.DataPipeline(
|
|
404
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
405
|
+
wds.filters.shuffle( buffer_shards ),
|
|
406
|
+
wds.shardlists.split_by_worker,
|
|
378
407
|
#
|
|
379
|
-
wds.tarfile_to_samples(),
|
|
408
|
+
wds.tariterators.tarfile_to_samples(),
|
|
380
409
|
# wds.shuffle( buffer_samples ),
|
|
381
410
|
# wds.map( self.preprocess ),
|
|
382
|
-
wds.shuffle( buffer_samples ),
|
|
383
|
-
wds.map( self.wrap ),
|
|
411
|
+
wds.filters.shuffle( buffer_samples ),
|
|
412
|
+
wds.filters.map( self.wrap ),
|
|
384
413
|
)
|
|
385
414
|
|
|
386
|
-
return wds.DataPipeline(
|
|
387
|
-
wds.SimpleShardList( self.url ),
|
|
388
|
-
wds.shuffle( buffer_shards ),
|
|
389
|
-
wds.split_by_worker,
|
|
415
|
+
return wds.pipeline.DataPipeline(
|
|
416
|
+
wds.shardlists.SimpleShardList( self.url ),
|
|
417
|
+
wds.filters.shuffle( buffer_shards ),
|
|
418
|
+
wds.shardlists.split_by_worker,
|
|
390
419
|
#
|
|
391
|
-
wds.tarfile_to_samples(),
|
|
420
|
+
wds.tariterators.tarfile_to_samples(),
|
|
392
421
|
# wds.shuffle( buffer_samples ),
|
|
393
422
|
# wds.map( self.preprocess ),
|
|
394
|
-
wds.shuffle( buffer_samples ),
|
|
395
|
-
wds.batched( batch_size ),
|
|
396
|
-
wds.map( self.wrap_batch ),
|
|
423
|
+
wds.filters.shuffle( buffer_samples ),
|
|
424
|
+
wds.filters.batched( batch_size ),
|
|
425
|
+
wds.filters.map( self.wrap_batch ),
|
|
397
426
|
)
|
|
398
427
|
|
|
399
428
|
# TODO Rewrite to eliminate `pandas` dependency directly calling
|
|
@@ -50,6 +50,12 @@ class NumpyTestSampleDecorated:
|
|
|
50
50
|
label: int
|
|
51
51
|
image: NDArray
|
|
52
52
|
|
|
53
|
+
@atdata.packable
|
|
54
|
+
class NumpyOptionalSampleDecorated:
|
|
55
|
+
label: int
|
|
56
|
+
image: NDArray
|
|
57
|
+
embeddings: NDArray | None = None
|
|
58
|
+
|
|
53
59
|
test_cases = [
|
|
54
60
|
{
|
|
55
61
|
'SampleType': BasicTestSample,
|
|
@@ -91,6 +97,28 @@ test_cases = [
|
|
|
91
97
|
'sample_wds_stem': 'numpy_test_decorated',
|
|
92
98
|
'test_parquet': False,
|
|
93
99
|
},
|
|
100
|
+
{
|
|
101
|
+
'SampleType': NumpyOptionalSampleDecorated,
|
|
102
|
+
'sample_data':
|
|
103
|
+
{
|
|
104
|
+
'label': 9_001,
|
|
105
|
+
'image': np.random.randn( 1024, 1024 ),
|
|
106
|
+
'embeddings': np.random.randn( 512 ),
|
|
107
|
+
},
|
|
108
|
+
'sample_wds_stem': 'numpy_optional_decorated',
|
|
109
|
+
'test_parquet': False,
|
|
110
|
+
},
|
|
111
|
+
{
|
|
112
|
+
'SampleType': NumpyOptionalSampleDecorated,
|
|
113
|
+
'sample_data':
|
|
114
|
+
{
|
|
115
|
+
'label': 9_001,
|
|
116
|
+
'image': np.random.randn( 1024, 1024 ),
|
|
117
|
+
'embeddings': None,
|
|
118
|
+
},
|
|
119
|
+
'sample_wds_stem': 'numpy_optional_decorated_none',
|
|
120
|
+
'test_parquet': False,
|
|
121
|
+
},
|
|
94
122
|
]
|
|
95
123
|
|
|
96
124
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|