insitubatch 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- insitubatch-0.0.1/LICENSE +21 -0
- insitubatch-0.0.1/PKG-INFO +90 -0
- insitubatch-0.0.1/README.md +61 -0
- insitubatch-0.0.1/pyproject.toml +72 -0
- insitubatch-0.0.1/src/insitubatch/__init__.py +44 -0
- insitubatch-0.0.1/src/insitubatch/buffer.py +88 -0
- insitubatch-0.0.1/src/insitubatch/io.py +172 -0
- insitubatch-0.0.1/src/insitubatch/plan.py +112 -0
- insitubatch-0.0.1/src/insitubatch/py.typed +0 -0
- insitubatch-0.0.1/src/insitubatch/shuffle.py +74 -0
- insitubatch-0.0.1/src/insitubatch/source.py +118 -0
- insitubatch-0.0.1/src/insitubatch/split.py +97 -0
- insitubatch-0.0.1/src/insitubatch/store.py +75 -0
- insitubatch-0.0.1/src/insitubatch/types.py +108 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 David Stuebe
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: insitubatch
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Train in place on n-dimensional cloud tensors: the data-loader orchestration layer on top of solved async IO (obstore / zarr v3 / icechunk).
|
|
5
|
+
Keywords: zarr,xarray,pytorch,dataloader,obstore,machine-learning,cloud,async
|
|
6
|
+
Author: David Stuebe
|
|
7
|
+
Author-email: David Stuebe <stu3b3+emfdavid@gmail.com>
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering
|
|
15
|
+
Requires-Dist: numpy>=1.26
|
|
16
|
+
Requires-Dist: zarr>=3.0
|
|
17
|
+
Requires-Dist: xarray>=2024.9
|
|
18
|
+
Requires-Dist: obstore>=0.3
|
|
19
|
+
Requires-Dist: cupy-cuda12x>=13.0 ; extra == 'gpu'
|
|
20
|
+
Requires-Dist: kvikio-cu12>=24.10 ; extra == 'gpu'
|
|
21
|
+
Requires-Dist: torch>=2.2 ; extra == 'torch'
|
|
22
|
+
Requires-Dist: torchdata>=0.10 ; extra == 'torch'
|
|
23
|
+
Requires-Dist: virtualizarr>=1.2 ; extra == 'virtual'
|
|
24
|
+
Requires-Python: >=3.12
|
|
25
|
+
Provides-Extra: gpu
|
|
26
|
+
Provides-Extra: torch
|
|
27
|
+
Provides-Extra: virtual
|
|
28
|
+
Description-Content-Type: text/markdown
|
|
29
|
+
|
|
30
|
+
# insitubatch
|
|
31
|
+
|
|
32
|
+
**Train in place on n-dimensional cloud tensors.**
|
|
33
|
+
|
|
34
|
+
`insitubatch` is the data-loader orchestration layer that sits on top of
|
|
35
|
+
*already-solved* async cloud IO (obstore / zarr v3 / icechunk). It turns an
|
|
36
|
+
existing Zarr archive into a shuffled, split-aware, GPU-saturating PyTorch
|
|
37
|
+
source — **with no reshard** — and a Python hot path that scales with **chunks,
|
|
38
|
+
not samples**.
|
|
39
|
+
|
|
40
|
+
> The IO race is over (obstore/icechunk saturate the NIC). The *loader* race is
|
|
41
|
+
> open. `insitubatch` builds the layer that projects like light-speed-io and
|
|
42
|
+
> hypergrib stopped one step short of. See [DESIGN.md](DESIGN.md).
|
|
43
|
+
|
|
44
|
+
## Why
|
|
45
|
+
|
|
46
|
+
The classic PyTorch `DataLoader` spreads work across worker **processes**, each
|
|
47
|
+
running a *synchronous* `__getitem__`. Against cloud Zarr that means no shared
|
|
48
|
+
chunk cache (every worker re-reads the same chunk), no way to drive async
|
|
49
|
+
obstore, and dask thread pools nested inside forked workers. `insitubatch`
|
|
50
|
+
**inverts** it: one async event loop drives concurrent reads; a bounded
|
|
51
|
+
shuffle-block buffer assembles batches; torch runs `num_workers=0`.
|
|
52
|
+
|
|
53
|
+
## Status
|
|
54
|
+
|
|
55
|
+
🚧 **Pre-alpha skeleton.** Abstractions and control flow are in place; the live
|
|
56
|
+
store read in `io.py` and the GPU path are stubbed. Not yet usable for real
|
|
57
|
+
training — this is the design substrate.
|
|
58
|
+
|
|
59
|
+
## Install (dev)
|
|
60
|
+
|
|
61
|
+
```bash
|
|
62
|
+
uv sync # core engine + dev tools
|
|
63
|
+
uv sync --extra torch # add the torch IterableDataset surface
|
|
64
|
+
uv sync --extra gpu # CUDA box only: cupy + kvikio zero-copy path
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## Shape of the API (target)
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
from insitubatch import split_by_chunk, ArrayGeometry, SplitName
|
|
71
|
+
from insitubatch.source import InSituDataset
|
|
72
|
+
from torch.utils.data import DataLoader
|
|
73
|
+
|
|
74
|
+
geom = ArrayGeometry("t2m", shape=(8760, 721, 1440), chunks=(24, 721, 1440), dtype=...)
|
|
75
|
+
manifest = split_by_chunk(geom, fractions=(0.8, 0.1, 0.1))
|
|
76
|
+
|
|
77
|
+
ds = InSituDataset(store, {"t2m": geom}, manifest, SplitName.TRAIN,
|
|
78
|
+
batch_size=32, block_chunks=16)
|
|
79
|
+
|
|
80
|
+
# parallelism lives in insitubatch's event loop, not in workers:
|
|
81
|
+
loader = DataLoader(ds, batch_size=None, num_workers=0)
|
|
82
|
+
for epoch in range(n_epochs):
|
|
83
|
+
ds.set_epoch(epoch)
|
|
84
|
+
for batch in loader:
|
|
85
|
+
...
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
## License
|
|
89
|
+
|
|
90
|
+
MIT — see [LICENSE](LICENSE).
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# insitubatch
|
|
2
|
+
|
|
3
|
+
**Train in place on n-dimensional cloud tensors.**
|
|
4
|
+
|
|
5
|
+
`insitubatch` is the data-loader orchestration layer that sits on top of
|
|
6
|
+
*already-solved* async cloud IO (obstore / zarr v3 / icechunk). It turns an
|
|
7
|
+
existing Zarr archive into a shuffled, split-aware, GPU-saturating PyTorch
|
|
8
|
+
source — **with no reshard** — and a Python hot path that scales with **chunks,
|
|
9
|
+
not samples**.
|
|
10
|
+
|
|
11
|
+
> The IO race is over (obstore/icechunk saturate the NIC). The *loader* race is
|
|
12
|
+
> open. `insitubatch` builds the layer that projects like light-speed-io and
|
|
13
|
+
> hypergrib stopped one step short of. See [DESIGN.md](DESIGN.md).
|
|
14
|
+
|
|
15
|
+
## Why
|
|
16
|
+
|
|
17
|
+
The classic PyTorch `DataLoader` spreads work across worker **processes**, each
|
|
18
|
+
running a *synchronous* `__getitem__`. Against cloud Zarr that means no shared
|
|
19
|
+
chunk cache (every worker re-reads the same chunk), no way to drive async
|
|
20
|
+
obstore, and dask thread pools nested inside forked workers. `insitubatch`
|
|
21
|
+
**inverts** it: one async event loop drives concurrent reads; a bounded
|
|
22
|
+
shuffle-block buffer assembles batches; torch runs `num_workers=0`.
|
|
23
|
+
|
|
24
|
+
## Status
|
|
25
|
+
|
|
26
|
+
🚧 **Pre-alpha skeleton.** Abstractions and control flow are in place; the live
|
|
27
|
+
store read in `io.py` and the GPU path are stubbed. Not yet usable for real
|
|
28
|
+
training — this is the design substrate.
|
|
29
|
+
|
|
30
|
+
## Install (dev)
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
uv sync # core engine + dev tools
|
|
34
|
+
uv sync --extra torch # add the torch IterableDataset surface
|
|
35
|
+
uv sync --extra gpu # CUDA box only: cupy + kvikio zero-copy path
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
## Shape of the API (target)
|
|
39
|
+
|
|
40
|
+
```python
|
|
41
|
+
from insitubatch import split_by_chunk, ArrayGeometry, SplitName
|
|
42
|
+
from insitubatch.source import InSituDataset
|
|
43
|
+
from torch.utils.data import DataLoader
|
|
44
|
+
|
|
45
|
+
geom = ArrayGeometry("t2m", shape=(8760, 721, 1440), chunks=(24, 721, 1440), dtype=...)
|
|
46
|
+
manifest = split_by_chunk(geom, fractions=(0.8, 0.1, 0.1))
|
|
47
|
+
|
|
48
|
+
ds = InSituDataset(store, {"t2m": geom}, manifest, SplitName.TRAIN,
|
|
49
|
+
batch_size=32, block_chunks=16)
|
|
50
|
+
|
|
51
|
+
# parallelism lives in insitubatch's event loop, not in workers:
|
|
52
|
+
loader = DataLoader(ds, batch_size=None, num_workers=0)
|
|
53
|
+
for epoch in range(n_epochs):
|
|
54
|
+
ds.set_epoch(epoch)
|
|
55
|
+
for batch in loader:
|
|
56
|
+
...
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## License
|
|
60
|
+
|
|
61
|
+
MIT — see [LICENSE](LICENSE).
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "insitubatch"
|
|
3
|
+
version = "0.0.1"
|
|
4
|
+
description = "Train in place on n-dimensional cloud tensors: the data-loader orchestration layer on top of solved async IO (obstore / zarr v3 / icechunk)."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
license = "MIT"
|
|
7
|
+
license-files = ["LICENSE"]
|
|
8
|
+
authors = [
|
|
9
|
+
{ name = "David Stuebe", email = "stu3b3+emfdavid@gmail.com" }
|
|
10
|
+
]
|
|
11
|
+
requires-python = ">=3.12"
|
|
12
|
+
keywords = ["zarr", "xarray", "pytorch", "dataloader", "obstore", "machine-learning", "cloud", "async"]
|
|
13
|
+
classifiers = [
|
|
14
|
+
"Development Status :: 2 - Pre-Alpha",
|
|
15
|
+
"Programming Language :: Python :: 3.12",
|
|
16
|
+
"Programming Language :: Python :: 3.13",
|
|
17
|
+
"Intended Audience :: Science/Research",
|
|
18
|
+
"Topic :: Scientific/Engineering",
|
|
19
|
+
]
|
|
20
|
+
dependencies = [
|
|
21
|
+
"numpy>=1.26",
|
|
22
|
+
"zarr>=3.0",
|
|
23
|
+
"xarray>=2024.9",
|
|
24
|
+
"obstore>=0.3",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[project.optional-dependencies]
|
|
28
|
+
# Torch handoff surface (IterableDataset / torchdata.nodes). Kept optional so the
|
|
29
|
+
# core engine stays framework-agnostic and importable without torch.
|
|
30
|
+
torch = [
|
|
31
|
+
"torch>=2.2",
|
|
32
|
+
"torchdata>=0.10",
|
|
33
|
+
]
|
|
34
|
+
# GPU-native path (zero-copy chunk -> cupy -> dlpack -> torch, optional nvCOMP
|
|
35
|
+
# decode). Pulls CUDA wheels; only resolvable on a CUDA box, hence isolated here.
|
|
36
|
+
gpu = [
|
|
37
|
+
"cupy-cuda12x>=13.0",
|
|
38
|
+
"kvikio-cu12>=24.10",
|
|
39
|
+
]
|
|
40
|
+
# Virtual-zarr views over GRIB/NetCDF so the engine only ever speaks zarr-async.
|
|
41
|
+
virtual = [
|
|
42
|
+
"virtualizarr>=1.2",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
[dependency-groups]
|
|
46
|
+
dev = [
|
|
47
|
+
"pytest>=8.0",
|
|
48
|
+
"pytest-asyncio>=0.24",
|
|
49
|
+
"ruff>=0.6",
|
|
50
|
+
"mypy>=1.11",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
[build-system]
|
|
54
|
+
requires = ["uv_build>=0.11.21,<0.12.0"]
|
|
55
|
+
build-backend = "uv_build"
|
|
56
|
+
|
|
57
|
+
[tool.ruff]
|
|
58
|
+
line-length = 100
|
|
59
|
+
target-version = "py312"
|
|
60
|
+
|
|
61
|
+
[tool.ruff.lint]
|
|
62
|
+
select = ["E", "F", "I", "UP", "B", "SIM"]
|
|
63
|
+
|
|
64
|
+
[tool.pytest.ini_options]
|
|
65
|
+
asyncio_mode = "auto"
|
|
66
|
+
testpaths = ["tests"]
|
|
67
|
+
|
|
68
|
+
[tool.mypy]
|
|
69
|
+
python_version = "3.12"
|
|
70
|
+
warn_unused_ignores = true
|
|
71
|
+
disallow_untyped_defs = true
|
|
72
|
+
ignore_missing_imports = true
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""insitubatch -- train in place on n-dimensional cloud tensors.
|
|
2
|
+
|
|
3
|
+
The loader-orchestration layer that sits on top of *already-solved* async cloud
|
|
4
|
+
IO (obstore / zarr v3 / icechunk): turns an existing Zarr archive into a
|
|
5
|
+
shuffled, split-aware, GPU-saturating PyTorch source with no reshard and a
|
|
6
|
+
Python hot path that scales with chunks, not samples.
|
|
7
|
+
|
|
8
|
+
See DESIGN.md for the full rationale.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from .buffer import BufferConfig, ShuffleBlockBuffer
|
|
14
|
+
from .io import AsyncChunkReader, IOConfig
|
|
15
|
+
from .plan import ReadPlan, build_read_plan, dedup_ratio
|
|
16
|
+
from .shuffle import block_shuffled_order, chunk_permutation, shuffle_quality
|
|
17
|
+
from .split import SplitManifest, split_by_chunk
|
|
18
|
+
from .store import ensure_local_dir, open_geometries, store_from_url
|
|
19
|
+
from .types import ArrayGeometry, Batch, ChunkRead, DecodedChunk, SplitName
|
|
20
|
+
|
|
21
|
+
__version__ = "0.0.1"
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ArrayGeometry",
|
|
25
|
+
"AsyncChunkReader",
|
|
26
|
+
"Batch",
|
|
27
|
+
"BufferConfig",
|
|
28
|
+
"ChunkRead",
|
|
29
|
+
"DecodedChunk",
|
|
30
|
+
"IOConfig",
|
|
31
|
+
"ReadPlan",
|
|
32
|
+
"ShuffleBlockBuffer",
|
|
33
|
+
"SplitManifest",
|
|
34
|
+
"SplitName",
|
|
35
|
+
"block_shuffled_order",
|
|
36
|
+
"build_read_plan",
|
|
37
|
+
"chunk_permutation",
|
|
38
|
+
"dedup_ratio",
|
|
39
|
+
"ensure_local_dir",
|
|
40
|
+
"open_geometries",
|
|
41
|
+
"shuffle_quality",
|
|
42
|
+
"split_by_chunk",
|
|
43
|
+
"store_from_url",
|
|
44
|
+
]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Bounded shuffle-block buffer: residency + batch assembly.
|
|
2
|
+
|
|
3
|
+
Holds decoded chunks for a window of ``block_chunks`` chunks, draws shuffled
|
|
4
|
+
batches across that window, and evicts chunks once fully drained. This is the
|
|
5
|
+
memory-bounding component: peak residency is O(block_chunks), independent of the
|
|
6
|
+
number of samples per epoch or the batch size.
|
|
7
|
+
|
|
8
|
+
The batch assembly does a single coalesced gather per variable (one fancy-index
|
|
9
|
+
copy), never a Python per-sample loop -- the constraint David's S3 benchmark
|
|
10
|
+
imposed (Python per-chunk overhead bounds throughput).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from collections import deque
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from .types import Batch, DecodedChunk
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(slots=True)
|
|
24
|
+
class BufferConfig:
|
|
25
|
+
block_chunks: int = 16
|
|
26
|
+
"""Window size in chunks. Larger == better shuffle, more memory."""
|
|
27
|
+
|
|
28
|
+
batch_size: int = 32
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(slots=True)
|
|
32
|
+
class ShuffleBlockBuffer:
|
|
33
|
+
"""Accumulates decoded chunks and emits shuffled, coalesced batches."""
|
|
34
|
+
|
|
35
|
+
config: BufferConfig
|
|
36
|
+
seed: int = 0
|
|
37
|
+
_chunks: dict[tuple[str, int], DecodedChunk] = field(default_factory=dict)
|
|
38
|
+
# Pending draws: rows of (array, chunk_index, within) flattened per variable.
|
|
39
|
+
_pending: deque[int] = field(default_factory=deque)
|
|
40
|
+
|
|
41
|
+
def add(self, chunk: DecodedChunk) -> None:
|
|
42
|
+
self._chunks[(chunk.read.array, chunk.read.chunk_index)] = chunk
|
|
43
|
+
|
|
44
|
+
def ready(self) -> bool:
|
|
45
|
+
"""Enough buffered to safely emit a well-mixed batch."""
|
|
46
|
+
return len(self._chunks) >= self.config.block_chunks
|
|
47
|
+
|
|
48
|
+
def gather_batch(
|
|
49
|
+
self,
|
|
50
|
+
rows: np.ndarray,
|
|
51
|
+
variables: list[str],
|
|
52
|
+
sample_chunk_size: int,
|
|
53
|
+
) -> Batch:
|
|
54
|
+
"""Assemble one batch from ``rows`` of ``[chunk_id, within]`` draws.
|
|
55
|
+
|
|
56
|
+
``rows`` are pre-shuffled draw coordinates (see shuffle.block_shuffled_order).
|
|
57
|
+
Draws are grouped by chunk so each resident array is touched once (one
|
|
58
|
+
coalesced fancy-index per chunk); ``data`` and ``sample_indices`` are
|
|
59
|
+
emitted in the *same* grouped order so row ``i`` of every variable and
|
|
60
|
+
``sample_indices[i]`` refer to the same sample. Intra-batch order is thus
|
|
61
|
+
grouped-by-chunk -- irrelevant for training, and the cross-batch shuffle
|
|
62
|
+
is preserved.
|
|
63
|
+
|
|
64
|
+
``sample_chunk_size`` is the array's true chunk length (from geometry),
|
|
65
|
+
used to recover global sample indices -- NOT inferred from a resident
|
|
66
|
+
chunk, which may be a short final chunk.
|
|
67
|
+
"""
|
|
68
|
+
chunk_ids = rows[:, 0]
|
|
69
|
+
within = rows[:, 1]
|
|
70
|
+
uniq = np.unique(chunk_ids)
|
|
71
|
+
|
|
72
|
+
out: dict[str, list[np.ndarray]] = {v: [] for v in variables}
|
|
73
|
+
idx_pieces: list[np.ndarray] = []
|
|
74
|
+
for cid in uniq:
|
|
75
|
+
w = within[chunk_ids == cid]
|
|
76
|
+
idx_pieces.append(cid * sample_chunk_size + w)
|
|
77
|
+
for var in variables:
|
|
78
|
+
out[var].append(self._chunks[(var, int(cid))].data[w])
|
|
79
|
+
|
|
80
|
+
arrays = {var: np.concatenate(pieces, axis=0) for var, pieces in out.items()}
|
|
81
|
+
return Batch(arrays=arrays, sample_indices=np.concatenate(idx_pieces))
|
|
82
|
+
|
|
83
|
+
def evict_drained(self, still_needed: set[tuple[str, int]]) -> int:
|
|
84
|
+
"""Drop chunks no longer referenced by any pending draw. Returns count."""
|
|
85
|
+
drop = [k for k in self._chunks if k not in still_needed]
|
|
86
|
+
for k in drop:
|
|
87
|
+
del self._chunks[k]
|
|
88
|
+
return len(drop)
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""Async IO driver: the obstore win.
|
|
2
|
+
|
|
3
|
+
This is where insitubatch *stands on* solved cloud IO rather than reinventing
|
|
4
|
+
it. A single dedicated asyncio event loop (one OS thread) issues many concurrent
|
|
5
|
+
chunk reads against a zarr v3 store -- ideally the obstore-backed store, whose
|
|
6
|
+
``get_ranges_async`` coalesces concurrent range requests in a single coroutine
|
|
7
|
+
and saturates the NIC without spawning Python threads per request.
|
|
8
|
+
|
|
9
|
+
Design rules (DESIGN.md, "the inversion"):
|
|
10
|
+
* Parallelism lives *here*, in the event loop's concurrency, NOT in
|
|
11
|
+
torch.DataLoader worker processes. The torch surface runs with
|
|
12
|
+
``num_workers=0``.
|
|
13
|
+
* A bounded in-flight window (semaphore of ``max_inflight`` chunks) caps memory
|
|
14
|
+
at O(in-flight chunks), independent of batch size.
|
|
15
|
+
* Decode releases the GIL (numcodecs C codecs do) so decode overlaps IO. The
|
|
16
|
+
Python hot path stays O(reads); never decode per-sample in Python.
|
|
17
|
+
|
|
18
|
+
Status: SKELETON. The async store wiring is stubbed at the marked TODOs so the
|
|
19
|
+
module imports and the control flow is reviewable without a live store.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import asyncio
|
|
25
|
+
import queue
|
|
26
|
+
import threading
|
|
27
|
+
from collections.abc import Iterator
|
|
28
|
+
from dataclasses import dataclass
|
|
29
|
+
|
|
30
|
+
import numpy as np
|
|
31
|
+
import zarr.api.asynchronous as za
|
|
32
|
+
|
|
33
|
+
from .plan import ReadPlan
|
|
34
|
+
from .store import store_from_url
|
|
35
|
+
from .types import ArrayGeometry, ChunkRead, DecodedChunk
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass(slots=True)
|
|
39
|
+
class IOConfig:
|
|
40
|
+
max_inflight: int = 16
|
|
41
|
+
"""Upper bound on chunks in flight. Memory ~= max_inflight * chunk_nbytes."""
|
|
42
|
+
|
|
43
|
+
decode_threads: int = 4
|
|
44
|
+
"""Worker threads for GIL-releasing decode, overlapped with IO."""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class AsyncChunkReader:
|
|
48
|
+
"""Owns one asyncio event loop on a background thread and fans out reads.
|
|
49
|
+
|
|
50
|
+
The public API is deliberately *synchronous and iterator-shaped* so the rest
|
|
51
|
+
of the engine (buffer, torch surface) needs no async knowledge -- the loop is
|
|
52
|
+
an implementation detail hidden behind a thread-safe queue.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
store_url: str,
|
|
58
|
+
geometries: dict[str, ArrayGeometry],
|
|
59
|
+
config: IOConfig | None = None,
|
|
60
|
+
**store_kwargs: object,
|
|
61
|
+
) -> None:
|
|
62
|
+
self._url = store_url
|
|
63
|
+
self._store_kwargs = store_kwargs
|
|
64
|
+
self._geometries = geometries
|
|
65
|
+
self._config = config or IOConfig()
|
|
66
|
+
self._arrays: dict[str, za.AsyncArray] = {} # opened lazily on the loop
|
|
67
|
+
self._loop = asyncio.new_event_loop()
|
|
68
|
+
self._sem: asyncio.Semaphore | None = None # created on the loop bootstrap
|
|
69
|
+
self._open_lock: asyncio.Lock | None = None
|
|
70
|
+
self._ready = threading.Event()
|
|
71
|
+
self._thread = threading.Thread(target=self._run_loop, daemon=True, name="insitu-io")
|
|
72
|
+
self._thread.start()
|
|
73
|
+
self._ready.wait() # don't return until the loop + primitives exist
|
|
74
|
+
|
|
75
|
+
# -- loop lifecycle -----------------------------------------------------
|
|
76
|
+
|
|
77
|
+
def _run_loop(self) -> None:
|
|
78
|
+
asyncio.set_event_loop(self._loop)
|
|
79
|
+
self._sem = asyncio.Semaphore(self._config.max_inflight)
|
|
80
|
+
self._open_lock = asyncio.Lock()
|
|
81
|
+
self._loop.call_soon(self._ready.set)
|
|
82
|
+
self._loop.run_forever()
|
|
83
|
+
|
|
84
|
+
def close(self) -> None:
|
|
85
|
+
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
86
|
+
self._thread.join(timeout=5)
|
|
87
|
+
|
|
88
|
+
def __enter__(self) -> AsyncChunkReader:
|
|
89
|
+
return self
|
|
90
|
+
|
|
91
|
+
def __exit__(self, *exc: object) -> None:
|
|
92
|
+
self.close()
|
|
93
|
+
|
|
94
|
+
# -- public, synchronous surface ---------------------------------------
|
|
95
|
+
|
|
96
|
+
_SENTINEL = object()
|
|
97
|
+
|
|
98
|
+
def read_plan(self, plan: ReadPlan) -> Iterator[DecodedChunk]:
|
|
99
|
+
"""Fetch every read in ``plan`` concurrently; yield decoded chunks ASAP.
|
|
100
|
+
|
|
101
|
+
Yields in completion order (not plan order) so a slow chunk never stalls
|
|
102
|
+
the others -- the buffer downstream is responsible for reordering/gather.
|
|
103
|
+
|
|
104
|
+
The bridge uses a thread-safe ``queue.Queue`` and a done-callback so the
|
|
105
|
+
sentinel is *always* delivered -- even if the driver raises -- and any
|
|
106
|
+
exception is re-raised on the caller's thread rather than deadlocking it.
|
|
107
|
+
"""
|
|
108
|
+
out_q: queue.Queue = queue.Queue()
|
|
109
|
+
|
|
110
|
+
def _on_done(fut: object) -> None:
|
|
111
|
+
try:
|
|
112
|
+
fut.result() # type: ignore[attr-defined] # surface driver errors
|
|
113
|
+
except Exception as exc: # noqa: BLE001 - forwarded to the consumer
|
|
114
|
+
out_q.put(exc)
|
|
115
|
+
finally:
|
|
116
|
+
out_q.put(self._SENTINEL)
|
|
117
|
+
|
|
118
|
+
fut = asyncio.run_coroutine_threadsafe(self._drive(plan, out_q), self._loop)
|
|
119
|
+
fut.add_done_callback(_on_done)
|
|
120
|
+
|
|
121
|
+
while True:
|
|
122
|
+
item = out_q.get()
|
|
123
|
+
if item is self._SENTINEL:
|
|
124
|
+
break
|
|
125
|
+
if isinstance(item, Exception):
|
|
126
|
+
raise item
|
|
127
|
+
yield item
|
|
128
|
+
|
|
129
|
+
# -- async internals ----------------------------------------------------
|
|
130
|
+
|
|
131
|
+
async def _drive(self, plan: ReadPlan, out_q: queue.Queue) -> None:
|
|
132
|
+
assert self._sem is not None # guaranteed by _ready.wait() in __init__
|
|
133
|
+
await self._ensure_arrays()
|
|
134
|
+
|
|
135
|
+
async def one(read: ChunkRead) -> None:
|
|
136
|
+
async with self._sem: # bound in-flight chunks -> bounded memory
|
|
137
|
+
decoded = await self._fetch_and_decode(read)
|
|
138
|
+
out_q.put(decoded) # stdlib queue: thread-safe, non-blocking
|
|
139
|
+
|
|
140
|
+
await asyncio.gather(*(one(r) for r in plan.reads))
|
|
141
|
+
|
|
142
|
+
async def _ensure_arrays(self) -> None:
|
|
143
|
+
"""Open one AsyncArray per variable, once, sharing the store."""
|
|
144
|
+
if self._arrays:
|
|
145
|
+
return
|
|
146
|
+
assert self._open_lock is not None
|
|
147
|
+
async with self._open_lock:
|
|
148
|
+
if self._arrays: # double-checked: another coroutine may have won
|
|
149
|
+
return
|
|
150
|
+
store = store_from_url(self._url, **self._store_kwargs) # type: ignore[arg-type]
|
|
151
|
+
for name in self._geometries:
|
|
152
|
+
self._arrays[name] = await za.open_array(store=store, path=name, mode="r")
|
|
153
|
+
|
|
154
|
+
async def _fetch_and_decode(self, read: ChunkRead) -> DecodedChunk:
|
|
155
|
+
"""Fetch + decode one chunk via the zarr v3 async codec pipeline.
|
|
156
|
+
|
|
157
|
+
The selection is exactly one chunk along the sample axis, full on the
|
|
158
|
+
inner dims (the v1 sample-geometry contract). zarr fans the underlying
|
|
159
|
+
byte-range reads out through obstore and runs the decode pipeline; for
|
|
160
|
+
single-chunk inner dims this touches exactly one stored chunk.
|
|
161
|
+
|
|
162
|
+
TODO(perf): decode currently runs inside zarr's pipeline on the loop. If
|
|
163
|
+
it shows up as a GIL bottleneck, move it to a thread pool
|
|
164
|
+
(numcodecs C codecs release the GIL) -- the bounded-fan-out structure
|
|
165
|
+
here already isolates that change to this method.
|
|
166
|
+
"""
|
|
167
|
+
geom = self._geometries[read.array]
|
|
168
|
+
arr = self._arrays[read.array]
|
|
169
|
+
samples = geom.samples_in_chunk(read.chunk_index)
|
|
170
|
+
selection = (slice(samples.start, samples.stop), *(slice(None) for _ in geom.inner_shape))
|
|
171
|
+
block = await arr.getitem(selection)
|
|
172
|
+
return DecodedChunk(read=read, data=np.asarray(block), sample_offset=samples.start)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Read planning: samples -> deduplicated chunk reads.
|
|
2
|
+
|
|
3
|
+
This is the crux abstraction. Given the samples required for a window of the
|
|
4
|
+
epoch and the array geometries, produce the *minimal* set of chunk reads plus a
|
|
5
|
+
gather map describing where each sample lives once those chunks are decoded.
|
|
6
|
+
|
|
7
|
+
Why this matters (DESIGN.md, "the spectrum"):
|
|
8
|
+
- Fat chunks: many samples share one chunk -> dedup collapses N samples to 1
|
|
9
|
+
read; the shared decoded chunk is gathered N times. This is the shared-cache
|
|
10
|
+
win that the classic per-worker DataLoader cannot get.
|
|
11
|
+
- GRIB-per-timestep: one sample per chunk -> no dedup possible, but the plan
|
|
12
|
+
still drives a single wide async fan-out (B samples == B concurrent reads),
|
|
13
|
+
which is exactly where obstore earns its keep.
|
|
14
|
+
|
|
15
|
+
The Python hot path here is O(reads), never O(samples) once gathered, which is
|
|
16
|
+
the constraint David's S3 benchmark imposed (Python per-chunk overhead bounds
|
|
17
|
+
throughput; never loop per-sample in Python).
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from collections.abc import Sequence
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
from .types import ArrayGeometry, ChunkRead
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(slots=True)
|
|
31
|
+
class Gather:
|
|
32
|
+
"""Where one sample lives within a decoded chunk.
|
|
33
|
+
|
|
34
|
+
``read_index`` indexes into ``ReadPlan.reads``; ``within`` is the offset of
|
|
35
|
+
the sample inside that chunk's decoded array along the sample axis.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
read_index: int
|
|
39
|
+
within: int
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(slots=True)
|
|
43
|
+
class ReadPlan:
|
|
44
|
+
"""A deduplicated batch of chunk reads plus the gather map back to samples.
|
|
45
|
+
|
|
46
|
+
One ``ReadPlan`` typically covers enough samples to (a) saturate the async
|
|
47
|
+
fan-out and (b) fill the shuffle-block buffer. ``reads`` is what the IO
|
|
48
|
+
driver fetches; ``gathers[v]`` reconstructs the requested samples for
|
|
49
|
+
variable ``v`` from the decoded chunks.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
reads: list[ChunkRead]
|
|
53
|
+
gathers: dict[str, list[Gather]]
|
|
54
|
+
sample_indices: np.ndarray # global sample indices, in requested order
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def n_reads(self) -> int:
|
|
58
|
+
return len(self.reads)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def build_read_plan(
|
|
62
|
+
sample_indices: Sequence[int],
|
|
63
|
+
geometries: dict[str, ArrayGeometry],
|
|
64
|
+
) -> ReadPlan:
|
|
65
|
+
"""Build a deduplicated read plan for ``sample_indices`` across all variables.
|
|
66
|
+
|
|
67
|
+
All variables are assumed aligned on the sample axis (same length, possibly
|
|
68
|
+
different chunking) -- the common case for co-registered NWP variables. A
|
|
69
|
+
sample at global index ``s`` requires chunk ``geom.chunk_of(s)`` from *each*
|
|
70
|
+
variable; identical chunks requested by multiple samples are read once.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
sample_indices:
|
|
75
|
+
Global sample-axis indices to fetch, in the order they should appear.
|
|
76
|
+
geometries:
|
|
77
|
+
Variable name -> :class:`ArrayGeometry`.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
ReadPlan
|
|
82
|
+
"""
|
|
83
|
+
idx = np.asarray(sample_indices, dtype=np.int64)
|
|
84
|
+
reads: list[ChunkRead] = []
|
|
85
|
+
read_lookup: dict[ChunkRead, int] = {}
|
|
86
|
+
gathers: dict[str, list[Gather]] = {name: [] for name in geometries}
|
|
87
|
+
|
|
88
|
+
for name, geom in geometries.items():
|
|
89
|
+
# Vectorized chunk assignment for this variable across all samples.
|
|
90
|
+
chunk_ids = idx // geom.sample_chunk_size
|
|
91
|
+
within = idx - chunk_ids * geom.sample_chunk_size
|
|
92
|
+
for c, w in zip(chunk_ids.tolist(), within.tolist(), strict=True):
|
|
93
|
+
read = ChunkRead(array=name, chunk_index=int(c))
|
|
94
|
+
ri = read_lookup.get(read)
|
|
95
|
+
if ri is None:
|
|
96
|
+
ri = len(reads)
|
|
97
|
+
read_lookup[read] = ri
|
|
98
|
+
reads.append(read)
|
|
99
|
+
gathers[name].append(Gather(read_index=ri, within=int(w)))
|
|
100
|
+
|
|
101
|
+
return ReadPlan(reads=reads, gathers=gathers, sample_indices=idx)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def dedup_ratio(plan: ReadPlan) -> float:
|
|
105
|
+
"""Samples-per-read averaged over variables.
|
|
106
|
+
|
|
107
|
+
1.0 == degenerate (GRIB-per-timestep, no sharing); higher == fatter chunks
|
|
108
|
+
with more cache reuse. A quick lever for understanding which regime a dataset
|
|
109
|
+
+ batch size lands in.
|
|
110
|
+
"""
|
|
111
|
+
n_samples = len(plan.sample_indices) * max(len(plan.gathers), 1)
|
|
112
|
+
return n_samples / plan.n_reads if plan.n_reads else 0.0
|
|
File without changes
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Approximate-global shuffle for chunk-aligned data.
|
|
2
|
+
|
|
3
|
+
True global shuffle is incompatible with chunk-aligned, low-copy reads: it would
|
|
4
|
+
demand a random chunk per sample. The compromise (DESIGN.md, "shuffle"), adapted
|
|
5
|
+
from MosaicML Streaming's shuffle-block algorithms (py1e / py1br), is two-level:
|
|
6
|
+
|
|
7
|
+
1. **Chunk permutation** -- shuffle the *order chunks are scheduled* each epoch.
|
|
8
|
+
2. **Shuffle-block buffer** -- hold samples from a window of B chunks and draw
|
|
9
|
+
batches across the whole window, so samples from different chunks interleave.
|
|
10
|
+
|
|
11
|
+
Setting the block span B >= ~10x the samples-per-chunk yields shuffle quality
|
|
12
|
+
close to global, at memory cost O(B chunks). B is the single quality<->memory
|
|
13
|
+
knob. This module owns the *index math*; buffer.py owns the residency.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def chunk_permutation(chunk_ids: np.ndarray, *, seed: int, epoch: int) -> np.ndarray:
|
|
22
|
+
"""Deterministically permute chunk ids for one epoch.
|
|
23
|
+
|
|
24
|
+
Determinism is keyed on (seed, epoch) only -- not on world size or worker
|
|
25
|
+
count -- so a run is reproducible and resumable across hardware (the
|
|
26
|
+
"canonical" property from MosaicML).
|
|
27
|
+
"""
|
|
28
|
+
rng = np.random.default_rng((seed, epoch))
|
|
29
|
+
return rng.permutation(chunk_ids)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def block_shuffled_order(
|
|
33
|
+
chunk_ids: np.ndarray,
|
|
34
|
+
samples_per_chunk: int,
|
|
35
|
+
*,
|
|
36
|
+
block_chunks: int,
|
|
37
|
+
seed: int,
|
|
38
|
+
epoch: int,
|
|
39
|
+
) -> np.ndarray:
|
|
40
|
+
"""Produce a globally-ordered list of (chunk_id, within) draws.
|
|
41
|
+
|
|
42
|
+
Emulates the shuffle-block draw order the live buffer will realise, useful
|
|
43
|
+
for the quality harness and for deterministic single-process iteration.
|
|
44
|
+
Returns an array of shape ``(n_samples, 2)`` of ``[chunk_id, within]`` rows.
|
|
45
|
+
"""
|
|
46
|
+
perm = chunk_permutation(chunk_ids, seed=seed, epoch=epoch)
|
|
47
|
+
rng = np.random.default_rng((seed, epoch, 7919))
|
|
48
|
+
|
|
49
|
+
rows: list[np.ndarray] = []
|
|
50
|
+
for start in range(0, len(perm), block_chunks):
|
|
51
|
+
block = perm[start : start + block_chunks]
|
|
52
|
+
# Materialise every (chunk, within) pair in this block, then shuffle.
|
|
53
|
+
cc = np.repeat(block, samples_per_chunk)
|
|
54
|
+
ww = np.tile(np.arange(samples_per_chunk), len(block))
|
|
55
|
+
pairs = np.stack([cc, ww], axis=1)
|
|
56
|
+
rng.shuffle(pairs) # in-place, along axis 0
|
|
57
|
+
rows.append(pairs)
|
|
58
|
+
return np.concatenate(rows, axis=0)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def shuffle_quality(order: np.ndarray, samples_per_chunk: int) -> float:
|
|
62
|
+
"""A 0..1 score for how well an emitted order mixes the source.
|
|
63
|
+
|
|
64
|
+
Heuristic: the mean absolute *source-rank* gap between consecutive emitted
|
|
65
|
+
samples, normalised by the gap a perfect global shuffle would give. 1.0 ~=
|
|
66
|
+
global; values near 0 mean adjacent samples still come out near each other
|
|
67
|
+
(poor mixing). Cheap to compute, good enough to tune ``block_chunks``.
|
|
68
|
+
"""
|
|
69
|
+
source_rank = order[:, 0] * samples_per_chunk + order[:, 1]
|
|
70
|
+
gaps = np.abs(np.diff(source_rank.astype(np.int64)))
|
|
71
|
+
n = len(source_rank)
|
|
72
|
+
# Expected mean gap of a uniform random permutation of 0..n-1 is ~n/3.
|
|
73
|
+
expected = n / 3.0
|
|
74
|
+
return float(min(gaps.mean() / expected, 1.0)) if n > 1 and expected else 0.0
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Torch handoff surface.
|
|
2
|
+
|
|
3
|
+
Ties the pieces together and exposes them to PyTorch *without* using the classic
|
|
4
|
+
DataLoader worker model. Parallelism lives in :class:`AsyncChunkReader`'s event
|
|
5
|
+
loop, so the recommended configuration is::
|
|
6
|
+
|
|
7
|
+
loader = DataLoader(InSituDataset(...), batch_size=None, num_workers=0)
|
|
8
|
+
|
|
9
|
+
``batch_size=None`` because the dataset already yields assembled batches;
|
|
10
|
+
``num_workers=0`` because forking workers would re-introduce exactly the
|
|
11
|
+
redundant-read / nested-parallelism problems we set out to avoid.
|
|
12
|
+
|
|
13
|
+
torch (and torchdata.nodes) are optional imports so the core engine stays
|
|
14
|
+
framework-agnostic and importable on a box without torch installed.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from collections.abc import Iterator
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from .buffer import BufferConfig, ShuffleBlockBuffer
|
|
24
|
+
from .io import AsyncChunkReader, IOConfig
|
|
25
|
+
from .plan import build_read_plan
|
|
26
|
+
from .shuffle import block_shuffled_order
|
|
27
|
+
from .split import SplitManifest
|
|
28
|
+
from .store import open_geometries
|
|
29
|
+
from .types import ArrayGeometry, Batch, SplitName
|
|
30
|
+
|
|
31
|
+
try: # optional torch surface
|
|
32
|
+
from torch.utils.data import IterableDataset
|
|
33
|
+
|
|
34
|
+
_HAS_TORCH = True
|
|
35
|
+
except ImportError: # pragma: no cover - exercised on torch-less installs
|
|
36
|
+
IterableDataset = object # type: ignore[assignment,misc]
|
|
37
|
+
_HAS_TORCH = False
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class InSituDataset(IterableDataset): # type: ignore[misc]
|
|
41
|
+
"""An IterableDataset that streams shuffled batches from a Zarr archive.
|
|
42
|
+
|
|
43
|
+
One epoch = permute the split's chunks -> walk shuffle-blocks -> for each
|
|
44
|
+
block, async-fetch its chunks, fill the buffer, emit coalesced batches.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
store_url: str,
|
|
50
|
+
manifest: SplitManifest,
|
|
51
|
+
geometries: dict[str, ArrayGeometry] | None = None,
|
|
52
|
+
split: SplitName = SplitName.TRAIN,
|
|
53
|
+
*,
|
|
54
|
+
batch_size: int = 32,
|
|
55
|
+
block_chunks: int = 16,
|
|
56
|
+
max_inflight: int = 16,
|
|
57
|
+
seed: int = 0,
|
|
58
|
+
to_tensor: bool = True,
|
|
59
|
+
**store_kwargs: object,
|
|
60
|
+
) -> None:
|
|
61
|
+
self.store_url = store_url
|
|
62
|
+
self.store_kwargs = store_kwargs
|
|
63
|
+
self.geometries = geometries if geometries is not None else open_geometries(store_url)
|
|
64
|
+
self.manifest = manifest
|
|
65
|
+
self.split = split
|
|
66
|
+
self.variables = list(self.geometries)
|
|
67
|
+
self.io_config = IOConfig(max_inflight=max_inflight)
|
|
68
|
+
self.buffer_config = BufferConfig(block_chunks=block_chunks, batch_size=batch_size)
|
|
69
|
+
self.seed = seed
|
|
70
|
+
self.to_tensor = to_tensor and _HAS_TORCH
|
|
71
|
+
self._epoch = 0
|
|
72
|
+
|
|
73
|
+
def set_epoch(self, epoch: int) -> None:
|
|
74
|
+
"""Call from the training loop so each epoch reshuffles deterministically."""
|
|
75
|
+
self._epoch = epoch
|
|
76
|
+
|
|
77
|
+
def __iter__(self) -> Iterator[Batch | dict]:
|
|
78
|
+
geom = self.geometries[self.variables[0]]
|
|
79
|
+
chunk_ids = np.asarray(self.manifest.chunks[self.split.value], dtype=np.int64)
|
|
80
|
+
spc = geom.sample_chunk_size
|
|
81
|
+
|
|
82
|
+
order = block_shuffled_order(
|
|
83
|
+
chunk_ids,
|
|
84
|
+
spc,
|
|
85
|
+
block_chunks=self.buffer_config.block_chunks,
|
|
86
|
+
seed=self.seed,
|
|
87
|
+
epoch=self._epoch,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
with AsyncChunkReader(
|
|
91
|
+
self.store_url, self.geometries, self.io_config, **self.store_kwargs
|
|
92
|
+
) as reader:
|
|
93
|
+
buf = ShuffleBlockBuffer(self.buffer_config, seed=self.seed)
|
|
94
|
+
bs = self.buffer_config.batch_size
|
|
95
|
+
|
|
96
|
+
# Walk the emitted draw order in batch-sized windows. For each window
|
|
97
|
+
# we ensure its chunks are resident (async fan-out), then gather.
|
|
98
|
+
for start in range(0, len(order), bs):
|
|
99
|
+
rows = order[start : start + bs]
|
|
100
|
+
needed = {(v, int(c)) for v in self.variables for c in np.unique(rows[:, 0])}
|
|
101
|
+
missing_samples = [
|
|
102
|
+
int(c) * spc for (v, c) in needed if (v, c) not in buf._chunks
|
|
103
|
+
]
|
|
104
|
+
if missing_samples:
|
|
105
|
+
plan = build_read_plan(sorted(set(missing_samples)), self.geometries)
|
|
106
|
+
for decoded in reader.read_plan(plan):
|
|
107
|
+
buf.add(decoded)
|
|
108
|
+
|
|
109
|
+
batch = buf.gather_batch(rows, self.variables, spc)
|
|
110
|
+
yield self._maybe_tensor(batch)
|
|
111
|
+
buf.evict_drained(needed)
|
|
112
|
+
|
|
113
|
+
def _maybe_tensor(self, batch: Batch) -> Batch | dict:
|
|
114
|
+
if not self.to_tensor:
|
|
115
|
+
return batch
|
|
116
|
+
import torch
|
|
117
|
+
|
|
118
|
+
return {k: torch.from_numpy(v) for k, v in batch.arrays.items()}
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Chunk-aligned train/val/test splits.
|
|
2
|
+
|
|
3
|
+
Splits are done *ahead of time* and at *chunk granularity* along the sample
|
|
4
|
+
axis. Two reasons (DESIGN.md, "splits"):
|
|
5
|
+
|
|
6
|
+
1. Leakage: splitting mid-chunk would scatter temporally adjacent, highly
|
|
7
|
+
autocorrelated samples across train and val. Chunk-aligned boundaries keep
|
|
8
|
+
a contiguous block of time in a single split.
|
|
9
|
+
2. Zero-copy: a split that respects chunk boundaries means every read serves
|
|
10
|
+
exactly one split, so the engine never decodes a chunk and throws half of
|
|
11
|
+
it away.
|
|
12
|
+
|
|
13
|
+
The manifest is a plain, serializable record of which chunk indices belong to
|
|
14
|
+
which split, so a run is reproducible and shareable.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
from dataclasses import asdict, dataclass
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from .types import ArrayGeometry, SplitName
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(slots=True)
|
|
29
|
+
class SplitManifest:
|
|
30
|
+
"""Which sample-axis chunk indices belong to each split."""
|
|
31
|
+
|
|
32
|
+
n_chunks: int
|
|
33
|
+
sample_chunk_size: int
|
|
34
|
+
n_samples: int
|
|
35
|
+
chunks: dict[str, list[int]] # SplitName.value -> sorted chunk indices
|
|
36
|
+
seed: int
|
|
37
|
+
|
|
38
|
+
def sample_indices(self, split: SplitName, geom: ArrayGeometry) -> np.ndarray:
|
|
39
|
+
"""Expand a split's chunks into the global sample indices they contain."""
|
|
40
|
+
out: list[int] = []
|
|
41
|
+
for c in self.chunks[split.value]:
|
|
42
|
+
out.extend(geom.samples_in_chunk(c))
|
|
43
|
+
return np.asarray(out, dtype=np.int64)
|
|
44
|
+
|
|
45
|
+
def to_json(self, path: str | Path) -> None:
|
|
46
|
+
Path(path).write_text(json.dumps(asdict(self), indent=2))
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def from_json(cls, path: str | Path) -> SplitManifest:
|
|
50
|
+
return cls(**json.loads(Path(path).read_text()))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def split_by_chunk(
|
|
54
|
+
geom: ArrayGeometry,
|
|
55
|
+
*,
|
|
56
|
+
fractions: tuple[float, float, float] = (0.8, 0.1, 0.1),
|
|
57
|
+
seed: int = 0,
|
|
58
|
+
contiguous: bool = True,
|
|
59
|
+
) -> SplitManifest:
|
|
60
|
+
"""Partition a variable's sample-axis chunks into train/val/test.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
fractions:
|
|
65
|
+
(train, val, test) fractions of *chunks* (not samples). Must sum to ~1.
|
|
66
|
+
contiguous:
|
|
67
|
+
If True (default), assign contiguous blocks of chunks to each split --
|
|
68
|
+
the safest choice for time series, where a randomly interleaved split
|
|
69
|
+
still risks leakage through autocorrelation across chunk boundaries. If
|
|
70
|
+
False, chunks are shuffled before partitioning (acceptable when samples
|
|
71
|
+
are exchangeable, e.g. independent scenes).
|
|
72
|
+
"""
|
|
73
|
+
if abs(sum(fractions) - 1.0) > 1e-6:
|
|
74
|
+
raise ValueError(f"fractions must sum to 1.0, got {fractions} -> {sum(fractions)}")
|
|
75
|
+
|
|
76
|
+
n = geom.n_chunks
|
|
77
|
+
order = np.arange(n)
|
|
78
|
+
if not contiguous:
|
|
79
|
+
order = np.random.default_rng(seed).permutation(n)
|
|
80
|
+
|
|
81
|
+
n_train = int(round(fractions[0] * n))
|
|
82
|
+
n_val = int(round(fractions[1] * n))
|
|
83
|
+
train = sorted(order[:n_train].tolist())
|
|
84
|
+
val = sorted(order[n_train : n_train + n_val].tolist())
|
|
85
|
+
test = sorted(order[n_train + n_val :].tolist())
|
|
86
|
+
|
|
87
|
+
return SplitManifest(
|
|
88
|
+
n_chunks=n,
|
|
89
|
+
sample_chunk_size=geom.sample_chunk_size,
|
|
90
|
+
n_samples=geom.n_samples,
|
|
91
|
+
chunks={
|
|
92
|
+
SplitName.TRAIN.value: train,
|
|
93
|
+
SplitName.VAL.value: val,
|
|
94
|
+
SplitName.TEST.value: test,
|
|
95
|
+
},
|
|
96
|
+
seed=seed,
|
|
97
|
+
)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""Storage shim: one URL, any backend.
|
|
2
|
+
|
|
3
|
+
The whole local-now / cloud-later story is a single function. ``obstore`` already
|
|
4
|
+
dispatches on URL scheme (``file://``, ``s3://``, ``gs://``, ``az://``,
|
|
5
|
+
``memory://``) and ``zarr.storage.ObjectStore`` wraps it for the async zarr path.
|
|
6
|
+
So Phase 0 (local ``file://``) and Phase 1 (``s3://...``) differ only in the URL --
|
|
7
|
+
no hot-path code change, and the read path stays pure Rust (no fsspec layer).
|
|
8
|
+
|
|
9
|
+
We deliberately do *not* route through fsspec / universal_pathlib on the read
|
|
10
|
+
hot path: the entire thesis is that obstore wins by bypassing the fsspec/s3fs
|
|
11
|
+
Python layer. (obstore.fsspec exists if path-style ergonomics are ever wanted
|
|
12
|
+
off the hot path -- but not here.)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
from urllib.parse import urlparse
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import obstore
|
|
22
|
+
import zarr
|
|
23
|
+
import zarr.storage
|
|
24
|
+
|
|
25
|
+
from .types import ArrayGeometry
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def store_from_url(
|
|
29
|
+
url: str, *, read_only: bool = True, **kwargs: object
|
|
30
|
+
) -> zarr.storage.ObjectStore:
|
|
31
|
+
"""Return a zarr ObjectStore for ``url`` (any obstore-supported scheme).
|
|
32
|
+
|
|
33
|
+
``file:///abs/path.zarr`` for local; ``s3://bucket/path.zarr`` for cloud.
|
|
34
|
+
Extra ``kwargs`` pass through to ``obstore.store.from_url`` (region,
|
|
35
|
+
credentials, client options, ...).
|
|
36
|
+
"""
|
|
37
|
+
obs = obstore.store.from_url(url, **kwargs)
|
|
38
|
+
return zarr.storage.ObjectStore(obs, read_only=read_only)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def ensure_local_dir(url: str) -> str:
|
|
42
|
+
"""For a ``file://`` URL, create the target directory so writes can land.
|
|
43
|
+
|
|
44
|
+
obstore's LocalStore will not create the prefix for you. No-op for non-file
|
|
45
|
+
schemes. Returns the URL unchanged for chaining.
|
|
46
|
+
"""
|
|
47
|
+
parsed = urlparse(url)
|
|
48
|
+
if parsed.scheme in ("", "file"):
|
|
49
|
+
os.makedirs(parsed.path, exist_ok=True)
|
|
50
|
+
return url
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def open_geometries(
|
|
54
|
+
url: str,
|
|
55
|
+
variables: list[str] | None = None,
|
|
56
|
+
**kwargs: object,
|
|
57
|
+
) -> dict[str, ArrayGeometry]:
|
|
58
|
+
"""Introspect a zarr group at ``url`` into ``{name: ArrayGeometry}``.
|
|
59
|
+
|
|
60
|
+
Lets ``InSituDataset`` be built from a URL alone -- geometry (shape, chunks,
|
|
61
|
+
dtype) is read from the array metadata rather than hand-specified.
|
|
62
|
+
"""
|
|
63
|
+
store = store_from_url(url, **kwargs)
|
|
64
|
+
group = zarr.open_group(store=store, mode="r")
|
|
65
|
+
names = variables if variables is not None else [k for k, _ in group.arrays()]
|
|
66
|
+
out: dict[str, ArrayGeometry] = {}
|
|
67
|
+
for name in names:
|
|
68
|
+
arr = group[name]
|
|
69
|
+
out[name] = ArrayGeometry(
|
|
70
|
+
name=name,
|
|
71
|
+
shape=tuple(arr.shape),
|
|
72
|
+
chunks=tuple(arr.chunks),
|
|
73
|
+
dtype=np.dtype(arr.dtype),
|
|
74
|
+
)
|
|
75
|
+
return out
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Core data types shared across the insitubatch engine.
|
|
2
|
+
|
|
3
|
+
The central design choice (see DESIGN.md): the unit of work is neither the
|
|
4
|
+
*sample* nor the *chunk* in isolation, but a **read plan** that maps the samples
|
|
5
|
+
required for a step to a *deduplicated* set of chunk reads. This lets the same
|
|
6
|
+
scheduler serve the whole spectrum from fat chunks (many samples per chunk,
|
|
7
|
+
shared-cache wins) to the degenerate GRIB-per-timestep case (one sample per
|
|
8
|
+
chunk, async fan-out is everything).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from enum import StrEnum
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SplitName(StrEnum):
|
|
20
|
+
TRAIN = "train"
|
|
21
|
+
VAL = "val"
|
|
22
|
+
TEST = "test"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True, slots=True)
|
|
26
|
+
class ArrayGeometry:
|
|
27
|
+
"""The minimal geometry the engine needs about one zarr array.
|
|
28
|
+
|
|
29
|
+
We only model the *sample axis* (the outer dimension, axis 0 by convention:
|
|
30
|
+
time for ERA5/HRRR) explicitly, because that is the axis we split, shuffle,
|
|
31
|
+
and batch along. The trailing dims are carried opaquely as ``inner_shape``
|
|
32
|
+
and are kept contiguous to preserve partial zero-copy.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
name: str
|
|
36
|
+
shape: tuple[int, ...]
|
|
37
|
+
chunks: tuple[int, ...]
|
|
38
|
+
dtype: np.dtype
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def n_samples(self) -> int:
|
|
42
|
+
"""Length of the sample (outer) axis."""
|
|
43
|
+
return self.shape[0]
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def sample_chunk_size(self) -> int:
|
|
47
|
+
"""How many samples live in one chunk along the sample axis."""
|
|
48
|
+
return self.chunks[0]
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def inner_shape(self) -> tuple[int, ...]:
|
|
52
|
+
"""Shape of a single sample (everything past the sample axis)."""
|
|
53
|
+
return self.shape[1:]
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def n_chunks(self) -> int:
|
|
57
|
+
"""Number of chunks along the sample axis."""
|
|
58
|
+
return -(-self.n_samples // self.sample_chunk_size) # ceil div
|
|
59
|
+
|
|
60
|
+
def chunk_of(self, sample_index: int) -> int:
|
|
61
|
+
"""Which sample-axis chunk a given sample index falls in."""
|
|
62
|
+
return sample_index // self.sample_chunk_size
|
|
63
|
+
|
|
64
|
+
def samples_in_chunk(self, chunk_index: int) -> range:
|
|
65
|
+
"""The half-open range of global sample indices in ``chunk_index``."""
|
|
66
|
+
start = chunk_index * self.sample_chunk_size
|
|
67
|
+
stop = min(start + self.sample_chunk_size, self.n_samples)
|
|
68
|
+
return range(start, stop)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass(frozen=True, slots=True)
|
|
72
|
+
class ChunkRead:
|
|
73
|
+
"""A single chunk to fetch, addressed along the sample axis.
|
|
74
|
+
|
|
75
|
+
``array`` names which zarr array (variable) this read belongs to; a training
|
|
76
|
+
sample that concatenates several variables produces one ``ChunkRead`` per
|
|
77
|
+
variable that must be co-scheduled.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
array: str
|
|
81
|
+
chunk_index: int
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(slots=True)
|
|
85
|
+
class DecodedChunk:
|
|
86
|
+
"""A decoded, in-memory chunk, keyed by its read.
|
|
87
|
+
|
|
88
|
+
``data`` has shape ``(n_samples_in_chunk, *inner_shape)``. The buffer holds a
|
|
89
|
+
bounded number of these; memory overhead is O(in-flight chunks), independent
|
|
90
|
+
of batch size.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
read: ChunkRead
|
|
94
|
+
data: np.ndarray
|
|
95
|
+
sample_offset: int # global sample index of data[0]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass(slots=True)
|
|
99
|
+
class Batch:
|
|
100
|
+
"""A model-ready batch.
|
|
101
|
+
|
|
102
|
+
``arrays`` maps variable name -> stacked array of shape ``(batch, *inner)``.
|
|
103
|
+
``sample_indices`` records provenance (which global samples, in order) for
|
|
104
|
+
determinism checks and resumption.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
arrays: dict[str, np.ndarray]
|
|
108
|
+
sample_indices: np.ndarray = field(default_factory=lambda: np.empty(0, dtype=np.int64))
|