insitubatch 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,44 @@
1
+ """insitubatch -- train in place on n-dimensional cloud tensors.
2
+
3
+ The loader-orchestration layer that sits on top of *already-solved* async cloud
4
+ IO (obstore / zarr v3 / icechunk): turns an existing Zarr archive into a
5
+ shuffled, split-aware, GPU-saturating PyTorch source with no reshard and a
6
+ Python hot path that scales with chunks, not samples.
7
+
8
+ See DESIGN.md for the full rationale.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from .buffer import BufferConfig, ShuffleBlockBuffer
14
+ from .io import AsyncChunkReader, IOConfig
15
+ from .plan import ReadPlan, build_read_plan, dedup_ratio
16
+ from .shuffle import block_shuffled_order, chunk_permutation, shuffle_quality
17
+ from .split import SplitManifest, split_by_chunk
18
+ from .store import ensure_local_dir, open_geometries, store_from_url
19
+ from .types import ArrayGeometry, Batch, ChunkRead, DecodedChunk, SplitName
20
+
21
+ __version__ = "0.0.1"
22
+
23
+ __all__ = [
24
+ "ArrayGeometry",
25
+ "AsyncChunkReader",
26
+ "Batch",
27
+ "BufferConfig",
28
+ "ChunkRead",
29
+ "DecodedChunk",
30
+ "IOConfig",
31
+ "ReadPlan",
32
+ "ShuffleBlockBuffer",
33
+ "SplitManifest",
34
+ "SplitName",
35
+ "block_shuffled_order",
36
+ "build_read_plan",
37
+ "chunk_permutation",
38
+ "dedup_ratio",
39
+ "ensure_local_dir",
40
+ "open_geometries",
41
+ "shuffle_quality",
42
+ "split_by_chunk",
43
+ "store_from_url",
44
+ ]
insitubatch/buffer.py ADDED
@@ -0,0 +1,88 @@
1
+ """Bounded shuffle-block buffer: residency + batch assembly.
2
+
3
+ Holds decoded chunks for a window of ``block_chunks`` chunks, draws shuffled
4
+ batches across that window, and evicts chunks once fully drained. This is the
5
+ memory-bounding component: peak residency is O(block_chunks), independent of the
6
+ number of samples per epoch or the batch size.
7
+
8
+ The batch assembly does a single coalesced gather per variable (one fancy-index
9
+ copy), never a Python per-sample loop -- the constraint David's S3 benchmark
10
+ imposed (Python per-chunk overhead bounds throughput).
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from collections import deque
16
+ from dataclasses import dataclass, field
17
+
18
+ import numpy as np
19
+
20
+ from .types import Batch, DecodedChunk
21
+
22
+
23
+ @dataclass(slots=True)
24
+ class BufferConfig:
25
+ block_chunks: int = 16
26
+ """Window size in chunks. Larger == better shuffle, more memory."""
27
+
28
+ batch_size: int = 32
29
+
30
+
31
+ @dataclass(slots=True)
32
+ class ShuffleBlockBuffer:
33
+ """Accumulates decoded chunks and emits shuffled, coalesced batches."""
34
+
35
+ config: BufferConfig
36
+ seed: int = 0
37
+ _chunks: dict[tuple[str, int], DecodedChunk] = field(default_factory=dict)
38
+ # Pending draws: rows of (array, chunk_index, within) flattened per variable.
39
+ _pending: deque[int] = field(default_factory=deque)
40
+
41
+ def add(self, chunk: DecodedChunk) -> None:
42
+ self._chunks[(chunk.read.array, chunk.read.chunk_index)] = chunk
43
+
44
+ def ready(self) -> bool:
45
+ """Enough buffered to safely emit a well-mixed batch."""
46
+ return len(self._chunks) >= self.config.block_chunks
47
+
48
+ def gather_batch(
49
+ self,
50
+ rows: np.ndarray,
51
+ variables: list[str],
52
+ sample_chunk_size: int,
53
+ ) -> Batch:
54
+ """Assemble one batch from ``rows`` of ``[chunk_id, within]`` draws.
55
+
56
+ ``rows`` are pre-shuffled draw coordinates (see shuffle.block_shuffled_order).
57
+ Draws are grouped by chunk so each resident array is touched once (one
58
+ coalesced fancy-index per chunk); ``data`` and ``sample_indices`` are
59
+ emitted in the *same* grouped order so row ``i`` of every variable and
60
+ ``sample_indices[i]`` refer to the same sample. Intra-batch order is thus
61
+ grouped-by-chunk -- irrelevant for training, and the cross-batch shuffle
62
+ is preserved.
63
+
64
+ ``sample_chunk_size`` is the array's true chunk length (from geometry),
65
+ used to recover global sample indices -- NOT inferred from a resident
66
+ chunk, which may be a short final chunk.
67
+ """
68
+ chunk_ids = rows[:, 0]
69
+ within = rows[:, 1]
70
+ uniq = np.unique(chunk_ids)
71
+
72
+ out: dict[str, list[np.ndarray]] = {v: [] for v in variables}
73
+ idx_pieces: list[np.ndarray] = []
74
+ for cid in uniq:
75
+ w = within[chunk_ids == cid]
76
+ idx_pieces.append(cid * sample_chunk_size + w)
77
+ for var in variables:
78
+ out[var].append(self._chunks[(var, int(cid))].data[w])
79
+
80
+ arrays = {var: np.concatenate(pieces, axis=0) for var, pieces in out.items()}
81
+ return Batch(arrays=arrays, sample_indices=np.concatenate(idx_pieces))
82
+
83
+ def evict_drained(self, still_needed: set[tuple[str, int]]) -> int:
84
+ """Drop chunks no longer referenced by any pending draw. Returns count."""
85
+ drop = [k for k in self._chunks if k not in still_needed]
86
+ for k in drop:
87
+ del self._chunks[k]
88
+ return len(drop)
insitubatch/io.py ADDED
@@ -0,0 +1,172 @@
1
+ """Async IO driver: the obstore win.
2
+
3
+ This is where insitubatch *stands on* solved cloud IO rather than reinventing
4
+ it. A single dedicated asyncio event loop (one OS thread) issues many concurrent
5
+ chunk reads against a zarr v3 store -- ideally the obstore-backed store, whose
6
+ ``get_ranges_async`` coalesces concurrent range requests in a single coroutine
7
+ and saturates the NIC without spawning Python threads per request.
8
+
9
+ Design rules (DESIGN.md, "the inversion"):
10
+ * Parallelism lives *here*, in the event loop's concurrency, NOT in
11
+ torch.DataLoader worker processes. The torch surface runs with
12
+ ``num_workers=0``.
13
+ * A bounded in-flight window (semaphore of ``max_inflight`` chunks) caps memory
14
+ at O(in-flight chunks), independent of batch size.
15
+ * Decode releases the GIL (numcodecs C codecs do) so decode overlaps IO. The
16
+ Python hot path stays O(reads); never decode per-sample in Python.
17
+
18
+ Status: SKELETON. The async store wiring is stubbed at the marked TODOs so the
19
+ module imports and the control flow is reviewable without a live store.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import asyncio
25
+ import queue
26
+ import threading
27
+ from collections.abc import Iterator
28
+ from dataclasses import dataclass
29
+
30
+ import numpy as np
31
+ import zarr.api.asynchronous as za
32
+
33
+ from .plan import ReadPlan
34
+ from .store import store_from_url
35
+ from .types import ArrayGeometry, ChunkRead, DecodedChunk
36
+
37
+
38
+ @dataclass(slots=True)
39
+ class IOConfig:
40
+ max_inflight: int = 16
41
+ """Upper bound on chunks in flight. Memory ~= max_inflight * chunk_nbytes."""
42
+
43
+ decode_threads: int = 4
44
+ """Worker threads for GIL-releasing decode, overlapped with IO."""
45
+
46
+
47
+ class AsyncChunkReader:
48
+ """Owns one asyncio event loop on a background thread and fans out reads.
49
+
50
+ The public API is deliberately *synchronous and iterator-shaped* so the rest
51
+ of the engine (buffer, torch surface) needs no async knowledge -- the loop is
52
+ an implementation detail hidden behind a thread-safe queue.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ store_url: str,
58
+ geometries: dict[str, ArrayGeometry],
59
+ config: IOConfig | None = None,
60
+ **store_kwargs: object,
61
+ ) -> None:
62
+ self._url = store_url
63
+ self._store_kwargs = store_kwargs
64
+ self._geometries = geometries
65
+ self._config = config or IOConfig()
66
+ self._arrays: dict[str, za.AsyncArray] = {} # opened lazily on the loop
67
+ self._loop = asyncio.new_event_loop()
68
+ self._sem: asyncio.Semaphore | None = None # created on the loop bootstrap
69
+ self._open_lock: asyncio.Lock | None = None
70
+ self._ready = threading.Event()
71
+ self._thread = threading.Thread(target=self._run_loop, daemon=True, name="insitu-io")
72
+ self._thread.start()
73
+ self._ready.wait() # don't return until the loop + primitives exist
74
+
75
+ # -- loop lifecycle -----------------------------------------------------
76
+
77
+ def _run_loop(self) -> None:
78
+ asyncio.set_event_loop(self._loop)
79
+ self._sem = asyncio.Semaphore(self._config.max_inflight)
80
+ self._open_lock = asyncio.Lock()
81
+ self._loop.call_soon(self._ready.set)
82
+ self._loop.run_forever()
83
+
84
+ def close(self) -> None:
85
+ self._loop.call_soon_threadsafe(self._loop.stop)
86
+ self._thread.join(timeout=5)
87
+
88
+ def __enter__(self) -> AsyncChunkReader:
89
+ return self
90
+
91
+ def __exit__(self, *exc: object) -> None:
92
+ self.close()
93
+
94
+ # -- public, synchronous surface ---------------------------------------
95
+
96
+ _SENTINEL = object()
97
+
98
+ def read_plan(self, plan: ReadPlan) -> Iterator[DecodedChunk]:
99
+ """Fetch every read in ``plan`` concurrently; yield decoded chunks ASAP.
100
+
101
+ Yields in completion order (not plan order) so a slow chunk never stalls
102
+ the others -- the buffer downstream is responsible for reordering/gather.
103
+
104
+ The bridge uses a thread-safe ``queue.Queue`` and a done-callback so the
105
+ sentinel is *always* delivered -- even if the driver raises -- and any
106
+ exception is re-raised on the caller's thread rather than deadlocking it.
107
+ """
108
+ out_q: queue.Queue = queue.Queue()
109
+
110
+ def _on_done(fut: object) -> None:
111
+ try:
112
+ fut.result() # type: ignore[attr-defined] # surface driver errors
113
+ except Exception as exc: # noqa: BLE001 - forwarded to the consumer
114
+ out_q.put(exc)
115
+ finally:
116
+ out_q.put(self._SENTINEL)
117
+
118
+ fut = asyncio.run_coroutine_threadsafe(self._drive(plan, out_q), self._loop)
119
+ fut.add_done_callback(_on_done)
120
+
121
+ while True:
122
+ item = out_q.get()
123
+ if item is self._SENTINEL:
124
+ break
125
+ if isinstance(item, Exception):
126
+ raise item
127
+ yield item
128
+
129
+ # -- async internals ----------------------------------------------------
130
+
131
+ async def _drive(self, plan: ReadPlan, out_q: queue.Queue) -> None:
132
+ assert self._sem is not None # guaranteed by _ready.wait() in __init__
133
+ await self._ensure_arrays()
134
+
135
+ async def one(read: ChunkRead) -> None:
136
+ async with self._sem: # bound in-flight chunks -> bounded memory
137
+ decoded = await self._fetch_and_decode(read)
138
+ out_q.put(decoded) # stdlib queue: thread-safe, non-blocking
139
+
140
+ await asyncio.gather(*(one(r) for r in plan.reads))
141
+
142
+ async def _ensure_arrays(self) -> None:
143
+ """Open one AsyncArray per variable, once, sharing the store."""
144
+ if self._arrays:
145
+ return
146
+ assert self._open_lock is not None
147
+ async with self._open_lock:
148
+ if self._arrays: # double-checked: another coroutine may have won
149
+ return
150
+ store = store_from_url(self._url, **self._store_kwargs) # type: ignore[arg-type]
151
+ for name in self._geometries:
152
+ self._arrays[name] = await za.open_array(store=store, path=name, mode="r")
153
+
154
+ async def _fetch_and_decode(self, read: ChunkRead) -> DecodedChunk:
155
+ """Fetch + decode one chunk via the zarr v3 async codec pipeline.
156
+
157
+ The selection is exactly one chunk along the sample axis, full on the
158
+ inner dims (the v1 sample-geometry contract). zarr fans the underlying
159
+ byte-range reads out through obstore and runs the decode pipeline; for
160
+ single-chunk inner dims this touches exactly one stored chunk.
161
+
162
+ TODO(perf): decode currently runs inside zarr's pipeline on the loop. If
163
+ it shows up as a GIL bottleneck, move it to a thread pool
164
+ (numcodecs C codecs release the GIL) -- the bounded-fan-out structure
165
+ here already isolates that change to this method.
166
+ """
167
+ geom = self._geometries[read.array]
168
+ arr = self._arrays[read.array]
169
+ samples = geom.samples_in_chunk(read.chunk_index)
170
+ selection = (slice(samples.start, samples.stop), *(slice(None) for _ in geom.inner_shape))
171
+ block = await arr.getitem(selection)
172
+ return DecodedChunk(read=read, data=np.asarray(block), sample_offset=samples.start)
insitubatch/plan.py ADDED
@@ -0,0 +1,112 @@
1
+ """Read planning: samples -> deduplicated chunk reads.
2
+
3
+ This is the crux abstraction. Given the samples required for a window of the
4
+ epoch and the array geometries, produce the *minimal* set of chunk reads plus a
5
+ gather map describing where each sample lives once those chunks are decoded.
6
+
7
+ Why this matters (DESIGN.md, "the spectrum"):
8
+ - Fat chunks: many samples share one chunk -> dedup collapses N samples to 1
9
+ read; the shared decoded chunk is gathered N times. This is the shared-cache
10
+ win that the classic per-worker DataLoader cannot get.
11
+ - GRIB-per-timestep: one sample per chunk -> no dedup possible, but the plan
12
+ still drives a single wide async fan-out (B samples == B concurrent reads),
13
+ which is exactly where obstore earns its keep.
14
+
15
+ The Python hot path here is O(reads), never O(samples) once gathered, which is
16
+ the constraint David's S3 benchmark imposed (Python per-chunk overhead bounds
17
+ throughput; never loop per-sample in Python).
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from collections.abc import Sequence
23
+ from dataclasses import dataclass
24
+
25
+ import numpy as np
26
+
27
+ from .types import ArrayGeometry, ChunkRead
28
+
29
+
30
+ @dataclass(slots=True)
31
+ class Gather:
32
+ """Where one sample lives within a decoded chunk.
33
+
34
+ ``read_index`` indexes into ``ReadPlan.reads``; ``within`` is the offset of
35
+ the sample inside that chunk's decoded array along the sample axis.
36
+ """
37
+
38
+ read_index: int
39
+ within: int
40
+
41
+
42
+ @dataclass(slots=True)
43
+ class ReadPlan:
44
+ """A deduplicated batch of chunk reads plus the gather map back to samples.
45
+
46
+ One ``ReadPlan`` typically covers enough samples to (a) saturate the async
47
+ fan-out and (b) fill the shuffle-block buffer. ``reads`` is what the IO
48
+ driver fetches; ``gathers[v]`` reconstructs the requested samples for
49
+ variable ``v`` from the decoded chunks.
50
+ """
51
+
52
+ reads: list[ChunkRead]
53
+ gathers: dict[str, list[Gather]]
54
+ sample_indices: np.ndarray # global sample indices, in requested order
55
+
56
+ @property
57
+ def n_reads(self) -> int:
58
+ return len(self.reads)
59
+
60
+
61
+ def build_read_plan(
62
+ sample_indices: Sequence[int],
63
+ geometries: dict[str, ArrayGeometry],
64
+ ) -> ReadPlan:
65
+ """Build a deduplicated read plan for ``sample_indices`` across all variables.
66
+
67
+ All variables are assumed aligned on the sample axis (same length, possibly
68
+ different chunking) -- the common case for co-registered NWP variables. A
69
+ sample at global index ``s`` requires chunk ``geom.chunk_of(s)`` from *each*
70
+ variable; identical chunks requested by multiple samples are read once.
71
+
72
+ Parameters
73
+ ----------
74
+ sample_indices:
75
+ Global sample-axis indices to fetch, in the order they should appear.
76
+ geometries:
77
+ Variable name -> :class:`ArrayGeometry`.
78
+
79
+ Returns
80
+ -------
81
+ ReadPlan
82
+ """
83
+ idx = np.asarray(sample_indices, dtype=np.int64)
84
+ reads: list[ChunkRead] = []
85
+ read_lookup: dict[ChunkRead, int] = {}
86
+ gathers: dict[str, list[Gather]] = {name: [] for name in geometries}
87
+
88
+ for name, geom in geometries.items():
89
+ # Vectorized chunk assignment for this variable across all samples.
90
+ chunk_ids = idx // geom.sample_chunk_size
91
+ within = idx - chunk_ids * geom.sample_chunk_size
92
+ for c, w in zip(chunk_ids.tolist(), within.tolist(), strict=True):
93
+ read = ChunkRead(array=name, chunk_index=int(c))
94
+ ri = read_lookup.get(read)
95
+ if ri is None:
96
+ ri = len(reads)
97
+ read_lookup[read] = ri
98
+ reads.append(read)
99
+ gathers[name].append(Gather(read_index=ri, within=int(w)))
100
+
101
+ return ReadPlan(reads=reads, gathers=gathers, sample_indices=idx)
102
+
103
+
104
+ def dedup_ratio(plan: ReadPlan) -> float:
105
+ """Samples-per-read averaged over variables.
106
+
107
+ 1.0 == degenerate (GRIB-per-timestep, no sharing); higher == fatter chunks
108
+ with more cache reuse. A quick lever for understanding which regime a dataset
109
+ + batch size lands in.
110
+ """
111
+ n_samples = len(plan.sample_indices) * max(len(plan.gathers), 1)
112
+ return n_samples / plan.n_reads if plan.n_reads else 0.0
insitubatch/py.typed ADDED
File without changes
insitubatch/shuffle.py ADDED
@@ -0,0 +1,74 @@
1
+ """Approximate-global shuffle for chunk-aligned data.
2
+
3
+ True global shuffle is incompatible with chunk-aligned, low-copy reads: it would
4
+ demand a random chunk per sample. The compromise (DESIGN.md, "shuffle"), adapted
5
+ from MosaicML Streaming's shuffle-block algorithms (py1e / py1br), is two-level:
6
+
7
+ 1. **Chunk permutation** -- shuffle the *order chunks are scheduled* each epoch.
8
+ 2. **Shuffle-block buffer** -- hold samples from a window of B chunks and draw
9
+ batches across the whole window, so samples from different chunks interleave.
10
+
11
+ Setting the block span B >= ~10x the samples-per-chunk yields shuffle quality
12
+ close to global, at memory cost O(B chunks). B is the single quality<->memory
13
+ knob. This module owns the *index math*; buffer.py owns the residency.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import numpy as np
19
+
20
+
21
+ def chunk_permutation(chunk_ids: np.ndarray, *, seed: int, epoch: int) -> np.ndarray:
22
+ """Deterministically permute chunk ids for one epoch.
23
+
24
+ Determinism is keyed on (seed, epoch) only -- not on world size or worker
25
+ count -- so a run is reproducible and resumable across hardware (the
26
+ "canonical" property from MosaicML).
27
+ """
28
+ rng = np.random.default_rng((seed, epoch))
29
+ return rng.permutation(chunk_ids)
30
+
31
+
32
+ def block_shuffled_order(
33
+ chunk_ids: np.ndarray,
34
+ samples_per_chunk: int,
35
+ *,
36
+ block_chunks: int,
37
+ seed: int,
38
+ epoch: int,
39
+ ) -> np.ndarray:
40
+ """Produce a globally-ordered list of (chunk_id, within) draws.
41
+
42
+ Emulates the shuffle-block draw order the live buffer will realise, useful
43
+ for the quality harness and for deterministic single-process iteration.
44
+ Returns an array of shape ``(n_samples, 2)`` of ``[chunk_id, within]`` rows.
45
+ """
46
+ perm = chunk_permutation(chunk_ids, seed=seed, epoch=epoch)
47
+ rng = np.random.default_rng((seed, epoch, 7919))
48
+
49
+ rows: list[np.ndarray] = []
50
+ for start in range(0, len(perm), block_chunks):
51
+ block = perm[start : start + block_chunks]
52
+ # Materialise every (chunk, within) pair in this block, then shuffle.
53
+ cc = np.repeat(block, samples_per_chunk)
54
+ ww = np.tile(np.arange(samples_per_chunk), len(block))
55
+ pairs = np.stack([cc, ww], axis=1)
56
+ rng.shuffle(pairs) # in-place, along axis 0
57
+ rows.append(pairs)
58
+ return np.concatenate(rows, axis=0)
59
+
60
+
61
+ def shuffle_quality(order: np.ndarray, samples_per_chunk: int) -> float:
62
+ """A 0..1 score for how well an emitted order mixes the source.
63
+
64
+ Heuristic: the mean absolute *source-rank* gap between consecutive emitted
65
+ samples, normalised by the gap a perfect global shuffle would give. 1.0 ~=
66
+ global; values near 0 mean adjacent samples still come out near each other
67
+ (poor mixing). Cheap to compute, good enough to tune ``block_chunks``.
68
+ """
69
+ source_rank = order[:, 0] * samples_per_chunk + order[:, 1]
70
+ gaps = np.abs(np.diff(source_rank.astype(np.int64)))
71
+ n = len(source_rank)
72
+ # Expected mean gap of a uniform random permutation of 0..n-1 is ~n/3.
73
+ expected = n / 3.0
74
+ return float(min(gaps.mean() / expected, 1.0)) if n > 1 and expected else 0.0
insitubatch/source.py ADDED
@@ -0,0 +1,118 @@
1
+ """Torch handoff surface.
2
+
3
+ Ties the pieces together and exposes them to PyTorch *without* using the classic
4
+ DataLoader worker model. Parallelism lives in :class:`AsyncChunkReader`'s event
5
+ loop, so the recommended configuration is::
6
+
7
+ loader = DataLoader(InSituDataset(...), batch_size=None, num_workers=0)
8
+
9
+ ``batch_size=None`` because the dataset already yields assembled batches;
10
+ ``num_workers=0`` because forking workers would re-introduce exactly the
11
+ redundant-read / nested-parallelism problems we set out to avoid.
12
+
13
+ torch (and torchdata.nodes) are optional imports so the core engine stays
14
+ framework-agnostic and importable on a box without torch installed.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from collections.abc import Iterator
20
+
21
+ import numpy as np
22
+
23
+ from .buffer import BufferConfig, ShuffleBlockBuffer
24
+ from .io import AsyncChunkReader, IOConfig
25
+ from .plan import build_read_plan
26
+ from .shuffle import block_shuffled_order
27
+ from .split import SplitManifest
28
+ from .store import open_geometries
29
+ from .types import ArrayGeometry, Batch, SplitName
30
+
31
+ try: # optional torch surface
32
+ from torch.utils.data import IterableDataset
33
+
34
+ _HAS_TORCH = True
35
+ except ImportError: # pragma: no cover - exercised on torch-less installs
36
+ IterableDataset = object # type: ignore[assignment,misc]
37
+ _HAS_TORCH = False
38
+
39
+
40
+ class InSituDataset(IterableDataset): # type: ignore[misc]
41
+ """An IterableDataset that streams shuffled batches from a Zarr archive.
42
+
43
+ One epoch = permute the split's chunks -> walk shuffle-blocks -> for each
44
+ block, async-fetch its chunks, fill the buffer, emit coalesced batches.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ store_url: str,
50
+ manifest: SplitManifest,
51
+ geometries: dict[str, ArrayGeometry] | None = None,
52
+ split: SplitName = SplitName.TRAIN,
53
+ *,
54
+ batch_size: int = 32,
55
+ block_chunks: int = 16,
56
+ max_inflight: int = 16,
57
+ seed: int = 0,
58
+ to_tensor: bool = True,
59
+ **store_kwargs: object,
60
+ ) -> None:
61
+ self.store_url = store_url
62
+ self.store_kwargs = store_kwargs
63
+ self.geometries = geometries if geometries is not None else open_geometries(store_url)
64
+ self.manifest = manifest
65
+ self.split = split
66
+ self.variables = list(self.geometries)
67
+ self.io_config = IOConfig(max_inflight=max_inflight)
68
+ self.buffer_config = BufferConfig(block_chunks=block_chunks, batch_size=batch_size)
69
+ self.seed = seed
70
+ self.to_tensor = to_tensor and _HAS_TORCH
71
+ self._epoch = 0
72
+
73
+ def set_epoch(self, epoch: int) -> None:
74
+ """Call from the training loop so each epoch reshuffles deterministically."""
75
+ self._epoch = epoch
76
+
77
+ def __iter__(self) -> Iterator[Batch | dict]:
78
+ geom = self.geometries[self.variables[0]]
79
+ chunk_ids = np.asarray(self.manifest.chunks[self.split.value], dtype=np.int64)
80
+ spc = geom.sample_chunk_size
81
+
82
+ order = block_shuffled_order(
83
+ chunk_ids,
84
+ spc,
85
+ block_chunks=self.buffer_config.block_chunks,
86
+ seed=self.seed,
87
+ epoch=self._epoch,
88
+ )
89
+
90
+ with AsyncChunkReader(
91
+ self.store_url, self.geometries, self.io_config, **self.store_kwargs
92
+ ) as reader:
93
+ buf = ShuffleBlockBuffer(self.buffer_config, seed=self.seed)
94
+ bs = self.buffer_config.batch_size
95
+
96
+ # Walk the emitted draw order in batch-sized windows. For each window
97
+ # we ensure its chunks are resident (async fan-out), then gather.
98
+ for start in range(0, len(order), bs):
99
+ rows = order[start : start + bs]
100
+ needed = {(v, int(c)) for v in self.variables for c in np.unique(rows[:, 0])}
101
+ missing_samples = [
102
+ int(c) * spc for (v, c) in needed if (v, c) not in buf._chunks
103
+ ]
104
+ if missing_samples:
105
+ plan = build_read_plan(sorted(set(missing_samples)), self.geometries)
106
+ for decoded in reader.read_plan(plan):
107
+ buf.add(decoded)
108
+
109
+ batch = buf.gather_batch(rows, self.variables, spc)
110
+ yield self._maybe_tensor(batch)
111
+ buf.evict_drained(needed)
112
+
113
+ def _maybe_tensor(self, batch: Batch) -> Batch | dict:
114
+ if not self.to_tensor:
115
+ return batch
116
+ import torch
117
+
118
+ return {k: torch.from_numpy(v) for k, v in batch.arrays.items()}
insitubatch/split.py ADDED
@@ -0,0 +1,97 @@
1
+ """Chunk-aligned train/val/test splits.
2
+
3
+ Splits are done *ahead of time* and at *chunk granularity* along the sample
4
+ axis. Two reasons (DESIGN.md, "splits"):
5
+
6
+ 1. Leakage: splitting mid-chunk would scatter temporally adjacent, highly
7
+ autocorrelated samples across train and val. Chunk-aligned boundaries keep
8
+ a contiguous block of time in a single split.
9
+ 2. Zero-copy: a split that respects chunk boundaries means every read serves
10
+ exactly one split, so the engine never decodes a chunk and throws half of
11
+ it away.
12
+
13
+ The manifest is a plain, serializable record of which chunk indices belong to
14
+ which split, so a run is reproducible and shareable.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ from dataclasses import asdict, dataclass
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+
25
+ from .types import ArrayGeometry, SplitName
26
+
27
+
28
+ @dataclass(slots=True)
29
+ class SplitManifest:
30
+ """Which sample-axis chunk indices belong to each split."""
31
+
32
+ n_chunks: int
33
+ sample_chunk_size: int
34
+ n_samples: int
35
+ chunks: dict[str, list[int]] # SplitName.value -> sorted chunk indices
36
+ seed: int
37
+
38
+ def sample_indices(self, split: SplitName, geom: ArrayGeometry) -> np.ndarray:
39
+ """Expand a split's chunks into the global sample indices they contain."""
40
+ out: list[int] = []
41
+ for c in self.chunks[split.value]:
42
+ out.extend(geom.samples_in_chunk(c))
43
+ return np.asarray(out, dtype=np.int64)
44
+
45
+ def to_json(self, path: str | Path) -> None:
46
+ Path(path).write_text(json.dumps(asdict(self), indent=2))
47
+
48
+ @classmethod
49
+ def from_json(cls, path: str | Path) -> SplitManifest:
50
+ return cls(**json.loads(Path(path).read_text()))
51
+
52
+
53
+ def split_by_chunk(
54
+ geom: ArrayGeometry,
55
+ *,
56
+ fractions: tuple[float, float, float] = (0.8, 0.1, 0.1),
57
+ seed: int = 0,
58
+ contiguous: bool = True,
59
+ ) -> SplitManifest:
60
+ """Partition a variable's sample-axis chunks into train/val/test.
61
+
62
+ Parameters
63
+ ----------
64
+ fractions:
65
+ (train, val, test) fractions of *chunks* (not samples). Must sum to ~1.
66
+ contiguous:
67
+ If True (default), assign contiguous blocks of chunks to each split --
68
+ the safest choice for time series, where a randomly interleaved split
69
+ still risks leakage through autocorrelation across chunk boundaries. If
70
+ False, chunks are shuffled before partitioning (acceptable when samples
71
+ are exchangeable, e.g. independent scenes).
72
+ """
73
+ if abs(sum(fractions) - 1.0) > 1e-6:
74
+ raise ValueError(f"fractions must sum to 1.0, got {fractions} -> {sum(fractions)}")
75
+
76
+ n = geom.n_chunks
77
+ order = np.arange(n)
78
+ if not contiguous:
79
+ order = np.random.default_rng(seed).permutation(n)
80
+
81
+ n_train = int(round(fractions[0] * n))
82
+ n_val = int(round(fractions[1] * n))
83
+ train = sorted(order[:n_train].tolist())
84
+ val = sorted(order[n_train : n_train + n_val].tolist())
85
+ test = sorted(order[n_train + n_val :].tolist())
86
+
87
+ return SplitManifest(
88
+ n_chunks=n,
89
+ sample_chunk_size=geom.sample_chunk_size,
90
+ n_samples=geom.n_samples,
91
+ chunks={
92
+ SplitName.TRAIN.value: train,
93
+ SplitName.VAL.value: val,
94
+ SplitName.TEST.value: test,
95
+ },
96
+ seed=seed,
97
+ )
insitubatch/store.py ADDED
@@ -0,0 +1,75 @@
1
+ """Storage shim: one URL, any backend.
2
+
3
+ The whole local-now / cloud-later story is a single function. ``obstore`` already
4
+ dispatches on URL scheme (``file://``, ``s3://``, ``gs://``, ``az://``,
5
+ ``memory://``) and ``zarr.storage.ObjectStore`` wraps it for the async zarr path.
6
+ So Phase 0 (local ``file://``) and Phase 1 (``s3://...``) differ only in the URL --
7
+ no hot-path code change, and the read path stays pure Rust (no fsspec layer).
8
+
9
+ We deliberately do *not* route through fsspec / universal_pathlib on the read
10
+ hot path: the entire thesis is that obstore wins by bypassing the fsspec/s3fs
11
+ Python layer. (obstore.fsspec exists if path-style ergonomics are ever wanted
12
+ off the hot path -- but not here.)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ from urllib.parse import urlparse
19
+
20
+ import numpy as np
21
+ import obstore
22
+ import zarr
23
+ import zarr.storage
24
+
25
+ from .types import ArrayGeometry
26
+
27
+
28
+ def store_from_url(
29
+ url: str, *, read_only: bool = True, **kwargs: object
30
+ ) -> zarr.storage.ObjectStore:
31
+ """Return a zarr ObjectStore for ``url`` (any obstore-supported scheme).
32
+
33
+ ``file:///abs/path.zarr`` for local; ``s3://bucket/path.zarr`` for cloud.
34
+ Extra ``kwargs`` pass through to ``obstore.store.from_url`` (region,
35
+ credentials, client options, ...).
36
+ """
37
+ obs = obstore.store.from_url(url, **kwargs)
38
+ return zarr.storage.ObjectStore(obs, read_only=read_only)
39
+
40
+
41
+ def ensure_local_dir(url: str) -> str:
42
+ """For a ``file://`` URL, create the target directory so writes can land.
43
+
44
+ obstore's LocalStore will not create the prefix for you. No-op for non-file
45
+ schemes. Returns the URL unchanged for chaining.
46
+ """
47
+ parsed = urlparse(url)
48
+ if parsed.scheme in ("", "file"):
49
+ os.makedirs(parsed.path, exist_ok=True)
50
+ return url
51
+
52
+
53
+ def open_geometries(
54
+ url: str,
55
+ variables: list[str] | None = None,
56
+ **kwargs: object,
57
+ ) -> dict[str, ArrayGeometry]:
58
+ """Introspect a zarr group at ``url`` into ``{name: ArrayGeometry}``.
59
+
60
+ Lets ``InSituDataset`` be built from a URL alone -- geometry (shape, chunks,
61
+ dtype) is read from the array metadata rather than hand-specified.
62
+ """
63
+ store = store_from_url(url, **kwargs)
64
+ group = zarr.open_group(store=store, mode="r")
65
+ names = variables if variables is not None else [k for k, _ in group.arrays()]
66
+ out: dict[str, ArrayGeometry] = {}
67
+ for name in names:
68
+ arr = group[name]
69
+ out[name] = ArrayGeometry(
70
+ name=name,
71
+ shape=tuple(arr.shape),
72
+ chunks=tuple(arr.chunks),
73
+ dtype=np.dtype(arr.dtype),
74
+ )
75
+ return out
insitubatch/types.py ADDED
@@ -0,0 +1,108 @@
1
+ """Core data types shared across the insitubatch engine.
2
+
3
+ The central design choice (see DESIGN.md): the unit of work is neither the
4
+ *sample* nor the *chunk* in isolation, but a **read plan** that maps the samples
5
+ required for a step to a *deduplicated* set of chunk reads. This lets the same
6
+ scheduler serve the whole spectrum from fat chunks (many samples per chunk,
7
+ shared-cache wins) to the degenerate GRIB-per-timestep case (one sample per
8
+ chunk, async fan-out is everything).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass, field
14
+ from enum import StrEnum
15
+
16
+ import numpy as np
17
+
18
+
19
+ class SplitName(StrEnum):
20
+ TRAIN = "train"
21
+ VAL = "val"
22
+ TEST = "test"
23
+
24
+
25
+ @dataclass(frozen=True, slots=True)
26
+ class ArrayGeometry:
27
+ """The minimal geometry the engine needs about one zarr array.
28
+
29
+ We only model the *sample axis* (the outer dimension, axis 0 by convention:
30
+ time for ERA5/HRRR) explicitly, because that is the axis we split, shuffle,
31
+ and batch along. The trailing dims are carried opaquely as ``inner_shape``
32
+ and are kept contiguous to preserve partial zero-copy.
33
+ """
34
+
35
+ name: str
36
+ shape: tuple[int, ...]
37
+ chunks: tuple[int, ...]
38
+ dtype: np.dtype
39
+
40
+ @property
41
+ def n_samples(self) -> int:
42
+ """Length of the sample (outer) axis."""
43
+ return self.shape[0]
44
+
45
+ @property
46
+ def sample_chunk_size(self) -> int:
47
+ """How many samples live in one chunk along the sample axis."""
48
+ return self.chunks[0]
49
+
50
+ @property
51
+ def inner_shape(self) -> tuple[int, ...]:
52
+ """Shape of a single sample (everything past the sample axis)."""
53
+ return self.shape[1:]
54
+
55
+ @property
56
+ def n_chunks(self) -> int:
57
+ """Number of chunks along the sample axis."""
58
+ return -(-self.n_samples // self.sample_chunk_size) # ceil div
59
+
60
+ def chunk_of(self, sample_index: int) -> int:
61
+ """Which sample-axis chunk a given sample index falls in."""
62
+ return sample_index // self.sample_chunk_size
63
+
64
+ def samples_in_chunk(self, chunk_index: int) -> range:
65
+ """The half-open range of global sample indices in ``chunk_index``."""
66
+ start = chunk_index * self.sample_chunk_size
67
+ stop = min(start + self.sample_chunk_size, self.n_samples)
68
+ return range(start, stop)
69
+
70
+
71
+ @dataclass(frozen=True, slots=True)
72
+ class ChunkRead:
73
+ """A single chunk to fetch, addressed along the sample axis.
74
+
75
+ ``array`` names which zarr array (variable) this read belongs to; a training
76
+ sample that concatenates several variables produces one ``ChunkRead`` per
77
+ variable that must be co-scheduled.
78
+ """
79
+
80
+ array: str
81
+ chunk_index: int
82
+
83
+
84
+ @dataclass(slots=True)
85
+ class DecodedChunk:
86
+ """A decoded, in-memory chunk, keyed by its read.
87
+
88
+ ``data`` has shape ``(n_samples_in_chunk, *inner_shape)``. The buffer holds a
89
+ bounded number of these; memory overhead is O(in-flight chunks), independent
90
+ of batch size.
91
+ """
92
+
93
+ read: ChunkRead
94
+ data: np.ndarray
95
+ sample_offset: int # global sample index of data[0]
96
+
97
+
98
+ @dataclass(slots=True)
99
+ class Batch:
100
+ """A model-ready batch.
101
+
102
+ ``arrays`` maps variable name -> stacked array of shape ``(batch, *inner)``.
103
+ ``sample_indices`` records provenance (which global samples, in order) for
104
+ determinism checks and resumption.
105
+ """
106
+
107
+ arrays: dict[str, np.ndarray]
108
+ sample_indices: np.ndarray = field(default_factory=lambda: np.empty(0, dtype=np.int64))
@@ -0,0 +1,90 @@
1
+ Metadata-Version: 2.4
2
+ Name: insitubatch
3
+ Version: 0.0.1
4
+ Summary: Train in place on n-dimensional cloud tensors: the data-loader orchestration layer on top of solved async IO (obstore / zarr v3 / icechunk).
5
+ Keywords: zarr,xarray,pytorch,dataloader,obstore,machine-learning,cloud,async
6
+ Author: David Stuebe
7
+ Author-email: David Stuebe <stu3b3+emfdavid@gmail.com>
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 2 - Pre-Alpha
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Classifier: Programming Language :: Python :: 3.13
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Topic :: Scientific/Engineering
15
+ Requires-Dist: numpy>=1.26
16
+ Requires-Dist: zarr>=3.0
17
+ Requires-Dist: xarray>=2024.9
18
+ Requires-Dist: obstore>=0.3
19
+ Requires-Dist: cupy-cuda12x>=13.0 ; extra == 'gpu'
20
+ Requires-Dist: kvikio-cu12>=24.10 ; extra == 'gpu'
21
+ Requires-Dist: torch>=2.2 ; extra == 'torch'
22
+ Requires-Dist: torchdata>=0.10 ; extra == 'torch'
23
+ Requires-Dist: virtualizarr>=1.2 ; extra == 'virtual'
24
+ Requires-Python: >=3.12
25
+ Provides-Extra: gpu
26
+ Provides-Extra: torch
27
+ Provides-Extra: virtual
28
+ Description-Content-Type: text/markdown
29
+
30
+ # insitubatch
31
+
32
+ **Train in place on n-dimensional cloud tensors.**
33
+
34
+ `insitubatch` is the data-loader orchestration layer that sits on top of
35
+ *already-solved* async cloud IO (obstore / zarr v3 / icechunk). It turns an
36
+ existing Zarr archive into a shuffled, split-aware, GPU-saturating PyTorch
37
+ source — **with no reshard** — and a Python hot path that scales with **chunks,
38
+ not samples**.
39
+
40
+ > The IO race is over (obstore/icechunk saturate the NIC). The *loader* race is
41
+ > open. `insitubatch` builds the layer that projects like light-speed-io and
42
+ > hypergrib stopped one step short of. See [DESIGN.md](DESIGN.md).
43
+
44
+ ## Why
45
+
46
+ The classic PyTorch `DataLoader` spreads work across worker **processes**, each
47
+ running a *synchronous* `__getitem__`. Against cloud Zarr that means no shared
48
+ chunk cache (every worker re-reads the same chunk), no way to drive async
49
+ obstore, and dask thread pools nested inside forked workers. `insitubatch`
50
+ **inverts** it: one async event loop drives concurrent reads; a bounded
51
+ shuffle-block buffer assembles batches; torch runs `num_workers=0`.
52
+
53
+ ## Status
54
+
55
+ 🚧 **Pre-alpha skeleton.** Abstractions and control flow are in place; the live
56
+ store read in `io.py` and the GPU path are stubbed. Not yet usable for real
57
+ training — this is the design substrate.
58
+
59
+ ## Install (dev)
60
+
61
+ ```bash
62
+ uv sync # core engine + dev tools
63
+ uv sync --extra torch # add the torch IterableDataset surface
64
+ uv sync --extra gpu # CUDA box only: cupy + kvikio zero-copy path
65
+ ```
66
+
67
+ ## Shape of the API (target)
68
+
69
+ ```python
70
+ from insitubatch import split_by_chunk, ArrayGeometry, SplitName
71
+ from insitubatch.source import InSituDataset
72
+ from torch.utils.data import DataLoader
73
+
74
+ geom = ArrayGeometry("t2m", shape=(8760, 721, 1440), chunks=(24, 721, 1440), dtype=...)
75
+ manifest = split_by_chunk(geom, fractions=(0.8, 0.1, 0.1))
76
+
77
+ ds = InSituDataset(store, {"t2m": geom}, manifest, SplitName.TRAIN,
78
+ batch_size=32, block_chunks=16)
79
+
80
+ # parallelism lives in insitubatch's event loop, not in workers:
81
+ loader = DataLoader(ds, batch_size=None, num_workers=0)
82
+ for epoch in range(n_epochs):
83
+ ds.set_epoch(epoch)
84
+ for batch in loader:
85
+ ...
86
+ ```
87
+
88
+ ## License
89
+
90
+ MIT — see [LICENSE](LICENSE).
@@ -0,0 +1,14 @@
1
+ insitubatch/__init__.py,sha256=9TTgS8anuSDAUm3QugS8EJgR6nIOCQpeEVFXUQNfZWo,1311
2
+ insitubatch/buffer.py,sha256=JcjzcFjqIvaqWS0jp6_1bvPlLgXUXRUEOP9F9HLWAS4,3419
3
+ insitubatch/io.py,sha256=kSXid_h1bu3tNoO3BbvlD1MDF5RlXh4eFtNI0fz0HyM,7090
4
+ insitubatch/plan.py,sha256=cY3iYemKLmJvaHkTD1ahOU7WsfjJzG9z6wuxbxu9kOQ,4110
5
+ insitubatch/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ insitubatch/shuffle.py,sha256=ARXoC1sAUU_xRa3YBN4adtKszSay_HqOTtBlvk1ik7Q,3136
7
+ insitubatch/source.py,sha256=qhIcLj3HAtDePGiKHBWqyi0n9lgKV6-KlfyvKdrRAh4,4486
8
+ insitubatch/split.py,sha256=xIQuzZmIuON4OANAEqVKHkdCc1BETYluv28HsztOTOg,3286
9
+ insitubatch/store.py,sha256=Ny-DrEoQD-cv4y9NIkwIdzk4_QBrXduZyN7RgkNYIX8,2605
10
+ insitubatch/types.py,sha256=6rquZxacOgekl-GKw37dfb3HOxD8muyHKRwaOqMBdQA,3455
11
+ insitubatch-0.0.1.dist-info/licenses/LICENSE,sha256=6IPhQ_YPQ0SYyy6JY62NW7zhfJhoYj8mfYeb0LKMwJU,1069
12
+ insitubatch-0.0.1.dist-info/WHEEL,sha256=8ZlpUMJ7mlDirmlHRhDirEx_nPnARrwDjeE92mlk68E,81
13
+ insitubatch-0.0.1.dist-info/METADATA,sha256=TD_A5rHYJqcPWRezqOwbG9xmS8teVvYG9pBOwFliNYA,3365
14
+ insitubatch-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: uv 0.11.21
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 David Stuebe
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.