cyreal 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cyreal-0.1.1/PKG-INFO +179 -0
- cyreal-0.1.1/README.md +156 -0
- cyreal-0.1.1/cyreal/__init__.py +38 -0
- cyreal-0.1.1/cyreal/dataset_protocol.py +14 -0
- cyreal-0.1.1/cyreal/datasets/__init__.py +14 -0
- cyreal-0.1.1/cyreal/datasets/cifar10.py +194 -0
- cyreal-0.1.1/cyreal/datasets/mnist.py +219 -0
- cyreal-0.1.1/cyreal/loader.py +145 -0
- cyreal-0.1.1/cyreal/sources.py +496 -0
- cyreal-0.1.1/cyreal/transforms.py +497 -0
- cyreal-0.1.1/cyreal.egg-info/PKG-INFO +179 -0
- cyreal-0.1.1/cyreal.egg-info/SOURCES.txt +17 -0
- cyreal-0.1.1/cyreal.egg-info/dependency_links.txt +1 -0
- cyreal-0.1.1/cyreal.egg-info/requires.txt +6 -0
- cyreal-0.1.1/cyreal.egg-info/top_level.txt +1 -0
- cyreal-0.1.1/setup.cfg +4 -0
- cyreal-0.1.1/setup.py +32 -0
- cyreal-0.1.1/tests/test_dataloader.py +625 -0
- cyreal-0.1.1/tests/test_readme_examples.py +180 -0
cyreal-0.1.1/PKG-INFO
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cyreal
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: Jittable data loading utilities for JAX.
|
|
5
|
+
Author:
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Requires-Python: >=3.10
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: jax
|
|
12
|
+
Requires-Dist: jaxlib
|
|
13
|
+
Requires-Dist: numpy
|
|
14
|
+
Provides-Extra: dev
|
|
15
|
+
Requires-Dist: pytest; extra == "dev"
|
|
16
|
+
Dynamic: classifier
|
|
17
|
+
Dynamic: description
|
|
18
|
+
Dynamic: description-content-type
|
|
19
|
+
Dynamic: provides-extra
|
|
20
|
+
Dynamic: requires-dist
|
|
21
|
+
Dynamic: requires-python
|
|
22
|
+
Dynamic: summary
|
|
23
|
+
|
|
24
|
+
# Cyreal - Another JAX DataLoader
|
|
25
|
+
|
|
26
|
+
> `grain` for the corporations, `cyreal` for the people
|
|
27
|
+
|
|
28
|
+
Pure `jax` utilities for iterating over finite datasets without ever touching `torch` or `tensorflow`. Dataloaders support `jax.jit`, `jax.grad`, `jax.lax.scan`, and other function transformations.
|
|
29
|
+
|
|
30
|
+
## Installation
|
|
31
|
+
|
|
32
|
+
The only dependency is `jax`. On GPU machines, install the
|
|
33
|
+
appropriate JAX build for your CUDA version.
|
|
34
|
+
|
|
35
|
+
`pip install cyreal`
|
|
36
|
+
|
|
37
|
+
## Quick start with MNIST
|
|
38
|
+
Write `torch`-style dataloaders without `torch`
|
|
39
|
+
|
|
40
|
+
```python
|
|
41
|
+
import jax
|
|
42
|
+
import jax.numpy as jnp
|
|
43
|
+
|
|
44
|
+
from cyreal import (
|
|
45
|
+
ArraySampleSource,
|
|
46
|
+
BatchTransform,
|
|
47
|
+
DataLoader,
|
|
48
|
+
DevicePutTransform,
|
|
49
|
+
MNISTDataset,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
train_data = MNISTDataset(split="train").as_array_dict()
|
|
53
|
+
pipeline = [
|
|
54
|
+
ArraySampleSource(train_data, ordering="shuffle"),
|
|
55
|
+
BatchTransform(batch_size=128),
|
|
56
|
+
DevicePutTransform(),
|
|
57
|
+
]
|
|
58
|
+
loader = DataLoader(pipeline=pipeline)
|
|
59
|
+
state = loader.init_state(jax.random.PRNGKey(0))
|
|
60
|
+
|
|
61
|
+
for batch, mask in loader.iterate(state):
|
|
62
|
+
... # train your network!
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
## Scan and Avoid Boilerplate
|
|
66
|
+
|
|
67
|
+
`DataLoader.scan_epoch` will run a full pass through the dataset into a single
|
|
68
|
+
`jax.lax.scan` to minimize dispatch overhead. This will `jit` the `body_fn`.
|
|
69
|
+
|
|
70
|
+
```python
|
|
71
|
+
def body_fn(model_state, batch, mask):
|
|
72
|
+
model_state = update_model(model_state, batch, mask)
|
|
73
|
+
return model_state, None
|
|
74
|
+
|
|
75
|
+
loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
## JIT Capabilities
|
|
79
|
+
Do you enjoy premature optimization? Why not `jit` the entire train epoch?
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
import jax
|
|
84
|
+
import jax.numpy as jnp
|
|
85
|
+
|
|
86
|
+
from cyreal import (
|
|
87
|
+
ArraySampleSource,
|
|
88
|
+
BatchTransform,
|
|
89
|
+
DataLoader,
|
|
90
|
+
DevicePutTransform,
|
|
91
|
+
MNISTDataset,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
train_data = MNISTDataset(split="train").as_array_dict()
|
|
95
|
+
pipeline = [
|
|
96
|
+
ArraySampleSource(train_data, ordering="shuffle"),
|
|
97
|
+
BatchTransform(batch_size=128),
|
|
98
|
+
DevicePutTransform(),
|
|
99
|
+
]
|
|
100
|
+
loader = DataLoader(pipeline)
|
|
101
|
+
loader_state = loader.init_state(jax.random.PRNGKey(0))
|
|
102
|
+
model_state = model_init()
|
|
103
|
+
|
|
104
|
+
@jax.jit
|
|
105
|
+
def train_epoch(model_state, loader_state):
|
|
106
|
+
def body_fn(model_state, batch, mask):
|
|
107
|
+
# Update the network using your train fn
|
|
108
|
+
new_model_state = model_update(model_state, batch, mask)
|
|
109
|
+
return new_model_state, None
|
|
110
|
+
|
|
111
|
+
loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
|
|
112
|
+
return model_state, loader_state
|
|
113
|
+
|
|
114
|
+
model_state, loader_state = train_epoch(model_state, loader_state)
|
|
115
|
+
```
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
## Streaming from Disk
|
|
119
|
+
Is your dataset enormous? Swap in a disk-backed source.
|
|
120
|
+
|
|
121
|
+
```python
|
|
122
|
+
import jax
|
|
123
|
+
|
|
124
|
+
from cyreal import (
|
|
125
|
+
BatchTransform,
|
|
126
|
+
DataLoader,
|
|
127
|
+
DevicePutTransform,
|
|
128
|
+
MNISTDataset,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
pipeline = [
|
|
132
|
+
MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
|
|
133
|
+
BatchTransform(batch_size=128),
|
|
134
|
+
DevicePutTransform(),
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
loader = DataLoader(pipeline=pipeline)
|
|
138
|
+
state = loader.init_state(jax.random.PRNGKey(0))
|
|
139
|
+
|
|
140
|
+
for batch, mask in loader.iterate(state):
|
|
141
|
+
... # stream without holding the dataset in RAM
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
## For the Dirty and Impure
|
|
145
|
+
Want to `jit` but also log some metrics? Use `HostCallbackTransform` which utilizes `jax.experimental.io_callback` under the hood.
|
|
146
|
+
|
|
147
|
+
```python
|
|
148
|
+
import jax.numpy as jnp
|
|
149
|
+
import numpy as np
|
|
150
|
+
|
|
151
|
+
from cyreal import (
|
|
152
|
+
ArraySampleSource,
|
|
153
|
+
BatchTransform,
|
|
154
|
+
DataLoader,
|
|
155
|
+
HostCallbackTransform,
|
|
156
|
+
MNISTDataset,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def model(images):
|
|
160
|
+
return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))
|
|
161
|
+
|
|
162
|
+
def cross_entropy(logits, labels):
|
|
163
|
+
labels = labels.astype(jnp.float32)
|
|
164
|
+
return (logits - labels) ** 2
|
|
165
|
+
|
|
166
|
+
def log_loss(batch, mask):
|
|
167
|
+
logits = model(batch["image"])
|
|
168
|
+
loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
|
|
169
|
+
print("loss:", float(np.asarray(loss)))
|
|
170
|
+
return batch
|
|
171
|
+
|
|
172
|
+
loader = DataLoader(
|
|
173
|
+
pipeline=[
|
|
174
|
+
ArraySampleSource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
|
|
175
|
+
BatchTransform(batch_size=128),
|
|
176
|
+
HostCallbackTransform(fn=log_loss),
|
|
177
|
+
],
|
|
178
|
+
)
|
|
179
|
+
```
|
cyreal-0.1.1/README.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
# Cyreal - Another JAX DataLoader
|
|
2
|
+
|
|
3
|
+
> `grain` for the corporations, `cyreal` for the people
|
|
4
|
+
|
|
5
|
+
Pure `jax` utilities for iterating over finite datasets without ever touching `torch` or `tensorflow`. Dataloaders support `jax.jit`, `jax.grad`, `jax.lax.scan`, and other function transformations.
|
|
6
|
+
|
|
7
|
+
## Installation
|
|
8
|
+
|
|
9
|
+
The only dependency is `jax`. On GPU machines, install the
|
|
10
|
+
appropriate JAX build for your CUDA version.
|
|
11
|
+
|
|
12
|
+
`pip install cyreal`
|
|
13
|
+
|
|
14
|
+
## Quick start with MNIST
|
|
15
|
+
Write `torch`-style dataloaders without `torch`
|
|
16
|
+
|
|
17
|
+
```python
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
|
|
21
|
+
from cyreal import (
|
|
22
|
+
ArraySampleSource,
|
|
23
|
+
BatchTransform,
|
|
24
|
+
DataLoader,
|
|
25
|
+
DevicePutTransform,
|
|
26
|
+
MNISTDataset,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
train_data = MNISTDataset(split="train").as_array_dict()
|
|
30
|
+
pipeline = [
|
|
31
|
+
ArraySampleSource(train_data, ordering="shuffle"),
|
|
32
|
+
BatchTransform(batch_size=128),
|
|
33
|
+
DevicePutTransform(),
|
|
34
|
+
]
|
|
35
|
+
loader = DataLoader(pipeline=pipeline)
|
|
36
|
+
state = loader.init_state(jax.random.PRNGKey(0))
|
|
37
|
+
|
|
38
|
+
for batch, mask in loader.iterate(state):
|
|
39
|
+
... # train your network!
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
## Scan and Avoid Boilerplate
|
|
43
|
+
|
|
44
|
+
`DataLoader.scan_epoch` will run a full pass through the dataset into a single
|
|
45
|
+
`jax.lax.scan` to minimize dispatch overhead. This will `jit` the `body_fn`.
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
def body_fn(model_state, batch, mask):
|
|
49
|
+
model_state = update_model(model_state, batch, mask)
|
|
50
|
+
return model_state, None
|
|
51
|
+
|
|
52
|
+
loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## JIT Capabilities
|
|
56
|
+
Do you enjoy premature optimization? Why not `jit` the entire train epoch?
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
import jax
|
|
61
|
+
import jax.numpy as jnp
|
|
62
|
+
|
|
63
|
+
from cyreal import (
|
|
64
|
+
ArraySampleSource,
|
|
65
|
+
BatchTransform,
|
|
66
|
+
DataLoader,
|
|
67
|
+
DevicePutTransform,
|
|
68
|
+
MNISTDataset,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
train_data = MNISTDataset(split="train").as_array_dict()
|
|
72
|
+
pipeline = [
|
|
73
|
+
ArraySampleSource(train_data, ordering="shuffle"),
|
|
74
|
+
BatchTransform(batch_size=128),
|
|
75
|
+
DevicePutTransform(),
|
|
76
|
+
]
|
|
77
|
+
loader = DataLoader(pipeline)
|
|
78
|
+
loader_state = loader.init_state(jax.random.PRNGKey(0))
|
|
79
|
+
model_state = model_init()
|
|
80
|
+
|
|
81
|
+
@jax.jit
|
|
82
|
+
def train_epoch(model_state, loader_state):
|
|
83
|
+
def body_fn(model_state, batch, mask):
|
|
84
|
+
# Update the network using your train fn
|
|
85
|
+
new_model_state = model_update(model_state, batch, mask)
|
|
86
|
+
return new_model_state, None
|
|
87
|
+
|
|
88
|
+
loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
|
|
89
|
+
return model_state, loader_state
|
|
90
|
+
|
|
91
|
+
model_state, loader_state = train_epoch(model_state, loader_state)
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
## Streaming from Disk
|
|
96
|
+
Is your dataset enormous? Swap in a disk-backed source.
|
|
97
|
+
|
|
98
|
+
```python
|
|
99
|
+
import jax
|
|
100
|
+
|
|
101
|
+
from cyreal import (
|
|
102
|
+
BatchTransform,
|
|
103
|
+
DataLoader,
|
|
104
|
+
DevicePutTransform,
|
|
105
|
+
MNISTDataset,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
pipeline = [
|
|
109
|
+
MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
|
|
110
|
+
BatchTransform(batch_size=128),
|
|
111
|
+
DevicePutTransform(),
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
loader = DataLoader(pipeline=pipeline)
|
|
115
|
+
state = loader.init_state(jax.random.PRNGKey(0))
|
|
116
|
+
|
|
117
|
+
for batch, mask in loader.iterate(state):
|
|
118
|
+
... # stream without holding the dataset in RAM
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
## For the Dirty and Impure
|
|
122
|
+
Want to `jit` but also log some metrics? Use `HostCallbackTransform` which utilizes `jax.experimental.io_callback` under the hood.
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
import jax.numpy as jnp
|
|
126
|
+
import numpy as np
|
|
127
|
+
|
|
128
|
+
from cyreal import (
|
|
129
|
+
ArraySampleSource,
|
|
130
|
+
BatchTransform,
|
|
131
|
+
DataLoader,
|
|
132
|
+
HostCallbackTransform,
|
|
133
|
+
MNISTDataset,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def model(images):
|
|
137
|
+
return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))
|
|
138
|
+
|
|
139
|
+
def cross_entropy(logits, labels):
|
|
140
|
+
labels = labels.astype(jnp.float32)
|
|
141
|
+
return (logits - labels) ** 2
|
|
142
|
+
|
|
143
|
+
def log_loss(batch, mask):
|
|
144
|
+
logits = model(batch["image"])
|
|
145
|
+
loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
|
|
146
|
+
print("loss:", float(np.asarray(loss)))
|
|
147
|
+
return batch
|
|
148
|
+
|
|
149
|
+
loader = DataLoader(
|
|
150
|
+
pipeline=[
|
|
151
|
+
ArraySampleSource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
|
|
152
|
+
BatchTransform(batch_size=128),
|
|
153
|
+
HostCallbackTransform(fn=log_loss),
|
|
154
|
+
],
|
|
155
|
+
)
|
|
156
|
+
```
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Jittable dataset utilities for JAX."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from .dataset_protocol import DatasetProtocol
|
|
5
|
+
from .datasets import CIFAR10Dataset, CIFAR10DiskSource, MNISTDataset, MNISTDiskSource
|
|
6
|
+
from .loader import (
|
|
7
|
+
DataLoader,
|
|
8
|
+
LoaderState,
|
|
9
|
+
)
|
|
10
|
+
from .sources import ArraySampleSource, DiskSampleSource, GymnaxSource, Source
|
|
11
|
+
from .transforms import (
|
|
12
|
+
BatchTransform,
|
|
13
|
+
DevicePutTransform,
|
|
14
|
+
FlattenTransform,
|
|
15
|
+
HostCallbackTransform,
|
|
16
|
+
MapTransform,
|
|
17
|
+
NormalizeImageTransform,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"DatasetProtocol",
|
|
22
|
+
"DataLoader",
|
|
23
|
+
"LoaderState",
|
|
24
|
+
"CIFAR10Dataset",
|
|
25
|
+
"CIFAR10DiskSource",
|
|
26
|
+
"MNISTDataset",
|
|
27
|
+
"MNISTDiskSource",
|
|
28
|
+
"ArraySampleSource",
|
|
29
|
+
"DiskSampleSource",
|
|
30
|
+
"GymnaxSource",
|
|
31
|
+
"Source",
|
|
32
|
+
"BatchTransform",
|
|
33
|
+
"DevicePutTransform",
|
|
34
|
+
"FlattenTransform",
|
|
35
|
+
"HostCallbackTransform",
|
|
36
|
+
"MapTransform",
|
|
37
|
+
"NormalizeImageTransform",
|
|
38
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Simple dataset protocol used by the jittable dataloader."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any, Protocol
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DatasetProtocol(Protocol):
|
|
8
|
+
"""Minimal interface for indexable, length-known datasets."""
|
|
9
|
+
|
|
10
|
+
def __len__(self) -> int:
|
|
11
|
+
...
|
|
12
|
+
|
|
13
|
+
def __getitem__(self, index: int) -> Any:
|
|
14
|
+
...
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Dataset helpers bundled with cyreal."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from .cifar10 import CIFAR10Dataset, CIFAR10DiskSource
|
|
5
|
+
from .mnist import MNISTDataset, MNISTDiskSource, MNIST_URLS, _ensure_file
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CIFAR10Dataset",
|
|
9
|
+
"CIFAR10DiskSource",
|
|
10
|
+
"MNISTDataset",
|
|
11
|
+
"MNISTDiskSource",
|
|
12
|
+
"MNIST_URLS",
|
|
13
|
+
"_ensure_file",
|
|
14
|
+
]
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""CIFAR-10 dataset that stays within NumPy/JAX dependencies."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import pickle
|
|
5
|
+
import tarfile
|
|
6
|
+
import urllib.request
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Literal
|
|
10
|
+
|
|
11
|
+
import jax
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from ..dataset_protocol import DatasetProtocol
|
|
16
|
+
from ..sources import DiskSampleSource
|
|
17
|
+
|
|
18
|
+
CIFAR10_URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _to_host_jax_array(array: np.ndarray) -> jax.Array:
|
|
22
|
+
cpu_devices = jax.devices("cpu")
|
|
23
|
+
if cpu_devices:
|
|
24
|
+
with jax.default_device(cpu_devices[0]):
|
|
25
|
+
return jnp.asarray(array)
|
|
26
|
+
return jnp.asarray(array)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _resolve_cache_dir(cache_dir: str | Path | None) -> Path:
|
|
30
|
+
base_dir = Path(cache_dir) if cache_dir is not None else Path.home() / ".cache" / "cyreal_cifar10"
|
|
31
|
+
base_dir.mkdir(parents=True, exist_ok=True)
|
|
32
|
+
return base_dir
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _download(url: str, path: Path) -> Path:
|
|
36
|
+
if path.exists():
|
|
37
|
+
return path
|
|
38
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
urllib.request.urlretrieve(url, path)
|
|
40
|
+
return path
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _ensure_extracted(archive: Path, extract_root: Path) -> Path:
|
|
44
|
+
target = extract_root / "cifar-10-batches-py"
|
|
45
|
+
if target.exists():
|
|
46
|
+
return target
|
|
47
|
+
with tarfile.open(archive, "r:gz") as tar:
|
|
48
|
+
tar.extractall(path=extract_root)
|
|
49
|
+
return target
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _batch_names(split: Literal["train", "test"]) -> list[str]:
|
|
53
|
+
if split == "train":
|
|
54
|
+
return [f"data_batch_{i}" for i in range(1, 6)]
|
|
55
|
+
if split == "test":
|
|
56
|
+
return ["test_batch"]
|
|
57
|
+
raise ValueError("split must be 'train' or 'test'.")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _load_split_numpy(split: Literal["train", "test"], extract_dir: Path) -> tuple[np.ndarray, np.ndarray]:
|
|
61
|
+
images = []
|
|
62
|
+
labels = []
|
|
63
|
+
for name in _batch_names(split):
|
|
64
|
+
batch_path = extract_dir / name
|
|
65
|
+
if not batch_path.exists():
|
|
66
|
+
raise FileNotFoundError(f"Missing CIFAR-10 batch '{name}'.")
|
|
67
|
+
with open(batch_path, "rb") as f:
|
|
68
|
+
batch = pickle.load(f, encoding="latin1")
|
|
69
|
+
data = batch["data"].reshape(-1, 3, 32, 32)
|
|
70
|
+
images.append(np.transpose(data, (0, 2, 3, 1)).astype(np.uint8))
|
|
71
|
+
labels.append(np.asarray(batch["labels"], dtype=np.int32))
|
|
72
|
+
|
|
73
|
+
images_np = np.concatenate(images, axis=0)
|
|
74
|
+
labels_np = np.concatenate(labels, axis=0)
|
|
75
|
+
return images_np, labels_np
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _ensure_split_numpy_cache(
|
|
79
|
+
split: Literal["train", "test"],
|
|
80
|
+
base_dir: Path,
|
|
81
|
+
extract_dir: Path,
|
|
82
|
+
) -> tuple[Path, Path]:
|
|
83
|
+
cache_root = base_dir / "disk_cache"
|
|
84
|
+
cache_root.mkdir(parents=True, exist_ok=True)
|
|
85
|
+
images_path = cache_root / f"{split}_images.npy"
|
|
86
|
+
labels_path = cache_root / f"{split}_labels.npy"
|
|
87
|
+
if not images_path.exists() or not labels_path.exists():
|
|
88
|
+
images_np, labels_np = _load_split_numpy(split, extract_dir)
|
|
89
|
+
np.save(images_path, images_np)
|
|
90
|
+
np.save(labels_path, labels_np)
|
|
91
|
+
return images_path, labels_path
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class CIFAR10Dataset(DatasetProtocol):
|
|
96
|
+
"""Download-free CIFAR-10 access ready for `ArraySampleSource`."""
|
|
97
|
+
|
|
98
|
+
split: Literal["train", "test"] = "train"
|
|
99
|
+
cache_dir: str | Path | None = None
|
|
100
|
+
|
|
101
|
+
def __post_init__(self) -> None:
|
|
102
|
+
base_dir = _resolve_cache_dir(self.cache_dir)
|
|
103
|
+
archive_path = base_dir / "cifar-10-python.tar.gz"
|
|
104
|
+
_download(CIFAR10_URL, archive_path)
|
|
105
|
+
extract_dir = _ensure_extracted(archive_path, base_dir)
|
|
106
|
+
|
|
107
|
+
images_np, labels_np = _load_split_numpy(self.split, extract_dir)
|
|
108
|
+
self._images = _to_host_jax_array(images_np)
|
|
109
|
+
self._labels = _to_host_jax_array(labels_np)
|
|
110
|
+
|
|
111
|
+
def __len__(self) -> int:
|
|
112
|
+
return int(self._images.shape[0])
|
|
113
|
+
|
|
114
|
+
def __getitem__(self, index: int):
|
|
115
|
+
return {
|
|
116
|
+
"image": self._images[index],
|
|
117
|
+
"label": self._labels[index],
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
def as_array_dict(self) -> dict[str, jax.Array]:
|
|
121
|
+
"""Expose the full dataset as a PyTree of JAX arrays."""
|
|
122
|
+
|
|
123
|
+
return {
|
|
124
|
+
"image": self._images,
|
|
125
|
+
"label": self._labels,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def make_disk_source(
|
|
130
|
+
cls,
|
|
131
|
+
*,
|
|
132
|
+
split: Literal["train", "test"] = "train",
|
|
133
|
+
cache_dir: str | Path | None = None,
|
|
134
|
+
ordering: Literal["sequential", "shuffle"] = "shuffle",
|
|
135
|
+
prefetch_size: int = 64,
|
|
136
|
+
) -> DiskSampleSource:
|
|
137
|
+
base_dir = _resolve_cache_dir(cache_dir)
|
|
138
|
+
archive_path = base_dir / "cifar-10-python.tar.gz"
|
|
139
|
+
_download(CIFAR10_URL, archive_path)
|
|
140
|
+
extract_dir = _ensure_extracted(archive_path, base_dir)
|
|
141
|
+
|
|
142
|
+
images_path, labels_path = _ensure_split_numpy_cache(split, base_dir, extract_dir)
|
|
143
|
+
images_memmap = np.load(images_path, mmap_mode="r")
|
|
144
|
+
labels_memmap = np.load(labels_path, mmap_mode="r")
|
|
145
|
+
|
|
146
|
+
if images_memmap.shape[0] != labels_memmap.shape[0]:
|
|
147
|
+
raise ValueError("CIFAR-10 image and label counts do not match.")
|
|
148
|
+
|
|
149
|
+
def _read_sample(index: int | np.ndarray) -> dict[str, np.ndarray]:
|
|
150
|
+
idx = int(np.asarray(index))
|
|
151
|
+
image = np.asarray(images_memmap[idx], dtype=np.uint8)
|
|
152
|
+
label = np.asarray(labels_memmap[idx], dtype=np.int32)
|
|
153
|
+
return {"image": image, "label": label}
|
|
154
|
+
|
|
155
|
+
sample_spec = {
|
|
156
|
+
"image": jax.ShapeDtypeStruct(shape=tuple(images_memmap.shape[1:]), dtype=jnp.uint8),
|
|
157
|
+
"label": jax.ShapeDtypeStruct(shape=(), dtype=jnp.int32),
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
return DiskSampleSource(
|
|
161
|
+
length=int(labels_memmap.shape[0]),
|
|
162
|
+
sample_fn=_read_sample,
|
|
163
|
+
sample_spec=sample_spec,
|
|
164
|
+
ordering=ordering,
|
|
165
|
+
prefetch_size=prefetch_size,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@dataclass
|
|
170
|
+
class CIFAR10DiskSource:
|
|
171
|
+
"""Stream CIFAR-10 samples directly from cached numpy arrays on disk."""
|
|
172
|
+
|
|
173
|
+
split: Literal["train", "test"] = "train"
|
|
174
|
+
cache_dir: str | Path | None = None
|
|
175
|
+
ordering: Literal["sequential", "shuffle"] = "shuffle"
|
|
176
|
+
prefetch_size: int = 64
|
|
177
|
+
|
|
178
|
+
def __post_init__(self) -> None:
|
|
179
|
+
self._disk_source = CIFAR10Dataset.make_disk_source(
|
|
180
|
+
split=self.split,
|
|
181
|
+
cache_dir=self.cache_dir,
|
|
182
|
+
ordering=self.ordering,
|
|
183
|
+
prefetch_size=self.prefetch_size,
|
|
184
|
+
)
|
|
185
|
+
self.steps_per_epoch = self._disk_source.steps_per_epoch
|
|
186
|
+
|
|
187
|
+
def element_spec(self):
|
|
188
|
+
return self._disk_source.element_spec()
|
|
189
|
+
|
|
190
|
+
def init_state(self, key=None):
|
|
191
|
+
return self._disk_source.init_state(key)
|
|
192
|
+
|
|
193
|
+
def next(self, state):
|
|
194
|
+
return self._disk_source.next(state)
|