sverdrup 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. sverdrup/__init__.py +0 -0
  2. sverdrup/__main__.py +49 -0
  3. sverdrup/_version.py +24 -0
  4. sverdrup/adapters/__init__.py +0 -0
  5. sverdrup/adapters/executor_dask.py +74 -0
  6. sverdrup/adapters/odc/__init__.py +0 -0
  7. sverdrup/adapters/odc/download.py +56 -0
  8. sverdrup/adapters/odc/fixtures.py +87 -0
  9. sverdrup/adapters/odc/natl60.py +29 -0
  10. sverdrup/adapters/odc/ose.py +31 -0
  11. sverdrup/adapters/storage_fsspec.py +128 -0
  12. sverdrup/application/__init__.py +0 -0
  13. sverdrup/application/config.py +24 -0
  14. sverdrup/application/pipeline.py +167 -0
  15. sverdrup/application/solve.py +152 -0
  16. sverdrup/application/splits.py +78 -0
  17. sverdrup/application/uow.py +28 -0
  18. sverdrup/core/__init__.py +0 -0
  19. sverdrup/core/derived.py +48 -0
  20. sverdrup/core/distribution.py +56 -0
  21. sverdrup/core/evaluation.py +84 -0
  22. sverdrup/core/grid.py +149 -0
  23. sverdrup/core/method.py +34 -0
  24. sverdrup/core/observations.py +153 -0
  25. sverdrup/core/parameters.py +60 -0
  26. sverdrup/core/ports.py +40 -0
  27. sverdrup/core/product.py +41 -0
  28. sverdrup/core/provenance.py +59 -0
  29. sverdrup/core/seeding.py +26 -0
  30. sverdrup/core/types.py +36 -0
  31. sverdrup/derived/__init__.py +0 -0
  32. sverdrup/derived/area_average.py +19 -0
  33. sverdrup/derived/eke.py +17 -0
  34. sverdrup/derived/firstdifference.py +113 -0
  35. sverdrup/derived/transport.py +17 -0
  36. sverdrup/derived/velocity.py +17 -0
  37. sverdrup/distributions/__init__.py +0 -0
  38. sverdrup/distributions/adapters.py +128 -0
  39. sverdrup/distributions/ensemble.py +59 -0
  40. sverdrup/distributions/gaussian.py +66 -0
  41. sverdrup/distributions/persisted.py +129 -0
  42. sverdrup/eval/__init__.py +0 -0
  43. sverdrup/eval/accuracy.py +31 -0
  44. sverdrup/eval/calibration.py +70 -0
  45. sverdrup/eval/groundtrack.py +35 -0
  46. sverdrup/methods/__init__.py +0 -0
  47. sverdrup/methods/kernel.py +55 -0
  48. sverdrup/methods/oi.py +129 -0
  49. sverdrup/methods/registry.py +8 -0
  50. sverdrup/methods/solver.py +71 -0
  51. sverdrup/methods/trivial.py +65 -0
  52. sverdrup/py.typed +0 -0
  53. sverdrup-0.1.0.dist-info/METADATA +120 -0
  54. sverdrup-0.1.0.dist-info/RECORD +56 -0
  55. sverdrup-0.1.0.dist-info/WHEEL +4 -0
  56. sverdrup-0.1.0.dist-info/licenses/LICENSE +202 -0
