atdata 0.1.1a2__py3-none-any.whl → 0.1.2a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata/__init__.py +14 -2
- atdata/_helpers.py +8 -16
- atdata/dataset.py +70 -33
- {atdata-0.1.1a2.dist-info → atdata-0.1.2a1.dist-info}/METADATA +1 -1
- atdata-0.1.2a1.dist-info/RECORD +8 -0
- atdata-0.1.2a1.dist-info/entry_points.txt +2 -0
- atdata-0.1.1a2.dist-info/RECORD +0 -8
- atdata-0.1.1a2.dist-info/entry_points.txt +0 -2
- {atdata-0.1.1a2.dist-info → atdata-0.1.2a1.dist-info}/WHEEL +0 -0
- {atdata-0.1.1a2.dist-info → atdata-0.1.2a1.dist-info}/licenses/LICENSE +0 -0
atdata/__init__.py
CHANGED
atdata/_helpers.py
CHANGED
|
@@ -1,30 +1,22 @@
|
|
|
1
|
-
"""Assorted helper methods for `
|
|
1
|
+
"""Assorted helper methods for `atdata`"""
|
|
2
2
|
|
|
3
3
|
##
|
|
4
4
|
# Imports
|
|
5
5
|
|
|
6
6
|
from io import BytesIO
|
|
7
|
-
import ormsgpack as omp
|
|
8
7
|
|
|
9
8
|
import numpy as np
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
##
|
|
13
|
-
#
|
|
14
12
|
|
|
15
|
-
def
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def unpack( bs: bytes ):
|
|
19
|
-
return omp.unpackb( bs )
|
|
20
|
-
|
|
21
|
-
##
|
|
22
|
-
|
|
23
|
-
def array_to_bytes(x: np.ndarray) -> bytes:
|
|
13
|
+
def array_to_bytes( x: np.ndarray ) -> bytes:
|
|
14
|
+
"""Convert `numpy` array to a format suitable for packing"""
|
|
24
15
|
np_bytes = BytesIO()
|
|
25
|
-
np.save(np_bytes, x, allow_pickle=True)
|
|
16
|
+
np.save( np_bytes, x, allow_pickle = True )
|
|
26
17
|
return np_bytes.getvalue()
|
|
27
18
|
|
|
28
|
-
def bytes_to_array(b: bytes) -> np.ndarray:
|
|
29
|
-
|
|
30
|
-
|
|
19
|
+
def bytes_to_array( b: bytes ) -> np.ndarray:
|
|
20
|
+
"""Convert packed bytes back to a `numpy` array"""
|
|
21
|
+
np_bytes = BytesIO( b )
|
|
22
|
+
return np.load( np_bytes, allow_pickle = True )
|
atdata/dataset.py
CHANGED
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
import webdataset as wds
|
|
7
7
|
|
|
8
|
+
import functools
|
|
8
9
|
from dataclasses import dataclass
|
|
9
10
|
import uuid
|
|
10
11
|
|
|
@@ -57,38 +58,38 @@ DT = TypeVar( 'DT' )
|
|
|
57
58
|
|
|
58
59
|
MsgpackRawSample: TypeAlias = Dict[str, Any]
|
|
59
60
|
|
|
60
|
-
@dataclass
|
|
61
|
-
class ArrayBytes:
|
|
62
|
-
|
|
63
|
-
|
|
61
|
+
# @dataclass
|
|
62
|
+
# class ArrayBytes:
|
|
63
|
+
# """Annotates bytes that should be interpreted as the raw contents of a
|
|
64
|
+
# numpy NDArray"""
|
|
64
65
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
66
|
+
# raw_bytes: bytes
|
|
67
|
+
# """The raw bytes of the corresponding NDArray"""
|
|
68
|
+
|
|
69
|
+
# def __init__( self,
|
|
70
|
+
# array: Optional[ArrayLike] = None,
|
|
71
|
+
# raw: Optional[bytes] = None,
|
|
72
|
+
# ):
|
|
73
|
+
# """TODO"""
|
|
74
|
+
|
|
75
|
+
# if array is not None:
|
|
76
|
+
# array = np.array( array )
|
|
77
|
+
# self.raw_bytes = eh.array_to_bytes( array )
|
|
77
78
|
|
|
78
|
-
|
|
79
|
-
|
|
79
|
+
# elif raw is not None:
|
|
80
|
+
# self.raw_bytes = raw
|
|
80
81
|
|
|
81
|
-
|
|
82
|
-
|
|
82
|
+
# else:
|
|
83
|
+
# raise ValueError( 'Must provide either `array` or `raw` bytes' )
|
|
83
84
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
85
|
+
# @property
|
|
86
|
+
# def to_numpy( self ) -> NDArray:
|
|
87
|
+
# """Return the `raw_bytes` data as an NDArray"""
|
|
88
|
+
# return eh.bytes_to_array( self.raw_bytes )
|
|
88
89
|
|
|
89
90
|
def _make_packable( x ):
|
|
90
|
-
if isinstance( x, ArrayBytes ):
|
|
91
|
-
|
|
91
|
+
# if isinstance( x, ArrayBytes ):
|
|
92
|
+
# return x.raw_bytes
|
|
92
93
|
if isinstance( x, np.ndarray ):
|
|
93
94
|
return eh.array_to_bytes( x )
|
|
94
95
|
return x
|
|
@@ -114,8 +115,8 @@ class PackableSample( ABC ):
|
|
|
114
115
|
# we're good!
|
|
115
116
|
pass
|
|
116
117
|
|
|
117
|
-
elif isinstance( var_cur_value, ArrayBytes ):
|
|
118
|
-
|
|
118
|
+
# elif isinstance( var_cur_value, ArrayBytes ):
|
|
119
|
+
# setattr( self, var_name, var_cur_value.to_numpy )
|
|
119
120
|
|
|
120
121
|
elif isinstance( var_cur_value, bytes ):
|
|
121
122
|
setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
|
|
@@ -172,7 +173,7 @@ def _batch_aggregate( xs: Sequence ):
|
|
|
172
173
|
|
|
173
174
|
return list( xs )
|
|
174
175
|
|
|
175
|
-
class
|
|
176
|
+
class SampleBatch( Generic[DT] ):
|
|
176
177
|
|
|
177
178
|
def __init__( self, samples: Sequence[DT] ):
|
|
178
179
|
"""TODO"""
|
|
@@ -233,7 +234,7 @@ class Dataset( Generic[ST] ):
|
|
|
233
234
|
def batch_type( self ) -> Type:
|
|
234
235
|
"""The type of a batch built from `sample_class`"""
|
|
235
236
|
# return self.__orig_class__.__args__[1]
|
|
236
|
-
return
|
|
237
|
+
return SampleBatch[self.sample_type]
|
|
237
238
|
|
|
238
239
|
|
|
239
240
|
# _schema_registry_sample: dict[str, Type]
|
|
@@ -396,7 +397,7 @@ class Dataset( Generic[ST] ):
|
|
|
396
397
|
value = sample,
|
|
397
398
|
)
|
|
398
399
|
|
|
399
|
-
def wrap_batch( self, batch: WDSRawBatch ) ->
|
|
400
|
+
def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
|
|
400
401
|
"""Wrap a `batch` of samples into the appropriate dataset-specific type
|
|
401
402
|
|
|
402
403
|
This default implementation simply creates a list one sample at a time
|
|
@@ -405,7 +406,7 @@ class Dataset( Generic[ST] ):
|
|
|
405
406
|
assert 'msgpack' in batch
|
|
406
407
|
batch_unpacked = [ self.sample_type.from_bytes( bs )
|
|
407
408
|
for bs in batch['msgpack'] ]
|
|
408
|
-
return
|
|
409
|
+
return SampleBatch[self.sample_type]( batch_unpacked )
|
|
409
410
|
|
|
410
411
|
|
|
411
412
|
# # @classmethod
|
|
@@ -415,4 +416,40 @@ class Dataset( Generic[ST] ):
|
|
|
415
416
|
# This default implementation simply creates a list one sample at a time
|
|
416
417
|
# """
|
|
417
418
|
# assert cls.batch_class is not None, 'No batch class specified'
|
|
418
|
-
# return cls.batch_class( **batch )
|
|
419
|
+
# return cls.batch_class( **batch )
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
##
|
|
423
|
+
# Shortcut decorators
|
|
424
|
+
|
|
425
|
+
# def packable( cls ):
|
|
426
|
+
# """TODO"""
|
|
427
|
+
|
|
428
|
+
# def decorator( cls ):
|
|
429
|
+
# # Create a new class dynamically
|
|
430
|
+
# # The new class inherits from the new_parent_class first, then the original cls
|
|
431
|
+
# new_bases = (PackableSample,) + cls.__bases__
|
|
432
|
+
# new_cls = type(cls.__name__, new_bases, dict(cls.__dict__))
|
|
433
|
+
|
|
434
|
+
# # Optionally, update __module__ and __qualname__ for better introspection
|
|
435
|
+
# new_cls.__module__ = cls.__module__
|
|
436
|
+
# new_cls.__qualname__ = cls.__qualname__
|
|
437
|
+
|
|
438
|
+
# return new_cls
|
|
439
|
+
# return decorator
|
|
440
|
+
|
|
441
|
+
def packable( cls ):
|
|
442
|
+
"""TODO"""
|
|
443
|
+
|
|
444
|
+
##
|
|
445
|
+
|
|
446
|
+
as_dataclass = dataclass( cls )
|
|
447
|
+
|
|
448
|
+
class as_packable( as_dataclass, PackableSample ):
|
|
449
|
+
pass
|
|
450
|
+
|
|
451
|
+
as_packable.__name__ = cls.__name__
|
|
452
|
+
|
|
453
|
+
##
|
|
454
|
+
|
|
455
|
+
return as_packable
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
atdata/__init__.py,sha256=jPZVd_6UIo0DSbCnXAnYZ2eMwHYzOk--5vtEDTZvwqw,173
|
|
2
|
+
atdata/_helpers.py,sha256=R63JhXewAKZYnZ9Th7R6yZh0IOUPYGBsth3FpRUMD-U,503
|
|
3
|
+
atdata/dataset.py,sha256=xPxDkQk1fBkU3sbLMT_Rm4CmvoBXIpEpmRNxrir4nis,13045
|
|
4
|
+
atdata-0.1.2a1.dist-info/METADATA,sha256=8aK-P0A7YZ2Cl6r_GC3Gi1huZRDrcR9zME2gMdT-fFc,434
|
|
5
|
+
atdata-0.1.2a1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
atdata-0.1.2a1.dist-info/entry_points.txt,sha256=6-iQr1veSTq-ac94bLyfcyGHprrZWevPEd12BWX37tQ,39
|
|
7
|
+
atdata-0.1.2a1.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
8
|
+
atdata-0.1.2a1.dist-info/RECORD,,
|
atdata-0.1.1a2.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
atdata/__init__.py,sha256=yN07kW_3UcMlYZrM_Jrpy6DMCzTp9kvu2ICcU7n1-5w,52
|
|
2
|
-
atdata/_helpers.py,sha256=CjIvLruNOhHRl1Arse5SahGTmI0Et3BoNqsWC9b8noE,515
|
|
3
|
-
atdata/dataset.py,sha256=mvmCYtL6wD9961qq4lprZSkone56ubTKp3vDgEnWdPI,12158
|
|
4
|
-
atdata-0.1.1a2.dist-info/METADATA,sha256=Mf75Ai8KKLhm7r2m65dP_QV6e9dkE7mTHtGSASGEKBE,434
|
|
5
|
-
atdata-0.1.1a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
atdata-0.1.1a2.dist-info/entry_points.txt,sha256=KuQtj4ZAwWLSyJUxhpQEHYfwSG-0ZXuj5hcZ1uAgGRQ,39
|
|
7
|
-
atdata-0.1.1a2.dist-info/licenses/LICENSE,sha256=Pz2eACSxkhsGfW9_iN60pgy-enjnbGTj8df8O3ebnQQ,16726
|
|
8
|
-
atdata-0.1.1a2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|