atdata 0.1.1a2__tar.gz → 0.1.2a1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atdata-0.1.2a1/.github/workflows/uv-publish-pypi.yml +46 -0
- atdata-0.1.2a1/.github/workflows/uv-test.yml +40 -0
- {atdata-0.1.1a2 → atdata-0.1.2a1}/.gitignore +2 -0
- {atdata-0.1.1a2 → atdata-0.1.2a1}/PKG-INFO +1 -1
- {atdata-0.1.1a2 → atdata-0.1.2a1}/pyproject.toml +3 -3
- atdata-0.1.2a1/src/atdata/__init__.py +14 -0
- atdata-0.1.2a1/src/atdata/_helpers.py +22 -0
- {atdata-0.1.1a2 → atdata-0.1.2a1}/src/atdata/dataset.py +70 -33
- atdata-0.1.2a1/tests/test_dataset.py +301 -0
- atdata-0.1.1a2/.github/workflows/python-package.yml +0 -66
- atdata-0.1.1a2/.github/workflows/python-publish.yml +0 -129
- atdata-0.1.1a2/src/atdata/__init__.py +0 -2
- atdata-0.1.1a2/src/atdata/_helpers.py +0 -30
- atdata-0.1.1a2/tests/test_dataset.py +0 -69
- {atdata-0.1.1a2 → atdata-0.1.2a1}/.python-version +0 -0
- {atdata-0.1.1a2 → atdata-0.1.2a1}/LICENSE +0 -0
- {atdata-0.1.1a2 → atdata-0.1.2a1}/README.md +0 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
#
|
|
2
|
+
|
|
3
|
+
name: Build and upload package to PyPI
|
|
4
|
+
|
|
5
|
+
on:
|
|
6
|
+
release:
|
|
7
|
+
types:
|
|
8
|
+
- published
|
|
9
|
+
|
|
10
|
+
permissions:
|
|
11
|
+
contents: read
|
|
12
|
+
|
|
13
|
+
jobs:
|
|
14
|
+
|
|
15
|
+
uv-build-release-pypi-publish:
|
|
16
|
+
name: "Build release distribution and publish to PyPI"
|
|
17
|
+
runs-on: ubuntu-latest
|
|
18
|
+
environment:
|
|
19
|
+
name: pypi
|
|
20
|
+
|
|
21
|
+
steps:
|
|
22
|
+
- uses: actions/checkout@v5
|
|
23
|
+
|
|
24
|
+
- name: "Set up Python"
|
|
25
|
+
uses: actions/setup-python@v5
|
|
26
|
+
with:
|
|
27
|
+
python-version-file: "pyproject.toml"
|
|
28
|
+
|
|
29
|
+
- name: Install uv
|
|
30
|
+
uses: astral-sh/setup-uv@v6
|
|
31
|
+
|
|
32
|
+
- name: Install project
|
|
33
|
+
run: uv sync --all-extras --dev
|
|
34
|
+
# TODO Better to use --locked for author control over versions?
|
|
35
|
+
# run: uv sync --locked --all-extras --dev
|
|
36
|
+
|
|
37
|
+
- name: Build release distributions
|
|
38
|
+
run: uv build
|
|
39
|
+
|
|
40
|
+
- name: Publish to PyPI
|
|
41
|
+
env:
|
|
42
|
+
UV_PUBLISH_TOKEN: ${{ secrets.UV_PUBLISH_TOKEN }}
|
|
43
|
+
run: uv publish
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
##
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
#
|
|
2
|
+
|
|
3
|
+
name: Run tests with `uv`
|
|
4
|
+
|
|
5
|
+
on:
|
|
6
|
+
push:
|
|
7
|
+
branches:
|
|
8
|
+
- main
|
|
9
|
+
- release/*
|
|
10
|
+
pull_request:
|
|
11
|
+
branches:
|
|
12
|
+
- main
|
|
13
|
+
|
|
14
|
+
jobs:
|
|
15
|
+
uv-test:
|
|
16
|
+
name: Run tests
|
|
17
|
+
runs-on: ubuntu-latest
|
|
18
|
+
|
|
19
|
+
steps:
|
|
20
|
+
- uses: actions/checkout@v5
|
|
21
|
+
|
|
22
|
+
- name: "Set up Python"
|
|
23
|
+
uses: actions/setup-python@v5
|
|
24
|
+
with:
|
|
25
|
+
python-version-file: "pyproject.toml"
|
|
26
|
+
|
|
27
|
+
- name: Install uv
|
|
28
|
+
uses: astral-sh/setup-uv@v6
|
|
29
|
+
|
|
30
|
+
- name: Install the project
|
|
31
|
+
run: uv sync --all-extras --dev
|
|
32
|
+
# TODO Better to use --locked for author control over versions?
|
|
33
|
+
# run: uv sync --locked --all-extras --dev
|
|
34
|
+
|
|
35
|
+
- name: Run tests
|
|
36
|
+
# For example, using `pytest`
|
|
37
|
+
run: uv run pytest tests
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
#
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "atdata"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.2a1"
|
|
4
4
|
description = "A loose federation of distributed, typed datasets"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [
|
|
@@ -15,14 +15,14 @@ dependencies = [
|
|
|
15
15
|
]
|
|
16
16
|
|
|
17
17
|
[project.scripts]
|
|
18
|
-
|
|
18
|
+
atdata = "atdata:main"
|
|
19
19
|
|
|
20
20
|
[build-system]
|
|
21
21
|
requires = ["hatchling"]
|
|
22
22
|
build-backend = "hatchling.build"
|
|
23
23
|
|
|
24
24
|
[tool.pytest.ini_options]
|
|
25
|
-
addopts = "--cov=atdata"
|
|
25
|
+
addopts = "--cov=atdata --cov-report=html"
|
|
26
26
|
|
|
27
27
|
[dependency-groups]
|
|
28
28
|
dev = [
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Assorted helper methods for `atdata`"""
|
|
2
|
+
|
|
3
|
+
##
|
|
4
|
+
# Imports
|
|
5
|
+
|
|
6
|
+
from io import BytesIO
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
##
|
|
12
|
+
|
|
13
|
+
def array_to_bytes( x: np.ndarray ) -> bytes:
|
|
14
|
+
"""Convert `numpy` array to a format suitable for packing"""
|
|
15
|
+
np_bytes = BytesIO()
|
|
16
|
+
np.save( np_bytes, x, allow_pickle = True )
|
|
17
|
+
return np_bytes.getvalue()
|
|
18
|
+
|
|
19
|
+
def bytes_to_array( b: bytes ) -> np.ndarray:
|
|
20
|
+
"""Convert packed bytes back to a `numpy` array"""
|
|
21
|
+
np_bytes = BytesIO( b )
|
|
22
|
+
return np.load( np_bytes, allow_pickle = True )
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
import webdataset as wds
|
|
7
7
|
|
|
8
|
+
import functools
|
|
8
9
|
from dataclasses import dataclass
|
|
9
10
|
import uuid
|
|
10
11
|
|
|
@@ -57,38 +58,38 @@ DT = TypeVar( 'DT' )
|
|
|
57
58
|
|
|
58
59
|
MsgpackRawSample: TypeAlias = Dict[str, Any]
|
|
59
60
|
|
|
60
|
-
@dataclass
|
|
61
|
-
class ArrayBytes:
|
|
62
|
-
|
|
63
|
-
|
|
61
|
+
# @dataclass
|
|
62
|
+
# class ArrayBytes:
|
|
63
|
+
# """Annotates bytes that should be interpreted as the raw contents of a
|
|
64
|
+
# numpy NDArray"""
|
|
64
65
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
66
|
+
# raw_bytes: bytes
|
|
67
|
+
# """The raw bytes of the corresponding NDArray"""
|
|
68
|
+
|
|
69
|
+
# def __init__( self,
|
|
70
|
+
# array: Optional[ArrayLike] = None,
|
|
71
|
+
# raw: Optional[bytes] = None,
|
|
72
|
+
# ):
|
|
73
|
+
# """TODO"""
|
|
74
|
+
|
|
75
|
+
# if array is not None:
|
|
76
|
+
# array = np.array( array )
|
|
77
|
+
# self.raw_bytes = eh.array_to_bytes( array )
|
|
77
78
|
|
|
78
|
-
|
|
79
|
-
|
|
79
|
+
# elif raw is not None:
|
|
80
|
+
# self.raw_bytes = raw
|
|
80
81
|
|
|
81
|
-
|
|
82
|
-
|
|
82
|
+
# else:
|
|
83
|
+
# raise ValueError( 'Must provide either `array` or `raw` bytes' )
|
|
83
84
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
85
|
+
# @property
|
|
86
|
+
# def to_numpy( self ) -> NDArray:
|
|
87
|
+
# """Return the `raw_bytes` data as an NDArray"""
|
|
88
|
+
# return eh.bytes_to_array( self.raw_bytes )
|
|
88
89
|
|
|
89
90
|
def _make_packable( x ):
|
|
90
|
-
if isinstance( x, ArrayBytes ):
|
|
91
|
-
|
|
91
|
+
# if isinstance( x, ArrayBytes ):
|
|
92
|
+
# return x.raw_bytes
|
|
92
93
|
if isinstance( x, np.ndarray ):
|
|
93
94
|
return eh.array_to_bytes( x )
|
|
94
95
|
return x
|
|
@@ -114,8 +115,8 @@ class PackableSample( ABC ):
|
|
|
114
115
|
# we're good!
|
|
115
116
|
pass
|
|
116
117
|
|
|
117
|
-
elif isinstance( var_cur_value, ArrayBytes ):
|
|
118
|
-
|
|
118
|
+
# elif isinstance( var_cur_value, ArrayBytes ):
|
|
119
|
+
# setattr( self, var_name, var_cur_value.to_numpy )
|
|
119
120
|
|
|
120
121
|
elif isinstance( var_cur_value, bytes ):
|
|
121
122
|
setattr( self, var_name, eh.bytes_to_array( var_cur_value ) )
|
|
@@ -172,7 +173,7 @@ def _batch_aggregate( xs: Sequence ):
|
|
|
172
173
|
|
|
173
174
|
return list( xs )
|
|
174
175
|
|
|
175
|
-
class
|
|
176
|
+
class SampleBatch( Generic[DT] ):
|
|
176
177
|
|
|
177
178
|
def __init__( self, samples: Sequence[DT] ):
|
|
178
179
|
"""TODO"""
|
|
@@ -233,7 +234,7 @@ class Dataset( Generic[ST] ):
|
|
|
233
234
|
def batch_type( self ) -> Type:
|
|
234
235
|
"""The type of a batch built from `sample_class`"""
|
|
235
236
|
# return self.__orig_class__.__args__[1]
|
|
236
|
-
return
|
|
237
|
+
return SampleBatch[self.sample_type]
|
|
237
238
|
|
|
238
239
|
|
|
239
240
|
# _schema_registry_sample: dict[str, Type]
|
|
@@ -396,7 +397,7 @@ class Dataset( Generic[ST] ):
|
|
|
396
397
|
value = sample,
|
|
397
398
|
)
|
|
398
399
|
|
|
399
|
-
def wrap_batch( self, batch: WDSRawBatch ) ->
|
|
400
|
+
def wrap_batch( self, batch: WDSRawBatch ) -> SampleBatch[ST]:
|
|
400
401
|
"""Wrap a `batch` of samples into the appropriate dataset-specific type
|
|
401
402
|
|
|
402
403
|
This default implementation simply creates a list one sample at a time
|
|
@@ -405,7 +406,7 @@ class Dataset( Generic[ST] ):
|
|
|
405
406
|
assert 'msgpack' in batch
|
|
406
407
|
batch_unpacked = [ self.sample_type.from_bytes( bs )
|
|
407
408
|
for bs in batch['msgpack'] ]
|
|
408
|
-
return
|
|
409
|
+
return SampleBatch[self.sample_type]( batch_unpacked )
|
|
409
410
|
|
|
410
411
|
|
|
411
412
|
# # @classmethod
|
|
@@ -415,4 +416,40 @@ class Dataset( Generic[ST] ):
|
|
|
415
416
|
# This default implementation simply creates a list one sample at a time
|
|
416
417
|
# """
|
|
417
418
|
# assert cls.batch_class is not None, 'No batch class specified'
|
|
418
|
-
# return cls.batch_class( **batch )
|
|
419
|
+
# return cls.batch_class( **batch )
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
##
|
|
423
|
+
# Shortcut decorators
|
|
424
|
+
|
|
425
|
+
# def packable( cls ):
|
|
426
|
+
# """TODO"""
|
|
427
|
+
|
|
428
|
+
# def decorator( cls ):
|
|
429
|
+
# # Create a new class dynamically
|
|
430
|
+
# # The new class inherits from the new_parent_class first, then the original cls
|
|
431
|
+
# new_bases = (PackableSample,) + cls.__bases__
|
|
432
|
+
# new_cls = type(cls.__name__, new_bases, dict(cls.__dict__))
|
|
433
|
+
|
|
434
|
+
# # Optionally, update __module__ and __qualname__ for better introspection
|
|
435
|
+
# new_cls.__module__ = cls.__module__
|
|
436
|
+
# new_cls.__qualname__ = cls.__qualname__
|
|
437
|
+
|
|
438
|
+
# return new_cls
|
|
439
|
+
# return decorator
|
|
440
|
+
|
|
441
|
+
def packable( cls ):
|
|
442
|
+
"""TODO"""
|
|
443
|
+
|
|
444
|
+
##
|
|
445
|
+
|
|
446
|
+
as_dataclass = dataclass( cls )
|
|
447
|
+
|
|
448
|
+
class as_packable( as_dataclass, PackableSample ):
|
|
449
|
+
pass
|
|
450
|
+
|
|
451
|
+
as_packable.__name__ = cls.__name__
|
|
452
|
+
|
|
453
|
+
##
|
|
454
|
+
|
|
455
|
+
return as_packable
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
"""Test dataaset functionality."""
|
|
2
|
+
|
|
3
|
+
##
|
|
4
|
+
# Imports
|
|
5
|
+
|
|
6
|
+
# Tests
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
# System
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
|
|
12
|
+
# External
|
|
13
|
+
import numpy as np
|
|
14
|
+
import webdataset as wds
|
|
15
|
+
|
|
16
|
+
# Local
|
|
17
|
+
import atdata
|
|
18
|
+
import atdata.dataset as atds
|
|
19
|
+
|
|
20
|
+
# Typing
|
|
21
|
+
from numpy.typing import NDArray
|
|
22
|
+
from typing import (
|
|
23
|
+
Type,
|
|
24
|
+
Any,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
##
|
|
29
|
+
# Sample test cases
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class BasicTestSample( atdata.PackableSample ):
|
|
33
|
+
name: str
|
|
34
|
+
position: int
|
|
35
|
+
value: float
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class NumpyTestSample( atdata.PackableSample ):
|
|
39
|
+
label: int
|
|
40
|
+
image: NDArray
|
|
41
|
+
|
|
42
|
+
test_cases = [
|
|
43
|
+
{
|
|
44
|
+
'SampleType': BasicTestSample,
|
|
45
|
+
'sample_data': {
|
|
46
|
+
'name': 'Hello, world!',
|
|
47
|
+
'position': 42,
|
|
48
|
+
'value': 1024.768,
|
|
49
|
+
},
|
|
50
|
+
'sample_wds_stem': 'basic_test',
|
|
51
|
+
},
|
|
52
|
+
{
|
|
53
|
+
'SampleType': NumpyTestSample,
|
|
54
|
+
'sample_data':
|
|
55
|
+
{
|
|
56
|
+
'label': 9_001,
|
|
57
|
+
'image': np.random.randn( 1024, 1024 ),
|
|
58
|
+
},
|
|
59
|
+
'sample_wds_stem': 'numpy_test',
|
|
60
|
+
},
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
## Tests
|
|
65
|
+
|
|
66
|
+
@pytest.mark.parametrize(
|
|
67
|
+
('SampleType', 'sample_data'),
|
|
68
|
+
[ (case['SampleType'], case['sample_data'])
|
|
69
|
+
for case in test_cases ]
|
|
70
|
+
)
|
|
71
|
+
def test_create_sample(
|
|
72
|
+
SampleType: Type[atdata.PackableSample],
|
|
73
|
+
sample_data: atds.MsgpackRawSample,
|
|
74
|
+
):
|
|
75
|
+
"""Test our ability to create samples from semi-structured data"""
|
|
76
|
+
|
|
77
|
+
sample = SampleType.from_data( sample_data )
|
|
78
|
+
assert isinstance( sample, SampleType ), \
|
|
79
|
+
f'Did not properly form sample for test type {SampleType}'
|
|
80
|
+
|
|
81
|
+
for k, v in sample_data.items():
|
|
82
|
+
cur_assertion: bool
|
|
83
|
+
if isinstance( v, np.ndarray ):
|
|
84
|
+
cur_assertion = np.all( getattr( sample, k ) == v ) == True
|
|
85
|
+
else:
|
|
86
|
+
cur_assertion = getattr( sample, k ) == v
|
|
87
|
+
assert cur_assertion, \
|
|
88
|
+
f'Did not properly incorporate property {k} of test type {SampleType}'
|
|
89
|
+
|
|
90
|
+
#
|
|
91
|
+
|
|
92
|
+
def test_decorator_syntax():
|
|
93
|
+
"""Test use of decorator syntax for sample types"""
|
|
94
|
+
|
|
95
|
+
@atdata.packable
|
|
96
|
+
class BasicTestSampleDecorated:
|
|
97
|
+
name: str
|
|
98
|
+
position: int
|
|
99
|
+
value: float
|
|
100
|
+
|
|
101
|
+
@atdata.packable
|
|
102
|
+
class NumpyTestSampleDecorated:
|
|
103
|
+
label: int
|
|
104
|
+
image: NDArray
|
|
105
|
+
|
|
106
|
+
##
|
|
107
|
+
|
|
108
|
+
test_create_sample( BasicTestSampleDecorated, {
|
|
109
|
+
'name': 'Hello, world!',
|
|
110
|
+
'position': 42,
|
|
111
|
+
'value': 1024.768,
|
|
112
|
+
} )
|
|
113
|
+
|
|
114
|
+
test_create_sample( NumpyTestSampleDecorated, {
|
|
115
|
+
'label': 9_001,
|
|
116
|
+
'image': np.random.randn( 1024, 1024 ),
|
|
117
|
+
} )
|
|
118
|
+
|
|
119
|
+
#
|
|
120
|
+
|
|
121
|
+
@pytest.mark.parametrize(
|
|
122
|
+
('SampleType', 'sample_data', 'sample_wds_stem'),
|
|
123
|
+
[ (case['SampleType'], case['sample_data'], case['sample_wds_stem'])
|
|
124
|
+
for case in test_cases ]
|
|
125
|
+
)
|
|
126
|
+
def test_wds(
|
|
127
|
+
SampleType: Type[atdata.PackableSample],
|
|
128
|
+
sample_data: atds.MsgpackRawSample,
|
|
129
|
+
sample_wds_stem: str,
|
|
130
|
+
tmp_path
|
|
131
|
+
):
|
|
132
|
+
"""Test our ability to write samples as `WebDatasets` to disk"""
|
|
133
|
+
|
|
134
|
+
## Testing hyperparameters
|
|
135
|
+
|
|
136
|
+
n_copies = 100
|
|
137
|
+
shard_maxcount = 10
|
|
138
|
+
batch_size = 4
|
|
139
|
+
n_iterate = 10
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
## Write sharded dataset
|
|
143
|
+
|
|
144
|
+
file_pattern = (
|
|
145
|
+
tmp_path
|
|
146
|
+
/ (f'{sample_wds_stem}' + '-{shard_id}.tar')
|
|
147
|
+
).as_posix()
|
|
148
|
+
file_wds_pattern = file_pattern.format( shard_id = '%06d' )
|
|
149
|
+
|
|
150
|
+
with wds.ShardWriter(
|
|
151
|
+
pattern = file_wds_pattern,
|
|
152
|
+
maxcount = shard_maxcount,
|
|
153
|
+
) as sink:
|
|
154
|
+
|
|
155
|
+
for i_sample in range( n_copies ):
|
|
156
|
+
new_sample = SampleType.from_data( sample_data )
|
|
157
|
+
assert isinstance( new_sample, SampleType ), \
|
|
158
|
+
f'Did not properly form sample for test type {SampleType}'
|
|
159
|
+
|
|
160
|
+
sink.write( new_sample.as_wds )
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
## Ordered
|
|
164
|
+
|
|
165
|
+
# Read first shard, no batches
|
|
166
|
+
|
|
167
|
+
first_filename = file_pattern.format( shard_id = f'{0:06d}' )
|
|
168
|
+
dataset = atdata.Dataset[SampleType]( first_filename )
|
|
169
|
+
|
|
170
|
+
iterations_run = 0
|
|
171
|
+
for i_iterate, cur_sample in enumerate( dataset.ordered( batch_size = None ) ):
|
|
172
|
+
|
|
173
|
+
assert isinstance( cur_sample, SampleType ), \
|
|
174
|
+
f'Single sample for {SampleType} written to `wds` is of wrong type'
|
|
175
|
+
|
|
176
|
+
# Check sample values
|
|
177
|
+
|
|
178
|
+
for k, v in sample_data.items():
|
|
179
|
+
if isinstance( v, np.ndarray ):
|
|
180
|
+
is_correct = np.all( getattr( cur_sample, k ) == v )
|
|
181
|
+
else:
|
|
182
|
+
is_correct = getattr( cur_sample, k ) == v
|
|
183
|
+
assert is_correct, \
|
|
184
|
+
f'{SampleType}: Incorrect sample value found for {k}'
|
|
185
|
+
|
|
186
|
+
iterations_run += 1
|
|
187
|
+
if iterations_run >= n_iterate:
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
assert iterations_run == n_iterate, \
|
|
191
|
+
f"Only found {iterations_run} samples, not {n_iterate}"
|
|
192
|
+
|
|
193
|
+
# Read all shards, batches
|
|
194
|
+
|
|
195
|
+
start_id = f'{0:06d}'
|
|
196
|
+
end_id = f'{9:06d}'
|
|
197
|
+
first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' )
|
|
198
|
+
print( first_filename )
|
|
199
|
+
dataset = atdata.Dataset[SampleType]( first_filename )
|
|
200
|
+
|
|
201
|
+
iterations_run = 0
|
|
202
|
+
for i_iterate, cur_batch in enumerate( dataset.ordered( batch_size = batch_size ) ):
|
|
203
|
+
|
|
204
|
+
assert isinstance( cur_batch, atdata.SampleBatch ), \
|
|
205
|
+
f'{SampleType}: Batch sample is not correctly a batch'
|
|
206
|
+
|
|
207
|
+
assert cur_batch.sample_type == SampleType, \
|
|
208
|
+
f'{SampleType}: Batch `sample_type` is incorrect type'
|
|
209
|
+
|
|
210
|
+
if i_iterate == 0:
|
|
211
|
+
cur_n = len( cur_batch.samples )
|
|
212
|
+
assert cur_n == batch_size, \
|
|
213
|
+
f'{SampleType}: Batch has {cur_n} samples, not {batch_size}'
|
|
214
|
+
|
|
215
|
+
assert isinstance( cur_batch.samples[0], SampleType ), \
|
|
216
|
+
f'{SampleType}: Batch sample of wrong type ({type( cur_batch.samples[0])})'
|
|
217
|
+
|
|
218
|
+
# Check batch values
|
|
219
|
+
for k, v in sample_data.items():
|
|
220
|
+
cur_batch_data = getattr( cur_batch, k )
|
|
221
|
+
|
|
222
|
+
if isinstance( v, np.ndarray ):
|
|
223
|
+
assert isinstance( cur_batch_data, np.ndarray ), \
|
|
224
|
+
f'{SampleType}: `NDArray` not carried through to batch'
|
|
225
|
+
|
|
226
|
+
is_correct = all(
|
|
227
|
+
[ np.all( cur_batch_data[i] == v )
|
|
228
|
+
for i in range( cur_batch_data.shape[0] ) ]
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
else:
|
|
232
|
+
is_correct = all(
|
|
233
|
+
[ cur_batch_data[i] == v
|
|
234
|
+
for i in range( len( cur_batch_data ) ) ]
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
assert is_correct, \
|
|
238
|
+
f'{SampleType}: Incorrect sample value found for {k}'
|
|
239
|
+
|
|
240
|
+
iterations_run += 1
|
|
241
|
+
if iterations_run >= n_iterate:
|
|
242
|
+
break
|
|
243
|
+
|
|
244
|
+
assert iterations_run == n_iterate, \
|
|
245
|
+
"Only found {iterations_run} samples, not {n_iterate}"
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
## Shuffled
|
|
249
|
+
|
|
250
|
+
# Read first shard, no batches
|
|
251
|
+
|
|
252
|
+
first_filename = file_pattern.format( shard_id = f'{0:06d}' )
|
|
253
|
+
dataset = atdata.Dataset[SampleType]( first_filename )
|
|
254
|
+
|
|
255
|
+
iterations_run = 0
|
|
256
|
+
for i_iterate, cur_sample in enumerate( dataset.shuffled( batch_size = None ) ):
|
|
257
|
+
|
|
258
|
+
assert isinstance( cur_sample, SampleType ), \
|
|
259
|
+
f'Single sample for {SampleType} written to `wds` is of wrong type'
|
|
260
|
+
|
|
261
|
+
iterations_run += 1
|
|
262
|
+
if iterations_run >= n_iterate:
|
|
263
|
+
break
|
|
264
|
+
|
|
265
|
+
assert iterations_run == n_iterate, \
|
|
266
|
+
f"Only found {iterations_run} samples, not {n_iterate}"
|
|
267
|
+
|
|
268
|
+
# Read all shards, batches
|
|
269
|
+
|
|
270
|
+
start_id = f'{0:06d}'
|
|
271
|
+
end_id = f'{9:06d}'
|
|
272
|
+
first_filename = file_pattern.format( shard_id = '{' + start_id + '..' + end_id + '}' )
|
|
273
|
+
print( first_filename )
|
|
274
|
+
dataset = atdata.Dataset[SampleType]( first_filename )
|
|
275
|
+
|
|
276
|
+
iterations_run = 0
|
|
277
|
+
for i_iterate, cur_sample in enumerate( dataset.shuffled( batch_size = batch_size ) ):
|
|
278
|
+
|
|
279
|
+
assert isinstance( cur_sample, atdata.SampleBatch ), \
|
|
280
|
+
f'{SampleType}: Batch sample is not correctly a batch'
|
|
281
|
+
|
|
282
|
+
assert cur_sample.sample_type == SampleType, \
|
|
283
|
+
f'{SampleType}: Batch `sample_type` is incorrect type'
|
|
284
|
+
|
|
285
|
+
if i_iterate == 0:
|
|
286
|
+
cur_n = len( cur_sample.samples )
|
|
287
|
+
assert cur_n == batch_size, \
|
|
288
|
+
f'{SampleType}: Batch has {cur_n} samples, not {batch_size}'
|
|
289
|
+
|
|
290
|
+
assert isinstance( cur_sample.samples[0], SampleType ), \
|
|
291
|
+
f'{SampleType}: Batch sample of wrong type ({type( cur_sample.samples[0])})'
|
|
292
|
+
|
|
293
|
+
iterations_run += 1
|
|
294
|
+
if iterations_run >= n_iterate:
|
|
295
|
+
break
|
|
296
|
+
|
|
297
|
+
assert iterations_run == n_iterate, \
|
|
298
|
+
"Only found {iterations_run} samples, not {n_iterate}"
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
##
|
|
@@ -1,66 +0,0 @@
|
|
|
1
|
-
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
|
2
|
-
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
|
3
|
-
|
|
4
|
-
name: Python package
|
|
5
|
-
|
|
6
|
-
on:
|
|
7
|
-
push:
|
|
8
|
-
branches: [ "main" ]
|
|
9
|
-
pull_request:
|
|
10
|
-
branches: [ "main" ]
|
|
11
|
-
|
|
12
|
-
jobs:
|
|
13
|
-
uv-test:
|
|
14
|
-
name: python
|
|
15
|
-
runs-on: ubuntu-latest
|
|
16
|
-
|
|
17
|
-
steps:
|
|
18
|
-
- uses: actions/checkout@v5
|
|
19
|
-
|
|
20
|
-
- name: "Set up Python"
|
|
21
|
-
uses: actions/setup-python@v5
|
|
22
|
-
with:
|
|
23
|
-
python-version-file: "pyproject.toml"
|
|
24
|
-
|
|
25
|
-
- name: Install uv
|
|
26
|
-
uses: astral-sh/setup-uv@v6
|
|
27
|
-
|
|
28
|
-
- name: Install the project
|
|
29
|
-
run: uv sync --all-extras --dev
|
|
30
|
-
# TODO Better to use --locked for author control over versions?
|
|
31
|
-
# run: uv sync --locked --all-extras --dev
|
|
32
|
-
|
|
33
|
-
- name: Run tests
|
|
34
|
-
# For example, using `pytest`
|
|
35
|
-
run: uv run pytest tests
|
|
36
|
-
|
|
37
|
-
# OLD - kept for legacy
|
|
38
|
-
# jobs:
|
|
39
|
-
# build:
|
|
40
|
-
|
|
41
|
-
# runs-on: ubuntu-latest
|
|
42
|
-
# strategy:
|
|
43
|
-
# fail-fast: false
|
|
44
|
-
# matrix:
|
|
45
|
-
# python-version: ["3.12", "3.13"]
|
|
46
|
-
|
|
47
|
-
# steps:
|
|
48
|
-
# - uses: actions/checkout@v4
|
|
49
|
-
# - name: Set up Python ${{ matrix.python-version }}
|
|
50
|
-
# uses: actions/setup-python@v3
|
|
51
|
-
# with:
|
|
52
|
-
# python-version: ${{ matrix.python-version }}
|
|
53
|
-
# - name: Install dependencies
|
|
54
|
-
# run: |
|
|
55
|
-
# python -m pip install --upgrade pip
|
|
56
|
-
# python -m pip install flake8 pytest
|
|
57
|
-
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
|
58
|
-
# - name: Lint with flake8
|
|
59
|
-
# run: |
|
|
60
|
-
# # stop the build if there are Python syntax errors or undefined names
|
|
61
|
-
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
|
62
|
-
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
|
63
|
-
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
|
64
|
-
# - name: Test with pytest
|
|
65
|
-
# run: |
|
|
66
|
-
# pytest
|
|
@@ -1,129 +0,0 @@
|
|
|
1
|
-
# This workflow will upload a Python Package to PyPI when a release is created
|
|
2
|
-
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
|
3
|
-
|
|
4
|
-
# This workflow uses actions that are not certified by GitHub.
|
|
5
|
-
# They are provided by a third-party and are governed by
|
|
6
|
-
# separate terms of service, privacy policy, and support
|
|
7
|
-
# documentation.
|
|
8
|
-
|
|
9
|
-
name: Build and upload package to PyPI
|
|
10
|
-
|
|
11
|
-
on:
|
|
12
|
-
release:
|
|
13
|
-
types: [published]
|
|
14
|
-
|
|
15
|
-
permissions:
|
|
16
|
-
contents: read
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
jobs:
|
|
20
|
-
|
|
21
|
-
# uv-release-build-publish:
|
|
22
|
-
# # name: python
|
|
23
|
-
# runs-on: ubuntu-latest
|
|
24
|
-
|
|
25
|
-
# steps:
|
|
26
|
-
# - uses: actions/checkout@v5
|
|
27
|
-
|
|
28
|
-
# - name: "Set up Python"
|
|
29
|
-
# uses: actions/setup-python@v5
|
|
30
|
-
# with:
|
|
31
|
-
# python-version-file: "pyproject.toml"
|
|
32
|
-
|
|
33
|
-
# - name: Install uv
|
|
34
|
-
# uses: astral-sh/setup-uv@v6
|
|
35
|
-
|
|
36
|
-
# - name: Install the project
|
|
37
|
-
# run: uv sync --all-extras --dev
|
|
38
|
-
# # TODO Better to use --locked for author control over versions?
|
|
39
|
-
# # run: uv sync --locked --all-extras --dev
|
|
40
|
-
|
|
41
|
-
# - name: Build release distributions
|
|
42
|
-
# run: uv build
|
|
43
|
-
|
|
44
|
-
uv-build-release-pypi-publish:
|
|
45
|
-
name: "Build release distribution and publish to PyPI"
|
|
46
|
-
runs-on: ubuntu-latest
|
|
47
|
-
# needs:
|
|
48
|
-
# - uv-release-build
|
|
49
|
-
environment:
|
|
50
|
-
name: pypi
|
|
51
|
-
|
|
52
|
-
steps:
|
|
53
|
-
- uses: actions/checkout@v5
|
|
54
|
-
|
|
55
|
-
- name: "Set up Python"
|
|
56
|
-
uses: actions/setup-python@v5
|
|
57
|
-
with:
|
|
58
|
-
python-version-file: "pyproject.toml"
|
|
59
|
-
|
|
60
|
-
- name: Install uv
|
|
61
|
-
uses: astral-sh/setup-uv@v6
|
|
62
|
-
|
|
63
|
-
- name: Install the project
|
|
64
|
-
run: uv sync --all-extras --dev
|
|
65
|
-
# TODO Better to use --locked for author control over versions?
|
|
66
|
-
# run: uv sync --locked --all-extras --dev
|
|
67
|
-
|
|
68
|
-
- name: Build release distributions
|
|
69
|
-
run: uv build
|
|
70
|
-
|
|
71
|
-
- name: Publish to PyPI
|
|
72
|
-
env:
|
|
73
|
-
UV_PUBLISH_TOKEN: ${{ secrets.UV_PUBLISH_TOKEN }}
|
|
74
|
-
run: uv publish
|
|
75
|
-
|
|
76
|
-
# TODO Original variant kept for reference
|
|
77
|
-
# jobs:
|
|
78
|
-
# release-build:
|
|
79
|
-
# runs-on: ubuntu-latest
|
|
80
|
-
|
|
81
|
-
# steps:
|
|
82
|
-
# - uses: actions/checkout@v4
|
|
83
|
-
|
|
84
|
-
# - uses: actions/setup-python@v5
|
|
85
|
-
# with:
|
|
86
|
-
# python-version: "3.x"
|
|
87
|
-
|
|
88
|
-
# - name: Build release distributions
|
|
89
|
-
# run: |
|
|
90
|
-
# # NOTE: put your own distribution build steps here.
|
|
91
|
-
# python -m pip install build
|
|
92
|
-
# python -m build
|
|
93
|
-
|
|
94
|
-
# - name: Upload distributions
|
|
95
|
-
# uses: actions/upload-artifact@v4
|
|
96
|
-
# with:
|
|
97
|
-
# name: release-dists
|
|
98
|
-
# path: dist/
|
|
99
|
-
|
|
100
|
-
# pypi-publish:
|
|
101
|
-
# runs-on: ubuntu-latest
|
|
102
|
-
# needs:
|
|
103
|
-
# - release-build
|
|
104
|
-
# permissions:
|
|
105
|
-
# # IMPORTANT: this permission is mandatory for trusted publishing
|
|
106
|
-
# id-token: write
|
|
107
|
-
|
|
108
|
-
# # Dedicated environments with protections for publishing are strongly recommended.
|
|
109
|
-
# # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules
|
|
110
|
-
# environment:
|
|
111
|
-
# name: pypi
|
|
112
|
-
# # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
|
|
113
|
-
# # url: https://pypi.org/p/YOURPROJECT
|
|
114
|
-
# #
|
|
115
|
-
# # ALTERNATIVE: if your GitHub Release name is the PyPI project version string
|
|
116
|
-
# # ALTERNATIVE: exactly, uncomment the following line instead:
|
|
117
|
-
# # url: https://pypi.org/project/YOURPROJECT/${{ github.event.release.name }}
|
|
118
|
-
|
|
119
|
-
# steps:
|
|
120
|
-
# - name: Retrieve release distributions
|
|
121
|
-
# uses: actions/download-artifact@v4
|
|
122
|
-
# with:
|
|
123
|
-
# name: release-dists
|
|
124
|
-
# path: dist/
|
|
125
|
-
|
|
126
|
-
# - name: Publish release distributions to PyPI
|
|
127
|
-
# uses: pypa/gh-action-pypi-publish@release/v1
|
|
128
|
-
# with:
|
|
129
|
-
# packages-dir: dist/
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
"""Assorted helper methods for `ekumen`"""
|
|
2
|
-
|
|
3
|
-
##
|
|
4
|
-
# Imports
|
|
5
|
-
|
|
6
|
-
from io import BytesIO
|
|
7
|
-
import ormsgpack as omp
|
|
8
|
-
|
|
9
|
-
import numpy as np
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
##
|
|
13
|
-
#
|
|
14
|
-
|
|
15
|
-
def pack_instance( x ) -> bytes:
|
|
16
|
-
return omp.packb( x )
|
|
17
|
-
|
|
18
|
-
def unpack( bs: bytes ):
|
|
19
|
-
return omp.unpackb( bs )
|
|
20
|
-
|
|
21
|
-
##
|
|
22
|
-
|
|
23
|
-
def array_to_bytes(x: np.ndarray) -> bytes:
|
|
24
|
-
np_bytes = BytesIO()
|
|
25
|
-
np.save(np_bytes, x, allow_pickle=True)
|
|
26
|
-
return np_bytes.getvalue()
|
|
27
|
-
|
|
28
|
-
def bytes_to_array(b: bytes) -> np.ndarray:
|
|
29
|
-
np_bytes = BytesIO(b)
|
|
30
|
-
return np.load(np_bytes, allow_pickle=True)
|
|
@@ -1,69 +0,0 @@
|
|
|
1
|
-
"""Test dataaset functionality."""
|
|
2
|
-
|
|
3
|
-
##
|
|
4
|
-
|
|
5
|
-
import pytest
|
|
6
|
-
|
|
7
|
-
from dataclasses import dataclass
|
|
8
|
-
|
|
9
|
-
import numpy as np
|
|
10
|
-
|
|
11
|
-
from numpy.typing import NDArray
|
|
12
|
-
from typing import (
|
|
13
|
-
Type,
|
|
14
|
-
Any,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
import atdata.dataset as ekd
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
## Sample test cases
|
|
21
|
-
|
|
22
|
-
@dataclass
|
|
23
|
-
class BasicTestSample( ekd.PackableSample ):
|
|
24
|
-
name: str
|
|
25
|
-
position: int
|
|
26
|
-
value: float
|
|
27
|
-
|
|
28
|
-
@dataclass
|
|
29
|
-
class NumpyTestSample( ekd.PackableSample ):
|
|
30
|
-
label: int
|
|
31
|
-
image: NDArray
|
|
32
|
-
|
|
33
|
-
test_sample_classes = [
|
|
34
|
-
(
|
|
35
|
-
BasicTestSample, {
|
|
36
|
-
'name': 'Hello, world!',
|
|
37
|
-
'position': 42,
|
|
38
|
-
'value': 1024.768,
|
|
39
|
-
}
|
|
40
|
-
),
|
|
41
|
-
(
|
|
42
|
-
NumpyTestSample, {
|
|
43
|
-
'label': 9_001,
|
|
44
|
-
'image': np.random.randn( 1024, 1024 ),
|
|
45
|
-
}
|
|
46
|
-
)
|
|
47
|
-
]
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
## Tests
|
|
51
|
-
|
|
52
|
-
@pytest.mark.parametrize( ('SampleType', 'sample_data'), test_sample_classes )
|
|
53
|
-
def test_create_sample(
|
|
54
|
-
SampleType: Type[ekd.PackableSample],
|
|
55
|
-
sample_data: ekd.MsgpackRawSample,
|
|
56
|
-
):
|
|
57
|
-
"""
|
|
58
|
-
Test our ability to create samples from semi-structured data
|
|
59
|
-
"""
|
|
60
|
-
sample = SampleType.from_data( sample_data )
|
|
61
|
-
assert isinstance( sample, SampleType ), f'Did not properly form sample for test type {SampleType}'
|
|
62
|
-
|
|
63
|
-
for k, v in sample_data.items():
|
|
64
|
-
cur_assertion: bool
|
|
65
|
-
if isinstance( v, np.ndarray ):
|
|
66
|
-
cur_assertion = np.all( getattr( sample, k ) == v ) == True
|
|
67
|
-
else:
|
|
68
|
-
cur_assertion = getattr( sample, k ) == v
|
|
69
|
-
assert cur_assertion, f'Did not properly incorporate property {k} of test type {SampleType}'
|
|
File without changes
|
|
File without changes
|
|
File without changes
|