sverdrup/__init__.py ADDED
File without changes
sverdrup/__main__.py ADDED
@@ -0,0 +1,49 @@
1
+ """Runnable entry point: ``python -m sverdrup <config.json>``."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ from sverdrup.adapters.executor_dask import ExecutorConfig
10
+ from sverdrup.adapters.odc.fixtures import FixtureSource
11
+ from sverdrup.application.pipeline import PipelineInputs, run_pipeline
12
+
13
+
14
+ def main(argv: list[str]) -> int:
15
+ """Run a config-driven pipeline, or print usage when no config is given.
16
+
17
+ Args:
18
+ argv: The process argv (``argv[1]`` is the config path).
19
+
20
+ Returns:
21
+ Process exit code (0 on success).
22
+ """
23
+ if len(argv) < 2:
24
+ print("usage: python -m sverdrup <config.json>")
25
+ return 0
26
+ cfg = json.loads(Path(argv[1]).read_text())
27
+ src = FixtureSource(cfg["obs_path"], cfg.get("ref_path"))
28
+ inp = PipelineInputs(
29
+ mode=cfg["mode"],
30
+ method_name=cfg["method"],
31
+ source=src,
32
+ out_url=cfg["out_url"],
33
+ lon_range=tuple(cfg["lon_range"]),
34
+ lat_range=tuple(cfg["lat_range"]),
35
+ time_range=tuple(cfg["time_range"]),
36
+ output_times=cfg["output_times"],
37
+ params=cfg["params"],
38
+ grid_resolution_deg=cfg.get("grid_resolution_deg", 1.0),
39
+ executor=ExecutorConfig(**cfg.get("executor", {})),
40
+ rank=cfg.get("rank", 20),
41
+ )
42
+ product, scores = run_pipeline(inp)
43
+ reported = {k: v for k, v in scores.items() if k != "context_keys"}
44
+ print(f"wrote {len(product.per_time)} time(s); scores={reported}")
45
+ return 0
46
+
47
+
48
+ if __name__ == "__main__":
49
+ raise SystemExit(main(sys.argv))
sverdrup/_version.py ADDED
@@ -0,0 +1,24 @@
1
+ # file generated by vcs-versioning
2
+ # don't change, don't track in version control
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "__version_tuple__",
8
+ "version",
9
+ "version_tuple",
10
+ "__commit_id__",
11
+ "commit_id",
12
+ ]
13
+
14
+ version: str
15
+ __version__: str
16
+ __version_tuple__: tuple[int | str, ...]
17
+ version_tuple: tuple[int | str, ...]
18
+ commit_id: str | None
19
+ __commit_id__: str | None
20
+
21
+ __version__ = version = '0.1.0'
22
+ __version_tuple__ = version_tuple = (0, 1, 0)
23
+
24
+ __commit_id__ = commit_id = None
File without changes
@@ -0,0 +1,74 @@
1
+ """Executor adapter: dask.distributed LocalCluster with a per-run BLAS/OpenMP knob (spec 5.9)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import Any, cast
8
+
9
+ from sverdrup.application.solve import solve_unit
10
+ from sverdrup.application.uow import UnitOfWork
11
+ from sverdrup.core.product import Product
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class ExecutorConfig:
16
+ """Executor sizing: processes, threads-per-process (BLAS cap), and scheduler seam."""
17
+
18
+ n_processes: int = 4
19
+ threads_per_process: int = 1
20
+ scheduler_address: str | None = None # None -> spin up a LocalCluster
21
+
22
+
23
+ def _thread_env(threads: int) -> dict[str, str]:
24
+ """Return the BLAS/OpenMP thread-cap environment for one worker."""
25
+ t = str(threads)
26
+ return {"OMP_NUM_THREADS": t, "OPENBLAS_NUM_THREADS": t, "MKL_NUM_THREADS": t}
27
+
28
+
29
+ class DaskExecutor:
30
+ """The sole Phase-1 executor adapter. Scaling out changes only scheduler_address."""
31
+
32
+ def __init__(self, config: ExecutorConfig) -> None:
33
+ """Store config; the cluster/client are created on context entry.
34
+
35
+ Args:
36
+ config: The executor configuration.
37
+ """
38
+ self.config = config
39
+ self._cluster: Any = None
40
+ self._client: Any = None
41
+
42
+ def __enter__(self) -> DaskExecutor:
43
+ """Start (or connect to) the cluster and open a client."""
44
+ from distributed import Client, LocalCluster
45
+
46
+ if self.config.scheduler_address:
47
+ self._client = Client(self.config.scheduler_address) # type: ignore[no-untyped-call]
48
+ else:
49
+ self._cluster = LocalCluster( # type: ignore[no-untyped-call]
50
+ n_workers=self.config.n_processes,
51
+ threads_per_worker=1,
52
+ processes=True,
53
+ env=_thread_env(self.config.threads_per_process),
54
+ )
55
+ self._client = Client(self._cluster) # type: ignore[no-untyped-call]
56
+ return self
57
+
58
+ def __exit__(self, *exc: object) -> None:
59
+ """Tear down the client and cluster on context exit."""
60
+ if self._client:
61
+ self._client.close()
62
+ if self._cluster:
63
+ self._cluster.close()
64
+
65
+ def worker_env_sample(self) -> dict[str, str]:
66
+ """Return one worker's BLAS/OpenMP environment (proves the cap is applied)."""
67
+ keys = _thread_env(self.config.threads_per_process)
68
+ result = self._client.run(lambda: {k: os.environ.get(k, "") for k in keys})
69
+ return cast(dict[str, str], result.popitem()[1])
70
+
71
+ def submit(self, unit_of_work: UnitOfWork) -> Product:
72
+ """Run ``solve_unit`` on a worker and return the resulting Product."""
73
+ future = self._client.submit(solve_unit, unit_of_work, pure=False)
74
+ return cast(Product, future.result())
File without changes
@@ -0,0 +1,56 @@
1
+ """ODC THREDDS cache: fetch whole files and OPeNDAP-subset, into ./data/cache/."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ from pathlib import Path
7
+
8
+ import requests
9
+ import xarray as xr
10
+ from tenacity import retry, stop_after_attempt, wait_exponential
11
+
12
+ CACHE = Path("data/cache")
13
+
14
+
15
+ class ODCCache:
16
+ """A local content cache for ODC THREDDS files under ``./data/cache/``."""
17
+
18
+ def __init__(self, root: Path = CACHE) -> None:
19
+ """Create the cache root if needed.
20
+
21
+ Args:
22
+ root: The cache directory.
23
+ """
24
+ self.root = root
25
+ self.root.mkdir(parents=True, exist_ok=True)
26
+
27
+ def path_for(self, url: str) -> Path:
28
+ """Return the deterministic cache path for ``url``."""
29
+ h = hashlib.blake2b(url.encode(), digest_size=8).hexdigest()
30
+ return self.root / f"{h}_{url.rsplit('/', 1)[-1]}"
31
+
32
+ @retry(stop=stop_after_attempt(4), wait=wait_exponential(multiplier=1, max=30))
33
+ def fetch_file(self, url: str) -> Path:
34
+ """Download ``url`` to the cache (skipped if already present).
35
+
36
+ Args:
37
+ url: The file URL.
38
+
39
+ Returns:
40
+ The local cache path.
41
+ """
42
+ dest = self.path_for(url)
43
+ if dest.exists():
44
+ return dest
45
+ with requests.get(url, stream=True, timeout=120) as r:
46
+ r.raise_for_status()
47
+ tmp = dest.with_suffix(".part")
48
+ with tmp.open("wb") as f:
49
+ for chunk in r.iter_content(1 << 20):
50
+ f.write(chunk)
51
+ tmp.replace(dest)
52
+ return dest
53
+
54
+ def open_dodsC(self, opendap_url: str) -> xr.Dataset:
55
+ """Open an OPeNDAP dataset lazily (no full download)."""
56
+ return xr.open_dataset(opendap_url)
@@ -0,0 +1,87 @@
1
+ """Offline fixture data-source for deterministic CI (wraps the same interface as ODC)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import dask.array as da
8
+ import numpy as np
9
+ import xarray as xr
10
+
11
+ from sverdrup.core.observations import DiagonalErrorModel, ObsWindow
12
+ from sverdrup.core.types import Field
13
+
14
+ if TYPE_CHECKING:
15
+ from sverdrup.core.grid import GridSpec
16
+
17
+ Range = tuple[float, float]
18
+
19
+
20
+ class FixtureSource:
21
+ """A NetCDF-backed data source matching the ODC ``DataSource``/truth interface."""
22
+
23
+ def __init__(
24
+ self, obs_path: str, ref_path: str | None = None, noise: float = 0.01
25
+ ) -> None:
26
+ """Open the observation (and optional reference) datasets.
27
+
28
+ Args:
29
+ obs_path: Path to the along-track observation NetCDF.
30
+ ref_path: Optional path to the gridded reference NetCDF (OSSE truth).
31
+ noise: Per-observation error variance for the diagonal error model.
32
+ """
33
+ self._obs = xr.open_dataset(obs_path)
34
+ self._ref = xr.open_dataset(ref_path) if ref_path else None
35
+ self._noise = noise
36
+
37
+ def window(
38
+ self, *, lon_range: Range, lat_range: Range, time_range: Range
39
+ ) -> ObsWindow:
40
+ """Return a lazily-backed ``ObsWindow`` over the requested space-time box.
41
+
42
+ Args:
43
+ lon_range: Inclusive longitude bounds in degrees.
44
+ lat_range: Inclusive latitude bounds in degrees.
45
+ time_range: Inclusive time bounds in days.
46
+
47
+ Returns:
48
+ An ``ObsWindow`` whose values stay dask-lazy until materialised.
49
+ """
50
+ ds = self._obs
51
+ m = (
52
+ (ds.longitude >= lon_range[0])
53
+ & (ds.longitude <= lon_range[1])
54
+ & (ds.latitude >= lat_range[0])
55
+ & (ds.latitude <= lat_range[1])
56
+ & (ds.time >= time_range[0])
57
+ & (ds.time <= time_range[1])
58
+ )
59
+ sub = ds.where(m, drop=True)
60
+ n = int(sub.sizes["t"])
61
+ return ObsWindow.from_arrays(
62
+ sub.longitude.values,
63
+ sub.latitude.values,
64
+ sub.time.values,
65
+ da.from_array(sub.sla.values, chunks=max(1, n // 2)), # type: ignore[no-untyped-call]
66
+ DiagonalErrorModel(np.full(n, self._noise)),
67
+ mission=sub.mission.values,
68
+ )
69
+
70
+ def truth(self, time_days: float, grid: GridSpec) -> Field | None:
71
+ """Return the reference field interpolated to grid nodes, or ``None`` for OSE.
72
+
73
+ Args:
74
+ time_days: The output time in days.
75
+ grid: The output grid.
76
+
77
+ Returns:
78
+ The ``(ny, nx)`` reference field, or ``None`` when no reference is set.
79
+ """
80
+ if self._ref is None:
81
+ return None
82
+ snap = self._ref.ssh.interp(time=time_days)
83
+ lon, lat = grid._lonlat_nodes()
84
+ vals = snap.interp(
85
+ longitude=("z", lon.ravel()), latitude=("z", lat.ravel())
86
+ ).values
87
+ return np.asarray(vals).reshape(grid.shape)
@@ -0,0 +1,29 @@
1
+ """OSSE NATL60 source: nadir obs (whole) + daily CJM165 reference clipped to the eval window."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from sverdrup.adapters.odc.download import ODCCache
6
+ from sverdrup.adapters.odc.fixtures import FixtureSource
7
+
8
+ WINDOW = ("2012-10-22", "2012-12-02") # 42-day eval window
9
+ OBS_URL = (
10
+ "https://tds.../2020a_SSH_mapping_NATL60/dc_obs/...tar.gz" # documented endpoint
11
+ )
12
+ REF_DAILY_URL = "https://tds.../NATL60-CJM165/...daily...nc"
13
+
14
+
15
+ class Natl60Source(FixtureSource):
16
+ """Phase-1 OSSE source. Until cached data is present, behaves as a FixtureSource."""
17
+
18
+ def __init__(
19
+ self, obs_path: str, ref_path: str, cache: ODCCache | None = None
20
+ ) -> None:
21
+ """Open the OSSE obs + daily reference (delegating to the fixture interface).
22
+
23
+ Args:
24
+ obs_path: Path to the nadir observation dataset.
25
+ ref_path: Path to the clipped daily reference dataset.
26
+ cache: Optional ODC cache (created on demand if omitted).
27
+ """
28
+ super().__init__(obs_path, ref_path)
29
+ self.cache = cache or ODCCache()
@@ -0,0 +1,31 @@
1
+ """OSE source: real along-track inputs; withheld CryoSat-2 as the independent eval signal."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+ from sverdrup.adapters.odc.fixtures import FixtureSource
8
+
9
+
10
+ class OseSource(FixtureSource):
11
+ """Phase-1 OSE source. Truth is absent; CryoSat-2 is withheld for evaluation."""
12
+
13
+ def __init__(self, obs_path: str) -> None:
14
+ """Open the OSE along-track inputs (no reference truth).
15
+
16
+ Args:
17
+ obs_path: Path to the along-track observation dataset.
18
+ """
19
+ super().__init__(obs_path, ref_path=None)
20
+
21
+ def withheld(self) -> tuple[np.ndarray, np.ndarray]:
22
+ """Return the withheld CryoSat-2 along-track as ``(locations, values)``.
23
+
24
+ Returns:
25
+ A ``((k, 3), (k,))`` tuple of ``(lon, lat, time)`` locations and SLA values.
26
+ """
27
+ c2 = self._obs.where(self._obs.mission == "c2", drop=True)
28
+ locs = np.column_stack(
29
+ [c2.longitude.values, c2.latitude.values, c2.time.values]
30
+ )
31
+ return locs, np.asarray(c2.sla.values)
@@ -0,0 +1,128 @@
1
+ """Local fsspec result-sink writing the persisted Product bundle + provenance (spec 5.8)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from typing import Any
7
+
8
+ import fsspec # type: ignore[import-untyped]
9
+ import numpy as np
10
+ from fsspec import AbstractFileSystem
11
+
12
+ from sverdrup.core.grid import GridSpec
13
+ from sverdrup.core.product import PerTimeProduct, Product
14
+ from sverdrup.core.provenance import ProductProvenance, UncertaintyProvenance
15
+ from sverdrup.core.types import UncertaintyCapability
16
+ from sverdrup.distributions.persisted import PersistedDistribution, PersistedFields
17
+
18
+
19
+ def _prov_to_json(p: ProductProvenance) -> dict[str, Any]:
20
+ """Serialise product provenance to a JSON-safe dict."""
21
+ return {
22
+ "method": p.method,
23
+ "params_key": p.params_key,
24
+ "seed": p.seed,
25
+ "split_id": p.split_id,
26
+ "code_version": p.code_version,
27
+ "native_capability": p.uncertainty.native_capability.name,
28
+ }
29
+
30
+
31
+ def _save_array(fs: AbstractFileSystem, path: str, arr: np.ndarray) -> None:
32
+ """Write a numpy array to ``path`` on filesystem ``fs``."""
33
+ with fs.open(path, "wb") as f:
34
+ np.save(f, arr)
35
+
36
+
37
+ def _load_array(fs: AbstractFileSystem, path: str) -> np.ndarray:
38
+ """Read a numpy array from ``path`` on filesystem ``fs``."""
39
+ with fs.open(path, "rb") as f:
40
+ return np.asarray(np.load(f))
41
+
42
+
43
+ class FsspecResultSink:
44
+ """Persists a Product bundle to any fsspec URL as per-time arrays + a JSON manifest."""
45
+
46
+ def write(self, product: Product, path: str) -> None:
47
+ """Write ``product`` to ``path`` (an fsspec URL).
48
+
49
+ Args:
50
+ product: The persisted Product bundle.
51
+ path: The destination fsspec URL (e.g. ``file://.../prod.zarr``).
52
+ """
53
+ fs, root = fsspec.core.url_to_fs(path)
54
+ fs.makedirs(root, exist_ok=True)
55
+ per_time_meta: list[dict[str, Any]] = []
56
+ for i, pt in enumerate(product.per_time):
57
+ base = pt.base
58
+ grp = f"{root}/t{i}"
59
+ fs.makedirs(grp, exist_ok=True)
60
+ arrays = {
61
+ "mean": base.fields.mean,
62
+ "marginal_variance": base.fields.marginal_variance,
63
+ "factor": base.fields.factor,
64
+ "residual": base.fields.residual,
65
+ "x": base.grid.x,
66
+ "y": base.grid.y,
67
+ }
68
+ for name, arr in arrays.items():
69
+ _save_array(fs, f"{grp}/{name}.npy", arr)
70
+ per_time_meta.append(
71
+ {
72
+ "time_days": pt.time_days,
73
+ "rank": base.fields.rank,
74
+ "captured_energy": base.fields.captured_energy,
75
+ "provenance": _prov_to_json(pt.provenance),
76
+ }
77
+ )
78
+ manifest = {
79
+ "times": product.times(),
80
+ "run": product.run_manifest,
81
+ "per_time": per_time_meta,
82
+ }
83
+ with fs.open(f"{root}/manifest.json", "w") as f:
84
+ json.dump(manifest, f)
85
+
86
+
87
+ def read_product(path: str) -> Product:
88
+ """Reconstruct a Product bundle previously written by ``FsspecResultSink``.
89
+
90
+ Args:
91
+ path: The fsspec URL the product was written to.
92
+
93
+ Returns:
94
+ The reconstructed Product (persisted representation, not sample maps).
95
+ """
96
+ fs, root = fsspec.core.url_to_fs(path)
97
+ with fs.open(f"{root}/manifest.json") as f:
98
+ manifest = json.load(f)
99
+ per_time: list[PerTimeProduct] = []
100
+ for i, meta in enumerate(manifest["per_time"]):
101
+ grp = f"{root}/t{i}"
102
+ prov_meta = meta["provenance"]
103
+ fields = PersistedFields(
104
+ mean=_load_array(fs, f"{grp}/mean.npy"),
105
+ marginal_variance=_load_array(fs, f"{grp}/marginal_variance.npy"),
106
+ factor=_load_array(fs, f"{grp}/factor.npy"),
107
+ residual=_load_array(fs, f"{grp}/residual.npy"),
108
+ rank=meta["rank"],
109
+ seed=prov_meta["seed"],
110
+ captured_energy=meta["captured_energy"],
111
+ )
112
+ grid = GridSpec.lonlat(
113
+ _load_array(fs, f"{grp}/x.npy"), _load_array(fs, f"{grp}/y.npy")
114
+ )
115
+ prov = ProductProvenance(
116
+ method=prov_meta["method"],
117
+ params_key=prov_meta["params_key"],
118
+ seed=prov_meta["seed"],
119
+ split_id=prov_meta["split_id"],
120
+ code_version=prov_meta["code_version"],
121
+ input_manifest={},
122
+ uncertainty=UncertaintyProvenance(
123
+ UncertaintyCapability[prov_meta["native_capability"]], []
124
+ ),
125
+ )
126
+ dist = PersistedDistribution(grid, fields, prov.uncertainty, meta["time_days"])
127
+ per_time.append(PerTimeProduct(meta["time_days"], dist, {}, None, prov))
128
+ return Product(per_time=per_time, run_manifest=manifest["run"])
File without changes
@@ -0,0 +1,24 @@
1
+ """Run configuration value objects."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+
7
+ from sverdrup.adapters.executor_dask import ExecutorConfig
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class RunConfig:
12
+ """The full specification of one pipeline run (OSSE or OSE)."""
13
+
14
+ mode: str # "OSSE" | "OSE"
15
+ method_name: str
16
+ params: dict[str, float]
17
+ lon_range: tuple[float, float]
18
+ lat_range: tuple[float, float]
19
+ time_range: tuple[float, float]
20
+ output_times: list[float]
21
+ grid_resolution_deg: float = 0.25
22
+ executor: ExecutorConfig = field(default_factory=ExecutorConfig)
23
+ split_by: str = "mission"
24
+ rank: int = 40
@@ -0,0 +1,167 @@
1
+ """End-to-end pipeline wiring: source -> executor.solve -> evaluate -> sink (spec 7)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, cast
7
+
8
+ import numpy as np
9
+
10
+ from sverdrup.adapters.executor_dask import DaskExecutor, ExecutorConfig
11
+ from sverdrup.adapters.storage_fsspec import FsspecResultSink
12
+ from sverdrup.application.splits import make_splits
13
+ from sverdrup.application.uow import UnitOfWork
14
+ from sverdrup.core.evaluation import ContextKey, EvalContext, Registry
15
+ from sverdrup.core.grid import GridSpec
16
+ from sverdrup.core.observations import DiagonalErrorModel, ObsWindow
17
+ from sverdrup.core.parameters import ConstantProvider
18
+ from sverdrup.core.product import Product
19
+ from sverdrup.core.seeding import derive_seed
20
+ from sverdrup.eval.accuracy import Accuracy
21
+ from sverdrup.eval.calibration import Calibration
22
+ from sverdrup.eval.groundtrack import GroundTrack
23
+
24
+ Range = tuple[float, float]
25
+
26
+
27
+ @dataclass
28
+ class PipelineInputs:
29
+ """All inputs to one end-to-end pipeline run (OSSE or OSE)."""
30
+
31
+ mode: str
32
+ method_name: str
33
+ source: object
34
+ out_url: str
35
+ lon_range: Range
36
+ lat_range: Range
37
+ time_range: Range
38
+ output_times: list[float]
39
+ params: dict[str, float]
40
+ grid_resolution_deg: float = 1.0
41
+ executor: ExecutorConfig = field(default_factory=ExecutorConfig)
42
+ rank: int = 20
43
+
44
+
45
+ def _grid(inp: PipelineInputs) -> GridSpec:
46
+ """Build the regular lon/lat output grid for the run."""
47
+ lons = np.arange(inp.lon_range[0], inp.lon_range[1] + 1e-9, inp.grid_resolution_deg)
48
+ lats = np.arange(inp.lat_range[0], inp.lat_range[1] + 1e-9, inp.grid_resolution_deg)
49
+ return GridSpec.lonlat(lons, lats)
50
+
51
+
52
+ def run_pipeline(inp: PipelineInputs) -> tuple[Product, dict[str, Any]]:
53
+ """Run source -> dask solve -> sink -> evaluate and return ``(product, scores)``.
54
+
55
+ Args:
56
+ inp: The pipeline inputs.
57
+
58
+ Returns:
59
+ The persisted Product and the evaluator score dictionary.
60
+ """
61
+ grid = _grid(inp)
62
+ src = cast(Any, inp.source)
63
+ obs = src.window(
64
+ lon_range=inp.lon_range, lat_range=inp.lat_range, time_range=inp.time_range
65
+ )
66
+ train_obs, eval_locs, withheld_vals = _prepare(inp, obs)
67
+ params = ConstantProvider(inp.params)
68
+ seed = derive_seed(inp.method_name, params.params_key(), "tile0", 0)
69
+ uow = UnitOfWork(
70
+ "tile0",
71
+ inp.method_name,
72
+ params,
73
+ "train",
74
+ seed,
75
+ inp.output_times,
76
+ train_obs,
77
+ grid,
78
+ eval_locations=eval_locs,
79
+ derived_names=["firstdifference"],
80
+ rank=inp.rank,
81
+ )
82
+
83
+ with DaskExecutor(inp.executor) as ex:
84
+ product = ex.submit(uow)
85
+
86
+ FsspecResultSink().write(product, inp.out_url)
87
+ scores = _evaluate(inp, product, grid, eval_locs, withheld_vals)
88
+ return product, scores
89
+
90
+
91
+ def _subset_obs(obs: ObsWindow, idx: np.ndarray) -> ObsWindow:
92
+ """Return the sub-window of ``obs`` at the given indices (preserves error variances)."""
93
+ coords = obs.coords()
94
+ var = np.diag(obs.error_model.as_matrix(len(obs)))
95
+ mission = obs.mission[idx] if obs.mission is not None else None
96
+ return ObsWindow.from_arrays(
97
+ coords[idx, 0],
98
+ coords[idx, 1],
99
+ coords[idx, 2],
100
+ obs.values()[idx],
101
+ DiagonalErrorModel(var[idx]),
102
+ mission=mission,
103
+ )
104
+
105
+
106
+ def _prepare(
107
+ inp: PipelineInputs, obs: ObsWindow
108
+ ) -> tuple[ObsWindow, np.ndarray | None, np.ndarray | None]:
109
+ """Return ``(train_obs, eval_locations, withheld_values)``.
110
+
111
+ OSSE trains on all observations (truth supplies evaluation). OSE withholds the
112
+ CryoSat-2 mission from training and returns its locations/values for evaluation,
113
+ so the eval signal is genuinely independent (no autocorrelation leak).
114
+ """
115
+ if inp.mode == "OSE" and obs.mission is not None:
116
+ split = make_splits(obs, by="mission", locked_missions=["c2"])
117
+ train_obs = _subset_obs(obs, split.train_idx)
118
+ coords = obs.coords()
119
+ eval_locs = coords[split.locked_test_idx].copy()
120
+ eval_locs[:, 2] = inp.output_times[0]
121
+ withheld_vals = obs.values()[split.locked_test_idx]
122
+ return train_obs, eval_locs, withheld_vals
123
+ return obs, None, None
124
+
125
+
126
+ def _evaluate(
127
+ inp: PipelineInputs,
128
+ product: Product,
129
+ grid: GridSpec,
130
+ eval_locs: np.ndarray | None,
131
+ withheld_vals: np.ndarray | None,
132
+ ) -> dict[str, Any]:
133
+ """Assemble the evaluation context and run every applicable evaluator.
134
+
135
+ OSSE calibrates/scores against the gridded truth; OSE scores against the
136
+ withheld CryoSat-2 along-track at the exact eval-point predictions. The
137
+ evaluator spine is identical — only the source and context differ.
138
+ """
139
+ pt = product.per_time[0]
140
+ base = pt.base
141
+ items: dict[ContextKey, object] = {
142
+ ContextKey.ORBIT_GEOMETRY: {"track_spacing_nodes": 4}
143
+ }
144
+ result: dict[str, np.ndarray] = {
145
+ "field": base.fields.mean,
146
+ "grid_mean": base.fields.mean,
147
+ }
148
+ if inp.mode == "OSSE":
149
+ truth = cast(Any, inp.source).truth(inp.output_times[0], grid)
150
+ truth = np.asarray(truth)
151
+ items[ContextKey.TRUTH] = {"field": truth}
152
+ items[ContextKey.WITHHELD_OBS] = {"values": truth.ravel()}
153
+ result["eval_mean"] = base.fields.mean.ravel()
154
+ result["eval_var"] = base.marginal_variance().ravel()
155
+ elif (
156
+ eval_locs is not None
157
+ and withheld_vals is not None
158
+ and pt.eval_points is not None
159
+ ):
160
+ items[ContextKey.WITHHELD_OBS] = {"values": withheld_vals}
161
+ result["eval_mean"] = pt.eval_points.mean
162
+ result["eval_var"] = pt.eval_points.variance
163
+ ctx = EvalContext(items)
164
+ reg = Registry([Accuracy(), Calibration(), GroundTrack(track_wavenumber=4)])
165
+ scores: dict[str, Any] = dict(reg.run(result, ctx))
166
+ scores["context_keys"] = {k.name for k in ctx.keys()}
167
+ return scores