mxalign 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxalign/__init__.py +36 -0
- mxalign/accessors/__init__.py +7 -0
- mxalign/accessors/space.py +205 -0
- mxalign/accessors/time.py +180 -0
- mxalign/align/__init__.py +7 -0
- mxalign/align/nans.py +72 -0
- mxalign/align/space.py +21 -0
- mxalign/align/time.py +62 -0
- mxalign/cli.py +157 -0
- mxalign/interpolations/__init__.py +9 -0
- mxalign/interpolations/base.py +29 -0
- mxalign/interpolations/delaunay.py +218 -0
- mxalign/interpolations/interpolate.py +29 -0
- mxalign/interpolations/registry.py +17 -0
- mxalign/interpolations/xarray.py +63 -0
- mxalign/loaders/__init__.py +11 -0
- mxalign/loaders/anemoi_datasets.py +92 -0
- mxalign/loaders/anemoi_inference.py +103 -0
- mxalign/loaders/base.py +103 -0
- mxalign/loaders/harp_obstable.py +81 -0
- mxalign/loaders/loader.py +8 -0
- mxalign/loaders/registry.py +17 -0
- mxalign/properties/__init__.py +0 -0
- mxalign/properties/properties.py +25 -0
- mxalign/properties/specs.py +54 -0
- mxalign/properties/utils.py +43 -0
- mxalign/properties/validation.py +48 -0
- mxalign/runner.py +167 -0
- mxalign/transformations/__init__.py +7 -0
- mxalign/transformations/base.py +38 -0
- mxalign/transformations/external.py +34 -0
- mxalign/transformations/registry.py +20 -0
- mxalign/transformations/transform.py +28 -0
- mxalign/utils/config.py +55 -0
- mxalign/utils/dates.py +76 -0
- mxalign/utils/projections.py +104 -0
- mxalign/utils/save.py +62 -0
- mxalign/verification.py +57 -0
- mxalign-0.1.0.dist-info/METADATA +136 -0
- mxalign-0.1.0.dist-info/RECORD +43 -0
- mxalign-0.1.0.dist-info/WHEEL +4 -0
- mxalign-0.1.0.dist-info/entry_points.txt +2 -0
- mxalign-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
import xarray as xr
|
|
3
|
+
|
|
4
|
+
from .registry import register_loader
|
|
5
|
+
from ..properties.properties import Space, Time, Uncertainty
|
|
6
|
+
from .base import BaseLoader
|
|
7
|
+
|
|
8
|
+
DEFAULTS_NETCDF = {"chunks": "auto", "engine": "h5netcdf", "parallel": True}
|
|
9
|
+
|
|
10
|
+
DEFAULTS_ZARR = {
|
|
11
|
+
"chunks": "auto",
|
|
12
|
+
"storage_options": {"anon": True},
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@register_loader
|
|
17
|
+
class AnemoiInferenceLoader(BaseLoader):
|
|
18
|
+
name = "anemoi-inference"
|
|
19
|
+
|
|
20
|
+
space = Space.GRID
|
|
21
|
+
time = Time.FORECAST
|
|
22
|
+
uncertainty = Uncertainty.DETERMINISTIC
|
|
23
|
+
|
|
24
|
+
def _load(self):
|
|
25
|
+
|
|
26
|
+
kwargs = self.kwargs.copy()
|
|
27
|
+
|
|
28
|
+
if isinstance(self.files, str):
|
|
29
|
+
if Path(self.files).suffix.lower() == ".zarr":
|
|
30
|
+
files = self.files
|
|
31
|
+
|
|
32
|
+
for k, v in DEFAULTS_ZARR.items():
|
|
33
|
+
kwargs[k] = self.kwargs.get(k, v)
|
|
34
|
+
|
|
35
|
+
loader = _open_zarr
|
|
36
|
+
else:
|
|
37
|
+
files = [self.files]
|
|
38
|
+
|
|
39
|
+
for k, v in DEFAULTS_NETCDF.items():
|
|
40
|
+
kwargs[k] = self.kwargs.get(k, v)
|
|
41
|
+
|
|
42
|
+
loader = _open_mf_dataset
|
|
43
|
+
else:
|
|
44
|
+
files = self.files
|
|
45
|
+
if Path(files[0]).suffix.lower() == ".zarr":
|
|
46
|
+
for k, v in DEFAULTS_ZARR.items():
|
|
47
|
+
kwargs[k] = self.kwargs.get(k, v)
|
|
48
|
+
kwargs["engine"] = "zarr"
|
|
49
|
+
|
|
50
|
+
else:
|
|
51
|
+
for k, v in DEFAULTS_NETCDF.items():
|
|
52
|
+
kwargs[k] = self.kwargs.get(k, v)
|
|
53
|
+
|
|
54
|
+
loader = _open_mf_dataset
|
|
55
|
+
|
|
56
|
+
ds = loader(files, **kwargs)
|
|
57
|
+
return ds
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _open_mf_dataset(files, **kwargs):
|
|
61
|
+
|
|
62
|
+
times = xr.open_dataset(files[0], engine=kwargs["engine"], chunks=kwargs["chunks"])[
|
|
63
|
+
"time"
|
|
64
|
+
].values
|
|
65
|
+
lead_times = times - times[0]
|
|
66
|
+
|
|
67
|
+
ds = xr.open_mfdataset(files, preprocess=_preprocess, **kwargs)
|
|
68
|
+
|
|
69
|
+
ds_out = (
|
|
70
|
+
ds.assign_coords({"lead_time": ("time", lead_times)})
|
|
71
|
+
.rename_dims({"values": "grid_index"})
|
|
72
|
+
.swap_dims({"time": "lead_time"})
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return ds_out
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _open_zarr(files, **kwargs):
|
|
79
|
+
|
|
80
|
+
ds = xr.open_zarr(files, **kwargs)
|
|
81
|
+
times = ds["time"].values
|
|
82
|
+
lead_times = times - times[0]
|
|
83
|
+
|
|
84
|
+
ds_out = _preprocess(ds)
|
|
85
|
+
|
|
86
|
+
ds_out = (
|
|
87
|
+
ds_out.assign_coords({"lead_time": ("time", lead_times)})
|
|
88
|
+
.rename_dims({"values": "grid_index"})
|
|
89
|
+
.swap_dims({"time": "lead_time"})
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return ds_out
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _preprocess(ds):
|
|
96
|
+
ds_out = (
|
|
97
|
+
ds.set_coords(["longitude", "latitude"])
|
|
98
|
+
.expand_dims("reference_time")
|
|
99
|
+
.assign_coords({"reference_time": ("reference_time", [ds["time"].values[0]])})
|
|
100
|
+
.drop_vars("time")
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return ds_out
|
mxalign/loaders/base.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from .registry import register_loader
|
|
4
|
+
from ..properties.properties import Properties, Space, Time, Uncertainty
|
|
5
|
+
from ..properties.validation import validate_dataset
|
|
6
|
+
from ..properties.utils import properties_to_attrs
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseLoader(ABC):
|
|
10
|
+
"""Base class for all loaders."""
|
|
11
|
+
|
|
12
|
+
name: str = "base"
|
|
13
|
+
|
|
14
|
+
space: Space | None = None
|
|
15
|
+
time: Time | None = None
|
|
16
|
+
uncertainty: Uncertainty | None = None
|
|
17
|
+
|
|
18
|
+
def __init__(self, files, variables=None, grid_mapping=None, **kwargs):
|
|
19
|
+
self.files = files
|
|
20
|
+
self.variables = [variables] if isinstance(variables, str) else variables
|
|
21
|
+
self.grid_mapping = grid_mapping
|
|
22
|
+
self.kwargs = kwargs
|
|
23
|
+
|
|
24
|
+
def load(self):
|
|
25
|
+
ds = self._load()
|
|
26
|
+
if self.variables:
|
|
27
|
+
ds = self._select_variables(ds)
|
|
28
|
+
|
|
29
|
+
properties = self._get_properties(ds)
|
|
30
|
+
validate_dataset(ds, properties)
|
|
31
|
+
|
|
32
|
+
ds.attrs["properties"] = properties_to_attrs(properties)
|
|
33
|
+
|
|
34
|
+
if self.grid_mapping:
|
|
35
|
+
ds = self._add_grid_mapping(ds)
|
|
36
|
+
|
|
37
|
+
# Make sure all the coordinates are loaded
|
|
38
|
+
for coord in ds.coords:
|
|
39
|
+
ds[coord] = ds[coord].compute()
|
|
40
|
+
|
|
41
|
+
return ds
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def _load(self): ...
|
|
45
|
+
|
|
46
|
+
def _select_variables(self, ds):
|
|
47
|
+
return ds[self.variables]
|
|
48
|
+
|
|
49
|
+
def _add_grid_mapping(self, ds):
|
|
50
|
+
ds = ds.space.add_crs(self.grid_mapping)
|
|
51
|
+
ds = ds.space.add_grid_mapping(self.grid_mapping)
|
|
52
|
+
return ds
|
|
53
|
+
|
|
54
|
+
def _get_properties(self, ds):
|
|
55
|
+
properties = Properties(
|
|
56
|
+
space=self.space, time=self.time, uncertainty=self.uncertainty
|
|
57
|
+
)
|
|
58
|
+
return properties
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@register_loader
|
|
62
|
+
class MxAlignLoader(BaseLoader):
|
|
63
|
+
name = "mxalign"
|
|
64
|
+
|
|
65
|
+
space = None
|
|
66
|
+
time = None
|
|
67
|
+
uncertainty = None
|
|
68
|
+
|
|
69
|
+
def _load(self):
|
|
70
|
+
import xarray as xr
|
|
71
|
+
|
|
72
|
+
files = [self.files] if isinstance(self.files, str) else self.files
|
|
73
|
+
|
|
74
|
+
ds = xr.open_mfdataset(files, chunks="auto", **self.kwargs)
|
|
75
|
+
if "code" in ds.dims:
|
|
76
|
+
ds = ds.rename_dims({"code": "point_index"}).transpose(
|
|
77
|
+
"valid_time", "point_index"
|
|
78
|
+
)
|
|
79
|
+
return ds
|
|
80
|
+
|
|
81
|
+
def _get_properties(self, ds):
|
|
82
|
+
if "reference_time" in ds.dims and "lead_time" in ds.dims:
|
|
83
|
+
time = Time.FORECAST
|
|
84
|
+
elif "valid_time" in ds.dims:
|
|
85
|
+
time = Time.OBSERVATION
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError("Unknown temporal dimensions")
|
|
88
|
+
|
|
89
|
+
if "grid_index" in ds.dims or "xc" in ds.dims or "latitude" in ds.dims:
|
|
90
|
+
space = Space.GRID
|
|
91
|
+
elif "point_index" in ds.dims:
|
|
92
|
+
space = Space.POINT
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError("Unknown spatial dimensions")
|
|
95
|
+
|
|
96
|
+
if "member" in ds.dims:
|
|
97
|
+
uncertainty = Uncertainty.ENSEMBLE
|
|
98
|
+
elif "quantile" in ds.dims:
|
|
99
|
+
uncertainty = Uncertainty.QUANTILE
|
|
100
|
+
else:
|
|
101
|
+
uncertainty = Uncertainty.DETERMINISTIC
|
|
102
|
+
|
|
103
|
+
return Properties(space=space, time=time, uncertainty=uncertainty)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import sqlite3
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
from .registry import register_loader
|
|
5
|
+
from ..properties.properties import Space, Time, Uncertainty
|
|
6
|
+
from .base import BaseLoader
|
|
7
|
+
|
|
8
|
+
COORDS = {
|
|
9
|
+
"longitude": "lon",
|
|
10
|
+
"latitude": "lat",
|
|
11
|
+
"valid_time": "validdate",
|
|
12
|
+
"code": "SID",
|
|
13
|
+
"altitude": "elev",
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_loader
|
|
18
|
+
class ObstableLoader(BaseLoader):
|
|
19
|
+
name = "harp-obstable"
|
|
20
|
+
|
|
21
|
+
space = Space.POINT
|
|
22
|
+
time = Time.OBSERVATION
|
|
23
|
+
uncertainty = Uncertainty.DETERMINISTIC
|
|
24
|
+
|
|
25
|
+
def _load(self):
|
|
26
|
+
if isinstance(self.files, list) and len(self.files > 1):
|
|
27
|
+
raise NotImplementedError(
|
|
28
|
+
"Reading from multiple SQLite-files not implemented"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
conn = sqlite3.connect(self.files)
|
|
32
|
+
|
|
33
|
+
if self.variables is None:
|
|
34
|
+
# Retrieve all variables
|
|
35
|
+
variables = [
|
|
36
|
+
var
|
|
37
|
+
for var in pd.read_sql_query(
|
|
38
|
+
"SELECT * FROM SYNOP LIMIT 0", conn
|
|
39
|
+
).columns
|
|
40
|
+
if var not in COORDS.values()
|
|
41
|
+
]
|
|
42
|
+
print(variables)
|
|
43
|
+
else:
|
|
44
|
+
variables = self.variables
|
|
45
|
+
|
|
46
|
+
# Read the SIDs
|
|
47
|
+
codes = pd.read_sql(
|
|
48
|
+
"SELECT SID as code, MIN(lat) AS latitude, MIN(lon) AS longitude, elev as altitude FROM SYNOP GROUP BY SID",
|
|
49
|
+
conn,
|
|
50
|
+
index_col="code",
|
|
51
|
+
).to_xarray()
|
|
52
|
+
|
|
53
|
+
print(codes)
|
|
54
|
+
# Read the data
|
|
55
|
+
query = f"""
|
|
56
|
+
SELECT SID as code, validdate as valid_time, {", ".join(variables)}
|
|
57
|
+
FROM SYNOP
|
|
58
|
+
"""
|
|
59
|
+
print(query)
|
|
60
|
+
df = pd.read_sql(
|
|
61
|
+
query,
|
|
62
|
+
conn,
|
|
63
|
+
index_col=["code", "valid_time"],
|
|
64
|
+
parse_dates={"valid_time": {"unit": "s"}},
|
|
65
|
+
)
|
|
66
|
+
print(df)
|
|
67
|
+
|
|
68
|
+
ds = df.to_xarray()
|
|
69
|
+
lon_values = codes["longitude"].sel(code=ds["code"]).values
|
|
70
|
+
lat_values = codes["latitude"].sel(code=ds["code"]).values
|
|
71
|
+
alt_values = codes["altitude"].sel(code=ds["code"]).values
|
|
72
|
+
|
|
73
|
+
ds = ds.assign_coords(
|
|
74
|
+
longitude=("code", lon_values),
|
|
75
|
+
latitude=("code", lat_values),
|
|
76
|
+
altitude=("code", alt_values),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return ds.rename_dims({"code": "point_index"}).transpose(
|
|
80
|
+
"valid_time", "point_index"
|
|
81
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
_LOADERS = {}
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def register_loader(cls):
|
|
5
|
+
_LOADERS[cls.name] = cls
|
|
6
|
+
return cls
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def available_loaders():
|
|
10
|
+
return list(_LOADERS.keys())
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_loader(name):
|
|
14
|
+
try:
|
|
15
|
+
return _LOADERS[name]
|
|
16
|
+
except KeyError:
|
|
17
|
+
raise ValueError(f"Unknown loader: {name}")
|
|
File without changes
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Space(str, Enum):
|
|
6
|
+
GRID = "grid"
|
|
7
|
+
POINT = "point"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Time(str, Enum):
|
|
11
|
+
FORECAST = "forecast"
|
|
12
|
+
OBSERVATION = "observation"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Uncertainty(str, Enum):
|
|
16
|
+
DETERMINISTIC = "deterministic"
|
|
17
|
+
ENSEMBLE = "ensemble"
|
|
18
|
+
QUANTILE = "quantile"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True)
|
|
22
|
+
class Properties:
|
|
23
|
+
space: Space
|
|
24
|
+
time: Time
|
|
25
|
+
uncertainty: Uncertainty = Uncertainty.DETERMINISTIC
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Callable
|
|
3
|
+
from .properties import Space, Time, Uncertainty
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class PropertySpec:
|
|
8
|
+
dim_variants: list[set[str]] = field(default_factory=list)
|
|
9
|
+
required_coords: set[str] = field(default_factory=set)
|
|
10
|
+
optional_dims: set[str] = field(default_factory=set)
|
|
11
|
+
optional_coords: set[str] = field(default_factory=set)
|
|
12
|
+
validators: list[Callable] = field(default_factory=list)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
SPACE_SPECS = {
|
|
16
|
+
Space.GRID: PropertySpec(
|
|
17
|
+
dim_variants=[
|
|
18
|
+
{"xc", "yc"},
|
|
19
|
+
{"grid_index"},
|
|
20
|
+
{"longitude", "latitude"},
|
|
21
|
+
],
|
|
22
|
+
required_coords={"longitude", "latitude"},
|
|
23
|
+
optional_coords={"xc", "yc"},
|
|
24
|
+
optional_dims={"member"},
|
|
25
|
+
),
|
|
26
|
+
Space.POINT: PropertySpec(
|
|
27
|
+
dim_variants=[
|
|
28
|
+
{"point_index"},
|
|
29
|
+
],
|
|
30
|
+
required_coords={"longitude", "latitude"},
|
|
31
|
+
optional_coords={"code", "elevation", "name", "country"},
|
|
32
|
+
),
|
|
33
|
+
}
|
|
34
|
+
TIME_SPECS = {
|
|
35
|
+
Time.FORECAST: PropertySpec(
|
|
36
|
+
dim_variants=[{"reference_time", "lead_time"}],
|
|
37
|
+
required_coords={"reference_time", "lead_time"},
|
|
38
|
+
optional_coords={"valid_time"},
|
|
39
|
+
),
|
|
40
|
+
Time.OBSERVATION: PropertySpec(
|
|
41
|
+
dim_variants=[{"valid_time"}],
|
|
42
|
+
required_coords={"valid_time"},
|
|
43
|
+
),
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
UNCERTAINTY_SPECS = {
|
|
47
|
+
Uncertainty.DETERMINISTIC: PropertySpec(),
|
|
48
|
+
Uncertainty.ENSEMBLE: PropertySpec(
|
|
49
|
+
dim_variants=[{"member"}], required_coords={"member"}
|
|
50
|
+
),
|
|
51
|
+
Uncertainty.QUANTILE: PropertySpec(
|
|
52
|
+
dim_variants=[{"quantile"}], required_coords={"quantile"}
|
|
53
|
+
),
|
|
54
|
+
}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from .properties import Properties, Space, Time, Uncertainty
|
|
2
|
+
from .validation import validate_time_dataset, validate_space_dataset
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def properties_to_attrs(prop: Properties) -> dict:
|
|
6
|
+
return {
|
|
7
|
+
"space": prop.space.value,
|
|
8
|
+
"time": prop.time.value,
|
|
9
|
+
"uncertainty": prop.uncertainty.value,
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def properties_from_attrs(ds) -> Properties:
|
|
14
|
+
attrs = ds.attrs.get("properties", {})
|
|
15
|
+
return Properties(
|
|
16
|
+
space=Space(attrs["space"]),
|
|
17
|
+
time=Time(attrs["time"]),
|
|
18
|
+
uncertainty=Uncertainty(attrs.get("uncertainty", Uncertainty.DETERMINISTIC)),
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def update_space_property(ds, prop: Space):
|
|
23
|
+
old_props = properties_from_attrs(ds)
|
|
24
|
+
new_props = Properties(
|
|
25
|
+
space=prop,
|
|
26
|
+
time=old_props.time,
|
|
27
|
+
uncertainty=old_props.uncertainty,
|
|
28
|
+
)
|
|
29
|
+
validate_space_dataset(ds, new_props)
|
|
30
|
+
ds.attrs["properties"] = properties_to_attrs(new_props)
|
|
31
|
+
return ds
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def update_time_property(ds, prop: Time):
|
|
35
|
+
old_props = properties_from_attrs(ds)
|
|
36
|
+
new_props = Properties(
|
|
37
|
+
space=old_props.space,
|
|
38
|
+
time=prop,
|
|
39
|
+
uncertainty=old_props.uncertainty,
|
|
40
|
+
)
|
|
41
|
+
validate_time_dataset(ds, new_props)
|
|
42
|
+
ds.attrs["properties"] = properties_to_attrs(new_props)
|
|
43
|
+
return ds
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from .specs import SPACE_SPECS, TIME_SPECS, UNCERTAINTY_SPECS
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _validate_dims(ds, variants):
|
|
5
|
+
if not variants:
|
|
6
|
+
return
|
|
7
|
+
|
|
8
|
+
ds_dims = set(ds.dims)
|
|
9
|
+
|
|
10
|
+
for variant in variants:
|
|
11
|
+
if variant.issubset(ds_dims):
|
|
12
|
+
return
|
|
13
|
+
|
|
14
|
+
raise ValueError(f"Dataset dims {ds_dims} do not match allowed variants {variants}")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _validate_coords(ds, required_coords, axis):
|
|
18
|
+
missing = required_coords - set(ds.coords)
|
|
19
|
+
if missing:
|
|
20
|
+
raise ValueError(f"{axis}: missing required coordinates {missing}")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# TIME
|
|
24
|
+
def validate_time_dataset(ds, properties):
|
|
25
|
+
time_spec = TIME_SPECS[properties.time.value]
|
|
26
|
+
_validate_dims(ds, time_spec.dim_variants)
|
|
27
|
+
_validate_coords(ds, time_spec.required_coords, "time")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# SPACE
|
|
31
|
+
def validate_space_dataset(ds, properties):
|
|
32
|
+
space_spec = SPACE_SPECS[properties.space.value]
|
|
33
|
+
_validate_dims(ds, space_spec.dim_variants)
|
|
34
|
+
_validate_coords(ds, space_spec.required_coords, "space")
|
|
35
|
+
validate_time_dataset(ds, properties)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# UNCERTAINTY
|
|
39
|
+
def validate_uncertainty_dataset(ds, properties):
|
|
40
|
+
uncertainty_spec = UNCERTAINTY_SPECS[properties.uncertainty.value]
|
|
41
|
+
_validate_dims(ds, uncertainty_spec.dim_variants)
|
|
42
|
+
_validate_coords(ds, uncertainty_spec.required_coords, "uncertainty")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def validate_dataset(ds, properties):
|
|
46
|
+
validate_time_dataset(ds, properties)
|
|
47
|
+
validate_space_dataset(ds, properties)
|
|
48
|
+
validate_uncertainty_dataset(ds, properties)
|
mxalign/runner.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import xarray as xr
|
|
3
|
+
|
|
4
|
+
from .utils.config import Config
|
|
5
|
+
from .loaders.loader import load
|
|
6
|
+
from .transformations.transform import transform
|
|
7
|
+
from .align.time import align_time
|
|
8
|
+
from .align.space import align_space
|
|
9
|
+
from .align.nans import broadcast_nans
|
|
10
|
+
from .utils.save import save_dataset, save_metrics
|
|
11
|
+
from .verification import Metric
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Runner:
|
|
15
|
+
def __init__(self, config: str | dict):
|
|
16
|
+
self.config = Config(config)
|
|
17
|
+
self.datasets = {}
|
|
18
|
+
|
|
19
|
+
def run(self):
|
|
20
|
+
# 1. Load the datasets
|
|
21
|
+
self.load_datasets()
|
|
22
|
+
|
|
23
|
+
# 2. Transform the datasets
|
|
24
|
+
self.transform_datasets()
|
|
25
|
+
self.align()
|
|
26
|
+
self.verify()
|
|
27
|
+
|
|
28
|
+
def load_datasets(self):
|
|
29
|
+
config = self.config["datasets"]
|
|
30
|
+
if config is None:
|
|
31
|
+
return ValueError("No datasets section in the config.")
|
|
32
|
+
for name, config_ds in config.items():
|
|
33
|
+
config_ds = config_ds.copy()
|
|
34
|
+
# Check if all the files exist
|
|
35
|
+
loader = config_ds.pop("loader")
|
|
36
|
+
variables = config_ds.pop("variables", None)
|
|
37
|
+
grid_mapping = config_ds.pop("grid_mapping", None)
|
|
38
|
+
files = []
|
|
39
|
+
# Check if all the files exist
|
|
40
|
+
for file in config_ds.pop("files"):
|
|
41
|
+
if os.path.exists(file):
|
|
42
|
+
files.append(file)
|
|
43
|
+
else:
|
|
44
|
+
print(f"File: {file} is missing, skipping.")
|
|
45
|
+
self.datasets[name] = load(
|
|
46
|
+
name=loader,
|
|
47
|
+
files=files,
|
|
48
|
+
variables=variables,
|
|
49
|
+
grid_mapping=grid_mapping,
|
|
50
|
+
**config_ds,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def transform_datasets(self):
|
|
54
|
+
config = self.config["transformations"]
|
|
55
|
+
if config is None:
|
|
56
|
+
pass
|
|
57
|
+
for transformation, config_trans in config.items():
|
|
58
|
+
config_trans = config_trans.copy()
|
|
59
|
+
# if no datasets specified, apply to all datasets
|
|
60
|
+
names_ds = config_trans.pop("datasets", self.datasets.keys())
|
|
61
|
+
for name in names_ds:
|
|
62
|
+
ds = self.datasets[name]
|
|
63
|
+
self.datasets[name] = transform(
|
|
64
|
+
name=transformation, datasets=ds, **config_trans
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def align(self):
|
|
68
|
+
config = self.config["alignment"]
|
|
69
|
+
reference = config.pop("reference")
|
|
70
|
+
brdcst_nans = config.pop("broadcast_nans", True)
|
|
71
|
+
config_align_time = config.get("time", None)
|
|
72
|
+
config_align_space = config.get("space", None)
|
|
73
|
+
config_align_save = config.get("save", None)
|
|
74
|
+
|
|
75
|
+
# align in time
|
|
76
|
+
if config_align_time:
|
|
77
|
+
self.align_time(config_align_time)
|
|
78
|
+
else:
|
|
79
|
+
print("Skipping temporal alignment")
|
|
80
|
+
|
|
81
|
+
# align in space
|
|
82
|
+
if config_align_space:
|
|
83
|
+
self.align_space(reference=reference, config=config_align_space)
|
|
84
|
+
else:
|
|
85
|
+
print("Skipping spatial alignment")
|
|
86
|
+
|
|
87
|
+
# broadcast NaNs
|
|
88
|
+
if brdcst_nans:
|
|
89
|
+
self.datasets = broadcast_nans(self.datasets)
|
|
90
|
+
|
|
91
|
+
# Save aligned datasets
|
|
92
|
+
if config_align_save:
|
|
93
|
+
config = config_align_save.copy()
|
|
94
|
+
method = config.pop("method")
|
|
95
|
+
datasets = config.pop("datasets", "all")
|
|
96
|
+
if datasets == "all":
|
|
97
|
+
for name, ds in self.datasets.items():
|
|
98
|
+
save_dataset(method, name, ds, **config)
|
|
99
|
+
elif datasets == "merge":
|
|
100
|
+
ds = xr.concat(
|
|
101
|
+
self.datasets.values(),
|
|
102
|
+
dim=xr.Variable("model", list(self.datasets.keys())),
|
|
103
|
+
)
|
|
104
|
+
save_dataset(method, name, ds, **config)
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError("Unknown option for dataset saving.")
|
|
107
|
+
|
|
108
|
+
def verify(self):
|
|
109
|
+
config = self.config["verification"]
|
|
110
|
+
reference = self.datasets[config["reference"]]
|
|
111
|
+
config_metrics = config.get("metrics", None)
|
|
112
|
+
config_save_metrics = config.get("save", None)
|
|
113
|
+
|
|
114
|
+
common_vars = set(reference.data_vars)
|
|
115
|
+
for ds in self.datasets.values():
|
|
116
|
+
common_vars.intersection_update(set(ds.data_vars))
|
|
117
|
+
common_vars = list(common_vars)
|
|
118
|
+
|
|
119
|
+
if config_metrics:
|
|
120
|
+
metrics = {}
|
|
121
|
+
for metric_name, config_metric in config["metrics"].items():
|
|
122
|
+
config_metric = config_metric.copy()
|
|
123
|
+
func_path = config_metric.pop("function")
|
|
124
|
+
inputs = config_metric.pop("inputs")
|
|
125
|
+
|
|
126
|
+
metric = Metric(
|
|
127
|
+
name=metric_name,
|
|
128
|
+
func_path=func_path,
|
|
129
|
+
ds_ref=reference[common_vars],
|
|
130
|
+
inputs=inputs,
|
|
131
|
+
**config_metric,
|
|
132
|
+
)
|
|
133
|
+
models = {}
|
|
134
|
+
for ds_name, ds in self.datasets.items():
|
|
135
|
+
if ds_name != config["reference"]:
|
|
136
|
+
models[ds_name] = metric.compute(ds[common_vars])
|
|
137
|
+
models = xr.concat(
|
|
138
|
+
models.values(), dim=xr.Variable("model", list(models.keys()))
|
|
139
|
+
)
|
|
140
|
+
metrics[metric.name] = models
|
|
141
|
+
metrics = xr.concat(
|
|
142
|
+
metrics.values(), dim=xr.Variable("metric", list(metrics.keys()))
|
|
143
|
+
)
|
|
144
|
+
self.metrics = metrics.transpose("model", "metric", ...).compute()
|
|
145
|
+
|
|
146
|
+
if config_save_metrics:
|
|
147
|
+
config = config_save_metrics.copy()
|
|
148
|
+
method = config.pop("method")
|
|
149
|
+
save_metrics(method, self.metrics, **config)
|
|
150
|
+
|
|
151
|
+
def align_time(self, config):
|
|
152
|
+
self.datasets = align_time(self.datasets, **config)
|
|
153
|
+
|
|
154
|
+
def align_space(self, reference, config):
|
|
155
|
+
ds_ref = self.datasets[reference]
|
|
156
|
+
for name, ds in self.datasets.items():
|
|
157
|
+
if name != reference:
|
|
158
|
+
options = config.get(get_spatial_alignment(ds, ds_ref), {})
|
|
159
|
+
self.datasets[name] = align_space(ds, ds_ref, **options)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def get_spatial_alignment(ds, reference):
|
|
163
|
+
if reference.space.is_point() and ds.space.is_grid():
|
|
164
|
+
return "interpolation"
|
|
165
|
+
if reference.space.is_grid() and ds.space.is_grid():
|
|
166
|
+
return "regrid"
|
|
167
|
+
return "null"
|