sverdrup 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sverdrup/__init__.py +0 -0
- sverdrup/__main__.py +49 -0
- sverdrup/_version.py +24 -0
- sverdrup/adapters/__init__.py +0 -0
- sverdrup/adapters/executor_dask.py +74 -0
- sverdrup/adapters/odc/__init__.py +0 -0
- sverdrup/adapters/odc/download.py +56 -0
- sverdrup/adapters/odc/fixtures.py +87 -0
- sverdrup/adapters/odc/natl60.py +29 -0
- sverdrup/adapters/odc/ose.py +31 -0
- sverdrup/adapters/storage_fsspec.py +128 -0
- sverdrup/application/__init__.py +0 -0
- sverdrup/application/config.py +24 -0
- sverdrup/application/pipeline.py +167 -0
- sverdrup/application/solve.py +152 -0
- sverdrup/application/splits.py +78 -0
- sverdrup/application/uow.py +28 -0
- sverdrup/core/__init__.py +0 -0
- sverdrup/core/derived.py +48 -0
- sverdrup/core/distribution.py +56 -0
- sverdrup/core/evaluation.py +84 -0
- sverdrup/core/grid.py +149 -0
- sverdrup/core/method.py +34 -0
- sverdrup/core/observations.py +153 -0
- sverdrup/core/parameters.py +60 -0
- sverdrup/core/ports.py +40 -0
- sverdrup/core/product.py +41 -0
- sverdrup/core/provenance.py +59 -0
- sverdrup/core/seeding.py +26 -0
- sverdrup/core/types.py +36 -0
- sverdrup/derived/__init__.py +0 -0
- sverdrup/derived/area_average.py +19 -0
- sverdrup/derived/eke.py +17 -0
- sverdrup/derived/firstdifference.py +113 -0
- sverdrup/derived/transport.py +17 -0
- sverdrup/derived/velocity.py +17 -0
- sverdrup/distributions/__init__.py +0 -0
- sverdrup/distributions/adapters.py +128 -0
- sverdrup/distributions/ensemble.py +59 -0
- sverdrup/distributions/gaussian.py +66 -0
- sverdrup/distributions/persisted.py +129 -0
- sverdrup/eval/__init__.py +0 -0
- sverdrup/eval/accuracy.py +31 -0
- sverdrup/eval/calibration.py +70 -0
- sverdrup/eval/groundtrack.py +35 -0
- sverdrup/methods/__init__.py +0 -0
- sverdrup/methods/kernel.py +55 -0
- sverdrup/methods/oi.py +129 -0
- sverdrup/methods/registry.py +8 -0
- sverdrup/methods/solver.py +71 -0
- sverdrup/methods/trivial.py +65 -0
- sverdrup/py.typed +0 -0
- sverdrup-0.1.0.dist-info/METADATA +120 -0
- sverdrup-0.1.0.dist-info/RECORD +56 -0
- sverdrup-0.1.0.dist-info/WHEEL +4 -0
- sverdrup-0.1.0.dist-info/licenses/LICENSE +202 -0
sverdrup/__init__.py
ADDED
|
File without changes
|
sverdrup/__main__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Runnable entry point: ``python -m sverdrup <config.json>``."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from sverdrup.adapters.executor_dask import ExecutorConfig
|
|
10
|
+
from sverdrup.adapters.odc.fixtures import FixtureSource
|
|
11
|
+
from sverdrup.application.pipeline import PipelineInputs, run_pipeline
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def main(argv: list[str]) -> int:
|
|
15
|
+
"""Run a config-driven pipeline, or print usage when no config is given.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
argv: The process argv (``argv[1]`` is the config path).
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Process exit code (0 on success).
|
|
22
|
+
"""
|
|
23
|
+
if len(argv) < 2:
|
|
24
|
+
print("usage: python -m sverdrup <config.json>")
|
|
25
|
+
return 0
|
|
26
|
+
cfg = json.loads(Path(argv[1]).read_text())
|
|
27
|
+
src = FixtureSource(cfg["obs_path"], cfg.get("ref_path"))
|
|
28
|
+
inp = PipelineInputs(
|
|
29
|
+
mode=cfg["mode"],
|
|
30
|
+
method_name=cfg["method"],
|
|
31
|
+
source=src,
|
|
32
|
+
out_url=cfg["out_url"],
|
|
33
|
+
lon_range=tuple(cfg["lon_range"]),
|
|
34
|
+
lat_range=tuple(cfg["lat_range"]),
|
|
35
|
+
time_range=tuple(cfg["time_range"]),
|
|
36
|
+
output_times=cfg["output_times"],
|
|
37
|
+
params=cfg["params"],
|
|
38
|
+
grid_resolution_deg=cfg.get("grid_resolution_deg", 1.0),
|
|
39
|
+
executor=ExecutorConfig(**cfg.get("executor", {})),
|
|
40
|
+
rank=cfg.get("rank", 20),
|
|
41
|
+
)
|
|
42
|
+
product, scores = run_pipeline(inp)
|
|
43
|
+
reported = {k: v for k, v in scores.items() if k != "context_keys"}
|
|
44
|
+
print(f"wrote {len(product.per_time)} time(s); scores={reported}")
|
|
45
|
+
return 0
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
if __name__ == "__main__":
|
|
49
|
+
raise SystemExit(main(sys.argv))
|
sverdrup/_version.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# file generated by vcs-versioning
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"__version__",
|
|
7
|
+
"__version_tuple__",
|
|
8
|
+
"version",
|
|
9
|
+
"version_tuple",
|
|
10
|
+
"__commit_id__",
|
|
11
|
+
"commit_id",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
version: str
|
|
15
|
+
__version__: str
|
|
16
|
+
__version_tuple__: tuple[int | str, ...]
|
|
17
|
+
version_tuple: tuple[int | str, ...]
|
|
18
|
+
commit_id: str | None
|
|
19
|
+
__commit_id__: str | None
|
|
20
|
+
|
|
21
|
+
__version__ = version = '0.1.0'
|
|
22
|
+
__version_tuple__ = version_tuple = (0, 1, 0)
|
|
23
|
+
|
|
24
|
+
__commit_id__ = commit_id = None
|
|
File without changes
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Executor adapter: dask.distributed LocalCluster with a per-run BLAS/OpenMP knob (spec 5.9)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, cast
|
|
8
|
+
|
|
9
|
+
from sverdrup.application.solve import solve_unit
|
|
10
|
+
from sverdrup.application.uow import UnitOfWork
|
|
11
|
+
from sverdrup.core.product import Product
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class ExecutorConfig:
|
|
16
|
+
"""Executor sizing: processes, threads-per-process (BLAS cap), and scheduler seam."""
|
|
17
|
+
|
|
18
|
+
n_processes: int = 4
|
|
19
|
+
threads_per_process: int = 1
|
|
20
|
+
scheduler_address: str | None = None # None -> spin up a LocalCluster
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _thread_env(threads: int) -> dict[str, str]:
|
|
24
|
+
"""Return the BLAS/OpenMP thread-cap environment for one worker."""
|
|
25
|
+
t = str(threads)
|
|
26
|
+
return {"OMP_NUM_THREADS": t, "OPENBLAS_NUM_THREADS": t, "MKL_NUM_THREADS": t}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DaskExecutor:
|
|
30
|
+
"""The sole Phase-1 executor adapter. Scaling out changes only scheduler_address."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, config: ExecutorConfig) -> None:
|
|
33
|
+
"""Store config; the cluster/client are created on context entry.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
config: The executor configuration.
|
|
37
|
+
"""
|
|
38
|
+
self.config = config
|
|
39
|
+
self._cluster: Any = None
|
|
40
|
+
self._client: Any = None
|
|
41
|
+
|
|
42
|
+
def __enter__(self) -> DaskExecutor:
|
|
43
|
+
"""Start (or connect to) the cluster and open a client."""
|
|
44
|
+
from distributed import Client, LocalCluster
|
|
45
|
+
|
|
46
|
+
if self.config.scheduler_address:
|
|
47
|
+
self._client = Client(self.config.scheduler_address) # type: ignore[no-untyped-call]
|
|
48
|
+
else:
|
|
49
|
+
self._cluster = LocalCluster( # type: ignore[no-untyped-call]
|
|
50
|
+
n_workers=self.config.n_processes,
|
|
51
|
+
threads_per_worker=1,
|
|
52
|
+
processes=True,
|
|
53
|
+
env=_thread_env(self.config.threads_per_process),
|
|
54
|
+
)
|
|
55
|
+
self._client = Client(self._cluster) # type: ignore[no-untyped-call]
|
|
56
|
+
return self
|
|
57
|
+
|
|
58
|
+
def __exit__(self, *exc: object) -> None:
|
|
59
|
+
"""Tear down the client and cluster on context exit."""
|
|
60
|
+
if self._client:
|
|
61
|
+
self._client.close()
|
|
62
|
+
if self._cluster:
|
|
63
|
+
self._cluster.close()
|
|
64
|
+
|
|
65
|
+
def worker_env_sample(self) -> dict[str, str]:
|
|
66
|
+
"""Return one worker's BLAS/OpenMP environment (proves the cap is applied)."""
|
|
67
|
+
keys = _thread_env(self.config.threads_per_process)
|
|
68
|
+
result = self._client.run(lambda: {k: os.environ.get(k, "") for k in keys})
|
|
69
|
+
return cast(dict[str, str], result.popitem()[1])
|
|
70
|
+
|
|
71
|
+
def submit(self, unit_of_work: UnitOfWork) -> Product:
|
|
72
|
+
"""Run ``solve_unit`` on a worker and return the resulting Product."""
|
|
73
|
+
future = self._client.submit(solve_unit, unit_of_work, pure=False)
|
|
74
|
+
return cast(Product, future.result())
|
|
File without changes
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""ODC THREDDS cache: fetch whole files and OPeNDAP-subset, into ./data/cache/."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
import xarray as xr
|
|
10
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
11
|
+
|
|
12
|
+
CACHE = Path("data/cache")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ODCCache:
|
|
16
|
+
"""A local content cache for ODC THREDDS files under ``./data/cache/``."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, root: Path = CACHE) -> None:
|
|
19
|
+
"""Create the cache root if needed.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
root: The cache directory.
|
|
23
|
+
"""
|
|
24
|
+
self.root = root
|
|
25
|
+
self.root.mkdir(parents=True, exist_ok=True)
|
|
26
|
+
|
|
27
|
+
def path_for(self, url: str) -> Path:
|
|
28
|
+
"""Return the deterministic cache path for ``url``."""
|
|
29
|
+
h = hashlib.blake2b(url.encode(), digest_size=8).hexdigest()
|
|
30
|
+
return self.root / f"{h}_{url.rsplit('/', 1)[-1]}"
|
|
31
|
+
|
|
32
|
+
@retry(stop=stop_after_attempt(4), wait=wait_exponential(multiplier=1, max=30))
|
|
33
|
+
def fetch_file(self, url: str) -> Path:
|
|
34
|
+
"""Download ``url`` to the cache (skipped if already present).
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
url: The file URL.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The local cache path.
|
|
41
|
+
"""
|
|
42
|
+
dest = self.path_for(url)
|
|
43
|
+
if dest.exists():
|
|
44
|
+
return dest
|
|
45
|
+
with requests.get(url, stream=True, timeout=120) as r:
|
|
46
|
+
r.raise_for_status()
|
|
47
|
+
tmp = dest.with_suffix(".part")
|
|
48
|
+
with tmp.open("wb") as f:
|
|
49
|
+
for chunk in r.iter_content(1 << 20):
|
|
50
|
+
f.write(chunk)
|
|
51
|
+
tmp.replace(dest)
|
|
52
|
+
return dest
|
|
53
|
+
|
|
54
|
+
def open_dodsC(self, opendap_url: str) -> xr.Dataset:
|
|
55
|
+
"""Open an OPeNDAP dataset lazily (no full download)."""
|
|
56
|
+
return xr.open_dataset(opendap_url)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Offline fixture data-source for deterministic CI (wraps the same interface as ODC)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import dask.array as da
|
|
8
|
+
import numpy as np
|
|
9
|
+
import xarray as xr
|
|
10
|
+
|
|
11
|
+
from sverdrup.core.observations import DiagonalErrorModel, ObsWindow
|
|
12
|
+
from sverdrup.core.types import Field
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from sverdrup.core.grid import GridSpec
|
|
16
|
+
|
|
17
|
+
Range = tuple[float, float]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FixtureSource:
|
|
21
|
+
"""A NetCDF-backed data source matching the ODC ``DataSource``/truth interface."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self, obs_path: str, ref_path: str | None = None, noise: float = 0.01
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Open the observation (and optional reference) datasets.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
obs_path: Path to the along-track observation NetCDF.
|
|
30
|
+
ref_path: Optional path to the gridded reference NetCDF (OSSE truth).
|
|
31
|
+
noise: Per-observation error variance for the diagonal error model.
|
|
32
|
+
"""
|
|
33
|
+
self._obs = xr.open_dataset(obs_path)
|
|
34
|
+
self._ref = xr.open_dataset(ref_path) if ref_path else None
|
|
35
|
+
self._noise = noise
|
|
36
|
+
|
|
37
|
+
def window(
|
|
38
|
+
self, *, lon_range: Range, lat_range: Range, time_range: Range
|
|
39
|
+
) -> ObsWindow:
|
|
40
|
+
"""Return a lazily-backed ``ObsWindow`` over the requested space-time box.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
lon_range: Inclusive longitude bounds in degrees.
|
|
44
|
+
lat_range: Inclusive latitude bounds in degrees.
|
|
45
|
+
time_range: Inclusive time bounds in days.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
An ``ObsWindow`` whose values stay dask-lazy until materialised.
|
|
49
|
+
"""
|
|
50
|
+
ds = self._obs
|
|
51
|
+
m = (
|
|
52
|
+
(ds.longitude >= lon_range[0])
|
|
53
|
+
& (ds.longitude <= lon_range[1])
|
|
54
|
+
& (ds.latitude >= lat_range[0])
|
|
55
|
+
& (ds.latitude <= lat_range[1])
|
|
56
|
+
& (ds.time >= time_range[0])
|
|
57
|
+
& (ds.time <= time_range[1])
|
|
58
|
+
)
|
|
59
|
+
sub = ds.where(m, drop=True)
|
|
60
|
+
n = int(sub.sizes["t"])
|
|
61
|
+
return ObsWindow.from_arrays(
|
|
62
|
+
sub.longitude.values,
|
|
63
|
+
sub.latitude.values,
|
|
64
|
+
sub.time.values,
|
|
65
|
+
da.from_array(sub.sla.values, chunks=max(1, n // 2)), # type: ignore[no-untyped-call]
|
|
66
|
+
DiagonalErrorModel(np.full(n, self._noise)),
|
|
67
|
+
mission=sub.mission.values,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def truth(self, time_days: float, grid: GridSpec) -> Field | None:
|
|
71
|
+
"""Return the reference field interpolated to grid nodes, or ``None`` for OSE.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
time_days: The output time in days.
|
|
75
|
+
grid: The output grid.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
The ``(ny, nx)`` reference field, or ``None`` when no reference is set.
|
|
79
|
+
"""
|
|
80
|
+
if self._ref is None:
|
|
81
|
+
return None
|
|
82
|
+
snap = self._ref.ssh.interp(time=time_days)
|
|
83
|
+
lon, lat = grid._lonlat_nodes()
|
|
84
|
+
vals = snap.interp(
|
|
85
|
+
longitude=("z", lon.ravel()), latitude=("z", lat.ravel())
|
|
86
|
+
).values
|
|
87
|
+
return np.asarray(vals).reshape(grid.shape)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""OSSE NATL60 source: nadir obs (whole) + daily CJM165 reference clipped to the eval window."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from sverdrup.adapters.odc.download import ODCCache
|
|
6
|
+
from sverdrup.adapters.odc.fixtures import FixtureSource
|
|
7
|
+
|
|
8
|
+
WINDOW = ("2012-10-22", "2012-12-02") # 42-day eval window
|
|
9
|
+
OBS_URL = (
|
|
10
|
+
"https://tds.../2020a_SSH_mapping_NATL60/dc_obs/...tar.gz" # documented endpoint
|
|
11
|
+
)
|
|
12
|
+
REF_DAILY_URL = "https://tds.../NATL60-CJM165/...daily...nc"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Natl60Source(FixtureSource):
|
|
16
|
+
"""Phase-1 OSSE source. Until cached data is present, behaves as a FixtureSource."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self, obs_path: str, ref_path: str, cache: ODCCache | None = None
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Open the OSSE obs + daily reference (delegating to the fixture interface).
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
obs_path: Path to the nadir observation dataset.
|
|
25
|
+
ref_path: Path to the clipped daily reference dataset.
|
|
26
|
+
cache: Optional ODC cache (created on demand if omitted).
|
|
27
|
+
"""
|
|
28
|
+
super().__init__(obs_path, ref_path)
|
|
29
|
+
self.cache = cache or ODCCache()
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""OSE source: real along-track inputs; withheld CryoSat-2 as the independent eval signal."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from sverdrup.adapters.odc.fixtures import FixtureSource
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OseSource(FixtureSource):
|
|
11
|
+
"""Phase-1 OSE source. Truth is absent; CryoSat-2 is withheld for evaluation."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, obs_path: str) -> None:
|
|
14
|
+
"""Open the OSE along-track inputs (no reference truth).
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
obs_path: Path to the along-track observation dataset.
|
|
18
|
+
"""
|
|
19
|
+
super().__init__(obs_path, ref_path=None)
|
|
20
|
+
|
|
21
|
+
def withheld(self) -> tuple[np.ndarray, np.ndarray]:
|
|
22
|
+
"""Return the withheld CryoSat-2 along-track as ``(locations, values)``.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
A ``((k, 3), (k,))`` tuple of ``(lon, lat, time)`` locations and SLA values.
|
|
26
|
+
"""
|
|
27
|
+
c2 = self._obs.where(self._obs.mission == "c2", drop=True)
|
|
28
|
+
locs = np.column_stack(
|
|
29
|
+
[c2.longitude.values, c2.latitude.values, c2.time.values]
|
|
30
|
+
)
|
|
31
|
+
return locs, np.asarray(c2.sla.values)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""Local fsspec result-sink writing the persisted Product bundle + provenance (spec 5.8)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import fsspec # type: ignore[import-untyped]
|
|
9
|
+
import numpy as np
|
|
10
|
+
from fsspec import AbstractFileSystem
|
|
11
|
+
|
|
12
|
+
from sverdrup.core.grid import GridSpec
|
|
13
|
+
from sverdrup.core.product import PerTimeProduct, Product
|
|
14
|
+
from sverdrup.core.provenance import ProductProvenance, UncertaintyProvenance
|
|
15
|
+
from sverdrup.core.types import UncertaintyCapability
|
|
16
|
+
from sverdrup.distributions.persisted import PersistedDistribution, PersistedFields
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _prov_to_json(p: ProductProvenance) -> dict[str, Any]:
|
|
20
|
+
"""Serialise product provenance to a JSON-safe dict."""
|
|
21
|
+
return {
|
|
22
|
+
"method": p.method,
|
|
23
|
+
"params_key": p.params_key,
|
|
24
|
+
"seed": p.seed,
|
|
25
|
+
"split_id": p.split_id,
|
|
26
|
+
"code_version": p.code_version,
|
|
27
|
+
"native_capability": p.uncertainty.native_capability.name,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _save_array(fs: AbstractFileSystem, path: str, arr: np.ndarray) -> None:
|
|
32
|
+
"""Write a numpy array to ``path`` on filesystem ``fs``."""
|
|
33
|
+
with fs.open(path, "wb") as f:
|
|
34
|
+
np.save(f, arr)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _load_array(fs: AbstractFileSystem, path: str) -> np.ndarray:
|
|
38
|
+
"""Read a numpy array from ``path`` on filesystem ``fs``."""
|
|
39
|
+
with fs.open(path, "rb") as f:
|
|
40
|
+
return np.asarray(np.load(f))
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class FsspecResultSink:
|
|
44
|
+
"""Persists a Product bundle to any fsspec URL as per-time arrays + a JSON manifest."""
|
|
45
|
+
|
|
46
|
+
def write(self, product: Product, path: str) -> None:
|
|
47
|
+
"""Write ``product`` to ``path`` (an fsspec URL).
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
product: The persisted Product bundle.
|
|
51
|
+
path: The destination fsspec URL (e.g. ``file://.../prod.zarr``).
|
|
52
|
+
"""
|
|
53
|
+
fs, root = fsspec.core.url_to_fs(path)
|
|
54
|
+
fs.makedirs(root, exist_ok=True)
|
|
55
|
+
per_time_meta: list[dict[str, Any]] = []
|
|
56
|
+
for i, pt in enumerate(product.per_time):
|
|
57
|
+
base = pt.base
|
|
58
|
+
grp = f"{root}/t{i}"
|
|
59
|
+
fs.makedirs(grp, exist_ok=True)
|
|
60
|
+
arrays = {
|
|
61
|
+
"mean": base.fields.mean,
|
|
62
|
+
"marginal_variance": base.fields.marginal_variance,
|
|
63
|
+
"factor": base.fields.factor,
|
|
64
|
+
"residual": base.fields.residual,
|
|
65
|
+
"x": base.grid.x,
|
|
66
|
+
"y": base.grid.y,
|
|
67
|
+
}
|
|
68
|
+
for name, arr in arrays.items():
|
|
69
|
+
_save_array(fs, f"{grp}/{name}.npy", arr)
|
|
70
|
+
per_time_meta.append(
|
|
71
|
+
{
|
|
72
|
+
"time_days": pt.time_days,
|
|
73
|
+
"rank": base.fields.rank,
|
|
74
|
+
"captured_energy": base.fields.captured_energy,
|
|
75
|
+
"provenance": _prov_to_json(pt.provenance),
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
manifest = {
|
|
79
|
+
"times": product.times(),
|
|
80
|
+
"run": product.run_manifest,
|
|
81
|
+
"per_time": per_time_meta,
|
|
82
|
+
}
|
|
83
|
+
with fs.open(f"{root}/manifest.json", "w") as f:
|
|
84
|
+
json.dump(manifest, f)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def read_product(path: str) -> Product:
|
|
88
|
+
"""Reconstruct a Product bundle previously written by ``FsspecResultSink``.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
path: The fsspec URL the product was written to.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
The reconstructed Product (persisted representation, not sample maps).
|
|
95
|
+
"""
|
|
96
|
+
fs, root = fsspec.core.url_to_fs(path)
|
|
97
|
+
with fs.open(f"{root}/manifest.json") as f:
|
|
98
|
+
manifest = json.load(f)
|
|
99
|
+
per_time: list[PerTimeProduct] = []
|
|
100
|
+
for i, meta in enumerate(manifest["per_time"]):
|
|
101
|
+
grp = f"{root}/t{i}"
|
|
102
|
+
prov_meta = meta["provenance"]
|
|
103
|
+
fields = PersistedFields(
|
|
104
|
+
mean=_load_array(fs, f"{grp}/mean.npy"),
|
|
105
|
+
marginal_variance=_load_array(fs, f"{grp}/marginal_variance.npy"),
|
|
106
|
+
factor=_load_array(fs, f"{grp}/factor.npy"),
|
|
107
|
+
residual=_load_array(fs, f"{grp}/residual.npy"),
|
|
108
|
+
rank=meta["rank"],
|
|
109
|
+
seed=prov_meta["seed"],
|
|
110
|
+
captured_energy=meta["captured_energy"],
|
|
111
|
+
)
|
|
112
|
+
grid = GridSpec.lonlat(
|
|
113
|
+
_load_array(fs, f"{grp}/x.npy"), _load_array(fs, f"{grp}/y.npy")
|
|
114
|
+
)
|
|
115
|
+
prov = ProductProvenance(
|
|
116
|
+
method=prov_meta["method"],
|
|
117
|
+
params_key=prov_meta["params_key"],
|
|
118
|
+
seed=prov_meta["seed"],
|
|
119
|
+
split_id=prov_meta["split_id"],
|
|
120
|
+
code_version=prov_meta["code_version"],
|
|
121
|
+
input_manifest={},
|
|
122
|
+
uncertainty=UncertaintyProvenance(
|
|
123
|
+
UncertaintyCapability[prov_meta["native_capability"]], []
|
|
124
|
+
),
|
|
125
|
+
)
|
|
126
|
+
dist = PersistedDistribution(grid, fields, prov.uncertainty, meta["time_days"])
|
|
127
|
+
per_time.append(PerTimeProduct(meta["time_days"], dist, {}, None, prov))
|
|
128
|
+
return Product(per_time=per_time, run_manifest=manifest["run"])
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Run configuration value objects."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
from sverdrup.adapters.executor_dask import ExecutorConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class RunConfig:
|
|
12
|
+
"""The full specification of one pipeline run (OSSE or OSE)."""
|
|
13
|
+
|
|
14
|
+
mode: str # "OSSE" | "OSE"
|
|
15
|
+
method_name: str
|
|
16
|
+
params: dict[str, float]
|
|
17
|
+
lon_range: tuple[float, float]
|
|
18
|
+
lat_range: tuple[float, float]
|
|
19
|
+
time_range: tuple[float, float]
|
|
20
|
+
output_times: list[float]
|
|
21
|
+
grid_resolution_deg: float = 0.25
|
|
22
|
+
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
|
23
|
+
split_by: str = "mission"
|
|
24
|
+
rank: int = 40
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""End-to-end pipeline wiring: source -> executor.solve -> evaluate -> sink (spec 7)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from sverdrup.adapters.executor_dask import DaskExecutor, ExecutorConfig
|
|
11
|
+
from sverdrup.adapters.storage_fsspec import FsspecResultSink
|
|
12
|
+
from sverdrup.application.splits import make_splits
|
|
13
|
+
from sverdrup.application.uow import UnitOfWork
|
|
14
|
+
from sverdrup.core.evaluation import ContextKey, EvalContext, Registry
|
|
15
|
+
from sverdrup.core.grid import GridSpec
|
|
16
|
+
from sverdrup.core.observations import DiagonalErrorModel, ObsWindow
|
|
17
|
+
from sverdrup.core.parameters import ConstantProvider
|
|
18
|
+
from sverdrup.core.product import Product
|
|
19
|
+
from sverdrup.core.seeding import derive_seed
|
|
20
|
+
from sverdrup.eval.accuracy import Accuracy
|
|
21
|
+
from sverdrup.eval.calibration import Calibration
|
|
22
|
+
from sverdrup.eval.groundtrack import GroundTrack
|
|
23
|
+
|
|
24
|
+
Range = tuple[float, float]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class PipelineInputs:
|
|
29
|
+
"""All inputs to one end-to-end pipeline run (OSSE or OSE)."""
|
|
30
|
+
|
|
31
|
+
mode: str
|
|
32
|
+
method_name: str
|
|
33
|
+
source: object
|
|
34
|
+
out_url: str
|
|
35
|
+
lon_range: Range
|
|
36
|
+
lat_range: Range
|
|
37
|
+
time_range: Range
|
|
38
|
+
output_times: list[float]
|
|
39
|
+
params: dict[str, float]
|
|
40
|
+
grid_resolution_deg: float = 1.0
|
|
41
|
+
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
|
42
|
+
rank: int = 20
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _grid(inp: PipelineInputs) -> GridSpec:
|
|
46
|
+
"""Build the regular lon/lat output grid for the run."""
|
|
47
|
+
lons = np.arange(inp.lon_range[0], inp.lon_range[1] + 1e-9, inp.grid_resolution_deg)
|
|
48
|
+
lats = np.arange(inp.lat_range[0], inp.lat_range[1] + 1e-9, inp.grid_resolution_deg)
|
|
49
|
+
return GridSpec.lonlat(lons, lats)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def run_pipeline(inp: PipelineInputs) -> tuple[Product, dict[str, Any]]:
|
|
53
|
+
"""Run source -> dask solve -> sink -> evaluate and return ``(product, scores)``.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
inp: The pipeline inputs.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The persisted Product and the evaluator score dictionary.
|
|
60
|
+
"""
|
|
61
|
+
grid = _grid(inp)
|
|
62
|
+
src = cast(Any, inp.source)
|
|
63
|
+
obs = src.window(
|
|
64
|
+
lon_range=inp.lon_range, lat_range=inp.lat_range, time_range=inp.time_range
|
|
65
|
+
)
|
|
66
|
+
train_obs, eval_locs, withheld_vals = _prepare(inp, obs)
|
|
67
|
+
params = ConstantProvider(inp.params)
|
|
68
|
+
seed = derive_seed(inp.method_name, params.params_key(), "tile0", 0)
|
|
69
|
+
uow = UnitOfWork(
|
|
70
|
+
"tile0",
|
|
71
|
+
inp.method_name,
|
|
72
|
+
params,
|
|
73
|
+
"train",
|
|
74
|
+
seed,
|
|
75
|
+
inp.output_times,
|
|
76
|
+
train_obs,
|
|
77
|
+
grid,
|
|
78
|
+
eval_locations=eval_locs,
|
|
79
|
+
derived_names=["firstdifference"],
|
|
80
|
+
rank=inp.rank,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
with DaskExecutor(inp.executor) as ex:
|
|
84
|
+
product = ex.submit(uow)
|
|
85
|
+
|
|
86
|
+
FsspecResultSink().write(product, inp.out_url)
|
|
87
|
+
scores = _evaluate(inp, product, grid, eval_locs, withheld_vals)
|
|
88
|
+
return product, scores
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _subset_obs(obs: ObsWindow, idx: np.ndarray) -> ObsWindow:
|
|
92
|
+
"""Return the sub-window of ``obs`` at the given indices (preserves error variances)."""
|
|
93
|
+
coords = obs.coords()
|
|
94
|
+
var = np.diag(obs.error_model.as_matrix(len(obs)))
|
|
95
|
+
mission = obs.mission[idx] if obs.mission is not None else None
|
|
96
|
+
return ObsWindow.from_arrays(
|
|
97
|
+
coords[idx, 0],
|
|
98
|
+
coords[idx, 1],
|
|
99
|
+
coords[idx, 2],
|
|
100
|
+
obs.values()[idx],
|
|
101
|
+
DiagonalErrorModel(var[idx]),
|
|
102
|
+
mission=mission,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _prepare(
|
|
107
|
+
inp: PipelineInputs, obs: ObsWindow
|
|
108
|
+
) -> tuple[ObsWindow, np.ndarray | None, np.ndarray | None]:
|
|
109
|
+
"""Return ``(train_obs, eval_locations, withheld_values)``.
|
|
110
|
+
|
|
111
|
+
OSSE trains on all observations (truth supplies evaluation). OSE withholds the
|
|
112
|
+
CryoSat-2 mission from training and returns its locations/values for evaluation,
|
|
113
|
+
so the eval signal is genuinely independent (no autocorrelation leak).
|
|
114
|
+
"""
|
|
115
|
+
if inp.mode == "OSE" and obs.mission is not None:
|
|
116
|
+
split = make_splits(obs, by="mission", locked_missions=["c2"])
|
|
117
|
+
train_obs = _subset_obs(obs, split.train_idx)
|
|
118
|
+
coords = obs.coords()
|
|
119
|
+
eval_locs = coords[split.locked_test_idx].copy()
|
|
120
|
+
eval_locs[:, 2] = inp.output_times[0]
|
|
121
|
+
withheld_vals = obs.values()[split.locked_test_idx]
|
|
122
|
+
return train_obs, eval_locs, withheld_vals
|
|
123
|
+
return obs, None, None
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _evaluate(
|
|
127
|
+
inp: PipelineInputs,
|
|
128
|
+
product: Product,
|
|
129
|
+
grid: GridSpec,
|
|
130
|
+
eval_locs: np.ndarray | None,
|
|
131
|
+
withheld_vals: np.ndarray | None,
|
|
132
|
+
) -> dict[str, Any]:
|
|
133
|
+
"""Assemble the evaluation context and run every applicable evaluator.
|
|
134
|
+
|
|
135
|
+
OSSE calibrates/scores against the gridded truth; OSE scores against the
|
|
136
|
+
withheld CryoSat-2 along-track at the exact eval-point predictions. The
|
|
137
|
+
evaluator spine is identical — only the source and context differ.
|
|
138
|
+
"""
|
|
139
|
+
pt = product.per_time[0]
|
|
140
|
+
base = pt.base
|
|
141
|
+
items: dict[ContextKey, object] = {
|
|
142
|
+
ContextKey.ORBIT_GEOMETRY: {"track_spacing_nodes": 4}
|
|
143
|
+
}
|
|
144
|
+
result: dict[str, np.ndarray] = {
|
|
145
|
+
"field": base.fields.mean,
|
|
146
|
+
"grid_mean": base.fields.mean,
|
|
147
|
+
}
|
|
148
|
+
if inp.mode == "OSSE":
|
|
149
|
+
truth = cast(Any, inp.source).truth(inp.output_times[0], grid)
|
|
150
|
+
truth = np.asarray(truth)
|
|
151
|
+
items[ContextKey.TRUTH] = {"field": truth}
|
|
152
|
+
items[ContextKey.WITHHELD_OBS] = {"values": truth.ravel()}
|
|
153
|
+
result["eval_mean"] = base.fields.mean.ravel()
|
|
154
|
+
result["eval_var"] = base.marginal_variance().ravel()
|
|
155
|
+
elif (
|
|
156
|
+
eval_locs is not None
|
|
157
|
+
and withheld_vals is not None
|
|
158
|
+
and pt.eval_points is not None
|
|
159
|
+
):
|
|
160
|
+
items[ContextKey.WITHHELD_OBS] = {"values": withheld_vals}
|
|
161
|
+
result["eval_mean"] = pt.eval_points.mean
|
|
162
|
+
result["eval_var"] = pt.eval_points.variance
|
|
163
|
+
ctx = EvalContext(items)
|
|
164
|
+
reg = Registry([Accuracy(), Calibration(), GroundTrack(track_wavenumber=4)])
|
|
165
|
+
scores: dict[str, Any] = dict(reg.run(result, ctx))
|
|
166
|
+
scores["context_keys"] = {k.name for k in ctx.keys()}
|
|
167
|
+
return scores
|