cyreal 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cyreal-0.1.1/PKG-INFO ADDED
@@ -0,0 +1,179 @@
1
+ Metadata-Version: 2.4
2
+ Name: cyreal
3
+ Version: 0.1.1
4
+ Summary: Jittable data loading utilities for JAX.
5
+ Author:
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: License :: OSI Approved :: MIT License
8
+ Classifier: Operating System :: OS Independent
9
+ Requires-Python: >=3.10
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: jax
12
+ Requires-Dist: jaxlib
13
+ Requires-Dist: numpy
14
+ Provides-Extra: dev
15
+ Requires-Dist: pytest; extra == "dev"
16
+ Dynamic: classifier
17
+ Dynamic: description
18
+ Dynamic: description-content-type
19
+ Dynamic: provides-extra
20
+ Dynamic: requires-dist
21
+ Dynamic: requires-python
22
+ Dynamic: summary
23
+
24
+ # Cyreal - Another JAX DataLoader
25
+
26
+ > `grain` for the corporations, `cyreal` for the people
27
+
28
+ Pure `jax` utilities for iterating over finite datasets without ever touching `torch` or `tensorflow`. Dataloaders support `jax.jit`, `jax.grad`, `jax.lax.scan`, and other function transformations.
29
+
30
+ ## Installation
31
+
32
+ The only dependency is `jax`. On GPU machines, install the
33
+ appropriate JAX build for your CUDA version.
34
+
35
+ `pip install cyreal`
36
+
37
+ ## Quick start with MNIST
38
+ Write `torch`-style dataloaders without `torch`
39
+
40
+ ```python
41
+ import jax
42
+ import jax.numpy as jnp
43
+
44
+ from cyreal import (
45
+ ArraySampleSource,
46
+ BatchTransform,
47
+ DataLoader,
48
+ DevicePutTransform,
49
+ MNISTDataset,
50
+ )
51
+
52
+ train_data = MNISTDataset(split="train").as_array_dict()
53
+ pipeline = [
54
+ ArraySampleSource(train_data, ordering="shuffle"),
55
+ BatchTransform(batch_size=128),
56
+ DevicePutTransform(),
57
+ ]
58
+ loader = DataLoader(pipeline=pipeline)
59
+ state = loader.init_state(jax.random.PRNGKey(0))
60
+
61
+ for batch, mask in loader.iterate(state):
62
+ ... # train your network!
63
+ ```
64
+
65
+ ## Scan and Avoid Boilerplate
66
+
67
+ `DataLoader.scan_epoch` will run a full pass through the dataset into a single
68
+ `jax.lax.scan` to minimize dispatch overhead. This will `jit` the `body_fn`.
69
+
70
+ ```python
71
+ def body_fn(model_state, batch, mask):
72
+ model_state = update_model(model_state, batch, mask)
73
+ return model_state, None
74
+
75
+ loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
76
+ ```
77
+
78
+ ## JIT Capabilities
79
+ Do you enjoy premature optimization? Why not `jit` the entire train epoch?
80
+
81
+
82
+ ```python
83
+ import jax
84
+ import jax.numpy as jnp
85
+
86
+ from cyreal import (
87
+ ArraySampleSource,
88
+ BatchTransform,
89
+ DataLoader,
90
+ DevicePutTransform,
91
+ MNISTDataset,
92
+ )
93
+
94
+ train_data = MNISTDataset(split="train").as_array_dict()
95
+ pipeline = [
96
+ ArraySampleSource(train_data, ordering="shuffle"),
97
+ BatchTransform(batch_size=128),
98
+ DevicePutTransform(),
99
+ ]
100
+ loader = DataLoader(pipeline)
101
+ loader_state = loader.init_state(jax.random.PRNGKey(0))
102
+ model_state = model_init()
103
+
104
+ @jax.jit
105
+ def train_epoch(model_state, loader_state):
106
+ def body_fn(model_state, batch, mask):
107
+ # Update the network using your train fn
108
+ new_model_state = model_update(model_state, batch, mask)
109
+ return new_model_state, None
110
+
111
+ loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
112
+ return model_state, loader_state
113
+
114
+ model_state, loader_state = train_epoch(model_state, loader_state)
115
+ ```
116
+
117
+
118
+ ## Streaming from Disk
119
+ Is your dataset enormous? Swap in a disk-backed source.
120
+
121
+ ```python
122
+ import jax
123
+
124
+ from cyreal import (
125
+ BatchTransform,
126
+ DataLoader,
127
+ DevicePutTransform,
128
+ MNISTDataset,
129
+ )
130
+
131
+ pipeline = [
132
+ MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
133
+ BatchTransform(batch_size=128),
134
+ DevicePutTransform(),
135
+ ]
136
+
137
+ loader = DataLoader(pipeline=pipeline)
138
+ state = loader.init_state(jax.random.PRNGKey(0))
139
+
140
+ for batch, mask in loader.iterate(state):
141
+ ... # stream without holding the dataset in RAM
142
+ ```
143
+
144
+ ## For the Dirty and Impure
145
+ Want to `jit` but also log some metrics? Use `HostCallbackTransform` which utilizes `jax.experimental.io_callback` under the hood.
146
+
147
+ ```python
148
+ import jax.numpy as jnp
149
+ import numpy as np
150
+
151
+ from cyreal import (
152
+ ArraySampleSource,
153
+ BatchTransform,
154
+ DataLoader,
155
+ HostCallbackTransform,
156
+ MNISTDataset,
157
+ )
158
+
159
+ def model(images):
160
+ return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))
161
+
162
+ def cross_entropy(logits, labels):
163
+ labels = labels.astype(jnp.float32)
164
+ return (logits - labels) ** 2
165
+
166
+ def log_loss(batch, mask):
167
+ logits = model(batch["image"])
168
+ loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
169
+ print("loss:", float(np.asarray(loss)))
170
+ return batch
171
+
172
+ loader = DataLoader(
173
+ pipeline=[
174
+ ArraySampleSource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
175
+ BatchTransform(batch_size=128),
176
+ HostCallbackTransform(fn=log_loss),
177
+ ],
178
+ )
179
+ ```
cyreal-0.1.1/README.md ADDED
@@ -0,0 +1,156 @@
1
+ # Cyreal - Another JAX DataLoader
2
+
3
+ > `grain` for the corporations, `cyreal` for the people
4
+
5
+ Pure `jax` utilities for iterating over finite datasets without ever touching `torch` or `tensorflow`. Dataloaders support `jax.jit`, `jax.grad`, `jax.lax.scan`, and other function transformations.
6
+
7
+ ## Installation
8
+
9
+ The only dependency is `jax`. On GPU machines, install the
10
+ appropriate JAX build for your CUDA version.
11
+
12
+ `pip install cyreal`
13
+
14
+ ## Quick start with MNIST
15
+ Write `torch`-style dataloaders without `torch`
16
+
17
+ ```python
18
+ import jax
19
+ import jax.numpy as jnp
20
+
21
+ from cyreal import (
22
+ ArraySampleSource,
23
+ BatchTransform,
24
+ DataLoader,
25
+ DevicePutTransform,
26
+ MNISTDataset,
27
+ )
28
+
29
+ train_data = MNISTDataset(split="train").as_array_dict()
30
+ pipeline = [
31
+ ArraySampleSource(train_data, ordering="shuffle"),
32
+ BatchTransform(batch_size=128),
33
+ DevicePutTransform(),
34
+ ]
35
+ loader = DataLoader(pipeline=pipeline)
36
+ state = loader.init_state(jax.random.PRNGKey(0))
37
+
38
+ for batch, mask in loader.iterate(state):
39
+ ... # train your network!
40
+ ```
41
+
42
+ ## Scan and Avoid Boilerplate
43
+
44
+ `DataLoader.scan_epoch` will run a full pass through the dataset into a single
45
+ `jax.lax.scan` to minimize dispatch overhead. This will `jit` the `body_fn`.
46
+
47
+ ```python
48
+ def body_fn(model_state, batch, mask):
49
+ model_state = update_model(model_state, batch, mask)
50
+ return model_state, None
51
+
52
+ loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
53
+ ```
54
+
55
+ ## JIT Capabilities
56
+ Do you enjoy premature optimization? Why not `jit` the entire train epoch?
57
+
58
+
59
+ ```python
60
+ import jax
61
+ import jax.numpy as jnp
62
+
63
+ from cyreal import (
64
+ ArraySampleSource,
65
+ BatchTransform,
66
+ DataLoader,
67
+ DevicePutTransform,
68
+ MNISTDataset,
69
+ )
70
+
71
+ train_data = MNISTDataset(split="train").as_array_dict()
72
+ pipeline = [
73
+ ArraySampleSource(train_data, ordering="shuffle"),
74
+ BatchTransform(batch_size=128),
75
+ DevicePutTransform(),
76
+ ]
77
+ loader = DataLoader(pipeline)
78
+ loader_state = loader.init_state(jax.random.PRNGKey(0))
79
+ model_state = model_init()
80
+
81
+ @jax.jit
82
+ def train_epoch(model_state, loader_state):
83
+ def body_fn(model_state, batch, mask):
84
+ # Update the network using your train fn
85
+ new_model_state = model_update(model_state, batch, mask)
86
+ return new_model_state, None
87
+
88
+ loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
89
+ return model_state, loader_state
90
+
91
+ model_state, loader_state = train_epoch(model_state, loader_state)
92
+ ```
93
+
94
+
95
+ ## Streaming from Disk
96
+ Is your dataset enormous? Swap in a disk-backed source.
97
+
98
+ ```python
99
+ import jax
100
+
101
+ from cyreal import (
102
+ BatchTransform,
103
+ DataLoader,
104
+ DevicePutTransform,
105
+ MNISTDataset,
106
+ )
107
+
108
+ pipeline = [
109
+ MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
110
+ BatchTransform(batch_size=128),
111
+ DevicePutTransform(),
112
+ ]
113
+
114
+ loader = DataLoader(pipeline=pipeline)
115
+ state = loader.init_state(jax.random.PRNGKey(0))
116
+
117
+ for batch, mask in loader.iterate(state):
118
+ ... # stream without holding the dataset in RAM
119
+ ```
120
+
121
+ ## For the Dirty and Impure
122
+ Want to `jit` but also log some metrics? Use `HostCallbackTransform` which utilizes `jax.experimental.io_callback` under the hood.
123
+
124
+ ```python
125
+ import jax.numpy as jnp
126
+ import numpy as np
127
+
128
+ from cyreal import (
129
+ ArraySampleSource,
130
+ BatchTransform,
131
+ DataLoader,
132
+ HostCallbackTransform,
133
+ MNISTDataset,
134
+ )
135
+
136
+ def model(images):
137
+ return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))
138
+
139
+ def cross_entropy(logits, labels):
140
+ labels = labels.astype(jnp.float32)
141
+ return (logits - labels) ** 2
142
+
143
+ def log_loss(batch, mask):
144
+ logits = model(batch["image"])
145
+ loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
146
+ print("loss:", float(np.asarray(loss)))
147
+ return batch
148
+
149
+ loader = DataLoader(
150
+ pipeline=[
151
+ ArraySampleSource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
152
+ BatchTransform(batch_size=128),
153
+ HostCallbackTransform(fn=log_loss),
154
+ ],
155
+ )
156
+ ```
@@ -0,0 +1,38 @@
1
+ """Jittable dataset utilities for JAX."""
2
+ from __future__ import annotations
3
+
4
+ from .dataset_protocol import DatasetProtocol
5
+ from .datasets import CIFAR10Dataset, CIFAR10DiskSource, MNISTDataset, MNISTDiskSource
6
+ from .loader import (
7
+ DataLoader,
8
+ LoaderState,
9
+ )
10
+ from .sources import ArraySampleSource, DiskSampleSource, GymnaxSource, Source
11
+ from .transforms import (
12
+ BatchTransform,
13
+ DevicePutTransform,
14
+ FlattenTransform,
15
+ HostCallbackTransform,
16
+ MapTransform,
17
+ NormalizeImageTransform,
18
+ )
19
+
20
+ __all__ = [
21
+ "DatasetProtocol",
22
+ "DataLoader",
23
+ "LoaderState",
24
+ "CIFAR10Dataset",
25
+ "CIFAR10DiskSource",
26
+ "MNISTDataset",
27
+ "MNISTDiskSource",
28
+ "ArraySampleSource",
29
+ "DiskSampleSource",
30
+ "GymnaxSource",
31
+ "Source",
32
+ "BatchTransform",
33
+ "DevicePutTransform",
34
+ "FlattenTransform",
35
+ "HostCallbackTransform",
36
+ "MapTransform",
37
+ "NormalizeImageTransform",
38
+ ]
@@ -0,0 +1,14 @@
1
+ """Simple dataset protocol used by the jittable dataloader."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Protocol
5
+
6
+
7
+ class DatasetProtocol(Protocol):
8
+ """Minimal interface for indexable, length-known datasets."""
9
+
10
+ def __len__(self) -> int:
11
+ ...
12
+
13
+ def __getitem__(self, index: int) -> Any:
14
+ ...
@@ -0,0 +1,14 @@
1
+ """Dataset helpers bundled with cyreal."""
2
+ from __future__ import annotations
3
+
4
+ from .cifar10 import CIFAR10Dataset, CIFAR10DiskSource
5
+ from .mnist import MNISTDataset, MNISTDiskSource, MNIST_URLS, _ensure_file
6
+
7
+ __all__ = [
8
+ "CIFAR10Dataset",
9
+ "CIFAR10DiskSource",
10
+ "MNISTDataset",
11
+ "MNISTDiskSource",
12
+ "MNIST_URLS",
13
+ "_ensure_file",
14
+ ]
@@ -0,0 +1,194 @@
1
+ """CIFAR-10 dataset that stays within NumPy/JAX dependencies."""
2
+ from __future__ import annotations
3
+
4
+ import pickle
5
+ import tarfile
6
+ import urllib.request
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Literal
10
+
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import numpy as np
14
+
15
+ from ..dataset_protocol import DatasetProtocol
16
+ from ..sources import DiskSampleSource
17
+
18
+ CIFAR10_URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
19
+
20
+
21
+ def _to_host_jax_array(array: np.ndarray) -> jax.Array:
22
+ cpu_devices = jax.devices("cpu")
23
+ if cpu_devices:
24
+ with jax.default_device(cpu_devices[0]):
25
+ return jnp.asarray(array)
26
+ return jnp.asarray(array)
27
+
28
+
29
+ def _resolve_cache_dir(cache_dir: str | Path | None) -> Path:
30
+ base_dir = Path(cache_dir) if cache_dir is not None else Path.home() / ".cache" / "cyreal_cifar10"
31
+ base_dir.mkdir(parents=True, exist_ok=True)
32
+ return base_dir
33
+
34
+
35
+ def _download(url: str, path: Path) -> Path:
36
+ if path.exists():
37
+ return path
38
+ path.parent.mkdir(parents=True, exist_ok=True)
39
+ urllib.request.urlretrieve(url, path)
40
+ return path
41
+
42
+
43
+ def _ensure_extracted(archive: Path, extract_root: Path) -> Path:
44
+ target = extract_root / "cifar-10-batches-py"
45
+ if target.exists():
46
+ return target
47
+ with tarfile.open(archive, "r:gz") as tar:
48
+ tar.extractall(path=extract_root)
49
+ return target
50
+
51
+
52
+ def _batch_names(split: Literal["train", "test"]) -> list[str]:
53
+ if split == "train":
54
+ return [f"data_batch_{i}" for i in range(1, 6)]
55
+ if split == "test":
56
+ return ["test_batch"]
57
+ raise ValueError("split must be 'train' or 'test'.")
58
+
59
+
60
+ def _load_split_numpy(split: Literal["train", "test"], extract_dir: Path) -> tuple[np.ndarray, np.ndarray]:
61
+ images = []
62
+ labels = []
63
+ for name in _batch_names(split):
64
+ batch_path = extract_dir / name
65
+ if not batch_path.exists():
66
+ raise FileNotFoundError(f"Missing CIFAR-10 batch '{name}'.")
67
+ with open(batch_path, "rb") as f:
68
+ batch = pickle.load(f, encoding="latin1")
69
+ data = batch["data"].reshape(-1, 3, 32, 32)
70
+ images.append(np.transpose(data, (0, 2, 3, 1)).astype(np.uint8))
71
+ labels.append(np.asarray(batch["labels"], dtype=np.int32))
72
+
73
+ images_np = np.concatenate(images, axis=0)
74
+ labels_np = np.concatenate(labels, axis=0)
75
+ return images_np, labels_np
76
+
77
+
78
+ def _ensure_split_numpy_cache(
79
+ split: Literal["train", "test"],
80
+ base_dir: Path,
81
+ extract_dir: Path,
82
+ ) -> tuple[Path, Path]:
83
+ cache_root = base_dir / "disk_cache"
84
+ cache_root.mkdir(parents=True, exist_ok=True)
85
+ images_path = cache_root / f"{split}_images.npy"
86
+ labels_path = cache_root / f"{split}_labels.npy"
87
+ if not images_path.exists() or not labels_path.exists():
88
+ images_np, labels_np = _load_split_numpy(split, extract_dir)
89
+ np.save(images_path, images_np)
90
+ np.save(labels_path, labels_np)
91
+ return images_path, labels_path
92
+
93
+
94
+ @dataclass
95
+ class CIFAR10Dataset(DatasetProtocol):
96
+ """Download-free CIFAR-10 access ready for `ArraySampleSource`."""
97
+
98
+ split: Literal["train", "test"] = "train"
99
+ cache_dir: str | Path | None = None
100
+
101
+ def __post_init__(self) -> None:
102
+ base_dir = _resolve_cache_dir(self.cache_dir)
103
+ archive_path = base_dir / "cifar-10-python.tar.gz"
104
+ _download(CIFAR10_URL, archive_path)
105
+ extract_dir = _ensure_extracted(archive_path, base_dir)
106
+
107
+ images_np, labels_np = _load_split_numpy(self.split, extract_dir)
108
+ self._images = _to_host_jax_array(images_np)
109
+ self._labels = _to_host_jax_array(labels_np)
110
+
111
+ def __len__(self) -> int:
112
+ return int(self._images.shape[0])
113
+
114
+ def __getitem__(self, index: int):
115
+ return {
116
+ "image": self._images[index],
117
+ "label": self._labels[index],
118
+ }
119
+
120
+ def as_array_dict(self) -> dict[str, jax.Array]:
121
+ """Expose the full dataset as a PyTree of JAX arrays."""
122
+
123
+ return {
124
+ "image": self._images,
125
+ "label": self._labels,
126
+ }
127
+
128
+ @classmethod
129
+ def make_disk_source(
130
+ cls,
131
+ *,
132
+ split: Literal["train", "test"] = "train",
133
+ cache_dir: str | Path | None = None,
134
+ ordering: Literal["sequential", "shuffle"] = "shuffle",
135
+ prefetch_size: int = 64,
136
+ ) -> DiskSampleSource:
137
+ base_dir = _resolve_cache_dir(cache_dir)
138
+ archive_path = base_dir / "cifar-10-python.tar.gz"
139
+ _download(CIFAR10_URL, archive_path)
140
+ extract_dir = _ensure_extracted(archive_path, base_dir)
141
+
142
+ images_path, labels_path = _ensure_split_numpy_cache(split, base_dir, extract_dir)
143
+ images_memmap = np.load(images_path, mmap_mode="r")
144
+ labels_memmap = np.load(labels_path, mmap_mode="r")
145
+
146
+ if images_memmap.shape[0] != labels_memmap.shape[0]:
147
+ raise ValueError("CIFAR-10 image and label counts do not match.")
148
+
149
+ def _read_sample(index: int | np.ndarray) -> dict[str, np.ndarray]:
150
+ idx = int(np.asarray(index))
151
+ image = np.asarray(images_memmap[idx], dtype=np.uint8)
152
+ label = np.asarray(labels_memmap[idx], dtype=np.int32)
153
+ return {"image": image, "label": label}
154
+
155
+ sample_spec = {
156
+ "image": jax.ShapeDtypeStruct(shape=tuple(images_memmap.shape[1:]), dtype=jnp.uint8),
157
+ "label": jax.ShapeDtypeStruct(shape=(), dtype=jnp.int32),
158
+ }
159
+
160
+ return DiskSampleSource(
161
+ length=int(labels_memmap.shape[0]),
162
+ sample_fn=_read_sample,
163
+ sample_spec=sample_spec,
164
+ ordering=ordering,
165
+ prefetch_size=prefetch_size,
166
+ )
167
+
168
+
169
+ @dataclass
170
+ class CIFAR10DiskSource:
171
+ """Stream CIFAR-10 samples directly from cached numpy arrays on disk."""
172
+
173
+ split: Literal["train", "test"] = "train"
174
+ cache_dir: str | Path | None = None
175
+ ordering: Literal["sequential", "shuffle"] = "shuffle"
176
+ prefetch_size: int = 64
177
+
178
+ def __post_init__(self) -> None:
179
+ self._disk_source = CIFAR10Dataset.make_disk_source(
180
+ split=self.split,
181
+ cache_dir=self.cache_dir,
182
+ ordering=self.ordering,
183
+ prefetch_size=self.prefetch_size,
184
+ )
185
+ self.steps_per_epoch = self._disk_source.steps_per_epoch
186
+
187
+ def element_spec(self):
188
+ return self._disk_source.element_spec()
189
+
190
+ def init_state(self, key=None):
191
+ return self._disk_source.init_state(key)
192
+
193
+ def next(self, state):
194
+ return self._disk_source.next(state)