mxalign 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. mxalign/__init__.py +36 -0
  2. mxalign/accessors/__init__.py +7 -0
  3. mxalign/accessors/space.py +205 -0
  4. mxalign/accessors/time.py +180 -0
  5. mxalign/align/__init__.py +7 -0
  6. mxalign/align/nans.py +72 -0
  7. mxalign/align/space.py +21 -0
  8. mxalign/align/time.py +62 -0
  9. mxalign/cli.py +157 -0
  10. mxalign/interpolations/__init__.py +9 -0
  11. mxalign/interpolations/base.py +29 -0
  12. mxalign/interpolations/delaunay.py +218 -0
  13. mxalign/interpolations/interpolate.py +29 -0
  14. mxalign/interpolations/registry.py +17 -0
  15. mxalign/interpolations/xarray.py +63 -0
  16. mxalign/loaders/__init__.py +11 -0
  17. mxalign/loaders/anemoi_datasets.py +92 -0
  18. mxalign/loaders/anemoi_inference.py +103 -0
  19. mxalign/loaders/base.py +103 -0
  20. mxalign/loaders/harp_obstable.py +81 -0
  21. mxalign/loaders/loader.py +8 -0
  22. mxalign/loaders/registry.py +17 -0
  23. mxalign/properties/__init__.py +0 -0
  24. mxalign/properties/properties.py +25 -0
  25. mxalign/properties/specs.py +54 -0
  26. mxalign/properties/utils.py +43 -0
  27. mxalign/properties/validation.py +48 -0
  28. mxalign/runner.py +167 -0
  29. mxalign/transformations/__init__.py +7 -0
  30. mxalign/transformations/base.py +38 -0
  31. mxalign/transformations/external.py +34 -0
  32. mxalign/transformations/registry.py +20 -0
  33. mxalign/transformations/transform.py +28 -0
  34. mxalign/utils/config.py +55 -0
  35. mxalign/utils/dates.py +76 -0
  36. mxalign/utils/projections.py +104 -0
  37. mxalign/utils/save.py +62 -0
  38. mxalign/verification.py +57 -0
  39. mxalign-0.1.0.dist-info/METADATA +136 -0
  40. mxalign-0.1.0.dist-info/RECORD +43 -0
  41. mxalign-0.1.0.dist-info/WHEEL +4 -0
  42. mxalign-0.1.0.dist-info/entry_points.txt +2 -0
  43. mxalign-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,103 @@
1
+ from pathlib import Path
2
+ import xarray as xr
3
+
4
+ from .registry import register_loader
5
+ from ..properties.properties import Space, Time, Uncertainty
6
+ from .base import BaseLoader
7
+
8
+ DEFAULTS_NETCDF = {"chunks": "auto", "engine": "h5netcdf", "parallel": True}
9
+
10
+ DEFAULTS_ZARR = {
11
+ "chunks": "auto",
12
+ "storage_options": {"anon": True},
13
+ }
14
+
15
+
16
+ @register_loader
17
+ class AnemoiInferenceLoader(BaseLoader):
18
+ name = "anemoi-inference"
19
+
20
+ space = Space.GRID
21
+ time = Time.FORECAST
22
+ uncertainty = Uncertainty.DETERMINISTIC
23
+
24
+ def _load(self):
25
+
26
+ kwargs = self.kwargs.copy()
27
+
28
+ if isinstance(self.files, str):
29
+ if Path(self.files).suffix.lower() == ".zarr":
30
+ files = self.files
31
+
32
+ for k, v in DEFAULTS_ZARR.items():
33
+ kwargs[k] = self.kwargs.get(k, v)
34
+
35
+ loader = _open_zarr
36
+ else:
37
+ files = [self.files]
38
+
39
+ for k, v in DEFAULTS_NETCDF.items():
40
+ kwargs[k] = self.kwargs.get(k, v)
41
+
42
+ loader = _open_mf_dataset
43
+ else:
44
+ files = self.files
45
+ if Path(files[0]).suffix.lower() == ".zarr":
46
+ for k, v in DEFAULTS_ZARR.items():
47
+ kwargs[k] = self.kwargs.get(k, v)
48
+ kwargs["engine"] = "zarr"
49
+
50
+ else:
51
+ for k, v in DEFAULTS_NETCDF.items():
52
+ kwargs[k] = self.kwargs.get(k, v)
53
+
54
+ loader = _open_mf_dataset
55
+
56
+ ds = loader(files, **kwargs)
57
+ return ds
58
+
59
+
60
+ def _open_mf_dataset(files, **kwargs):
61
+
62
+ times = xr.open_dataset(files[0], engine=kwargs["engine"], chunks=kwargs["chunks"])[
63
+ "time"
64
+ ].values
65
+ lead_times = times - times[0]
66
+
67
+ ds = xr.open_mfdataset(files, preprocess=_preprocess, **kwargs)
68
+
69
+ ds_out = (
70
+ ds.assign_coords({"lead_time": ("time", lead_times)})
71
+ .rename_dims({"values": "grid_index"})
72
+ .swap_dims({"time": "lead_time"})
73
+ )
74
+
75
+ return ds_out
76
+
77
+
78
+ def _open_zarr(files, **kwargs):
79
+
80
+ ds = xr.open_zarr(files, **kwargs)
81
+ times = ds["time"].values
82
+ lead_times = times - times[0]
83
+
84
+ ds_out = _preprocess(ds)
85
+
86
+ ds_out = (
87
+ ds_out.assign_coords({"lead_time": ("time", lead_times)})
88
+ .rename_dims({"values": "grid_index"})
89
+ .swap_dims({"time": "lead_time"})
90
+ )
91
+
92
+ return ds_out
93
+
94
+
95
+ def _preprocess(ds):
96
+ ds_out = (
97
+ ds.set_coords(["longitude", "latitude"])
98
+ .expand_dims("reference_time")
99
+ .assign_coords({"reference_time": ("reference_time", [ds["time"].values[0]])})
100
+ .drop_vars("time")
101
+ )
102
+
103
+ return ds_out
@@ -0,0 +1,103 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from .registry import register_loader
4
+ from ..properties.properties import Properties, Space, Time, Uncertainty
5
+ from ..properties.validation import validate_dataset
6
+ from ..properties.utils import properties_to_attrs
7
+
8
+
9
+ class BaseLoader(ABC):
10
+ """Base class for all loaders."""
11
+
12
+ name: str = "base"
13
+
14
+ space: Space | None = None
15
+ time: Time | None = None
16
+ uncertainty: Uncertainty | None = None
17
+
18
+ def __init__(self, files, variables=None, grid_mapping=None, **kwargs):
19
+ self.files = files
20
+ self.variables = [variables] if isinstance(variables, str) else variables
21
+ self.grid_mapping = grid_mapping
22
+ self.kwargs = kwargs
23
+
24
+ def load(self):
25
+ ds = self._load()
26
+ if self.variables:
27
+ ds = self._select_variables(ds)
28
+
29
+ properties = self._get_properties(ds)
30
+ validate_dataset(ds, properties)
31
+
32
+ ds.attrs["properties"] = properties_to_attrs(properties)
33
+
34
+ if self.grid_mapping:
35
+ ds = self._add_grid_mapping(ds)
36
+
37
+ # Make sure all the coordinates are loaded
38
+ for coord in ds.coords:
39
+ ds[coord] = ds[coord].compute()
40
+
41
+ return ds
42
+
43
+ @abstractmethod
44
+ def _load(self): ...
45
+
46
+ def _select_variables(self, ds):
47
+ return ds[self.variables]
48
+
49
+ def _add_grid_mapping(self, ds):
50
+ ds = ds.space.add_crs(self.grid_mapping)
51
+ ds = ds.space.add_grid_mapping(self.grid_mapping)
52
+ return ds
53
+
54
+ def _get_properties(self, ds):
55
+ properties = Properties(
56
+ space=self.space, time=self.time, uncertainty=self.uncertainty
57
+ )
58
+ return properties
59
+
60
+
61
+ @register_loader
62
+ class MxAlignLoader(BaseLoader):
63
+ name = "mxalign"
64
+
65
+ space = None
66
+ time = None
67
+ uncertainty = None
68
+
69
+ def _load(self):
70
+ import xarray as xr
71
+
72
+ files = [self.files] if isinstance(self.files, str) else self.files
73
+
74
+ ds = xr.open_mfdataset(files, chunks="auto", **self.kwargs)
75
+ if "code" in ds.dims:
76
+ ds = ds.rename_dims({"code": "point_index"}).transpose(
77
+ "valid_time", "point_index"
78
+ )
79
+ return ds
80
+
81
+ def _get_properties(self, ds):
82
+ if "reference_time" in ds.dims and "lead_time" in ds.dims:
83
+ time = Time.FORECAST
84
+ elif "valid_time" in ds.dims:
85
+ time = Time.OBSERVATION
86
+ else:
87
+ raise ValueError("Unknown temporal dimensions")
88
+
89
+ if "grid_index" in ds.dims or "xc" in ds.dims or "latitude" in ds.dims:
90
+ space = Space.GRID
91
+ elif "point_index" in ds.dims:
92
+ space = Space.POINT
93
+ else:
94
+ raise ValueError("Unknown spatial dimensions")
95
+
96
+ if "member" in ds.dims:
97
+ uncertainty = Uncertainty.ENSEMBLE
98
+ elif "quantile" in ds.dims:
99
+ uncertainty = Uncertainty.QUANTILE
100
+ else:
101
+ uncertainty = Uncertainty.DETERMINISTIC
102
+
103
+ return Properties(space=space, time=time, uncertainty=uncertainty)
@@ -0,0 +1,81 @@
1
+ import sqlite3
2
+ import pandas as pd
3
+
4
+ from .registry import register_loader
5
+ from ..properties.properties import Space, Time, Uncertainty
6
+ from .base import BaseLoader
7
+
8
+ COORDS = {
9
+ "longitude": "lon",
10
+ "latitude": "lat",
11
+ "valid_time": "validdate",
12
+ "code": "SID",
13
+ "altitude": "elev",
14
+ }
15
+
16
+
17
+ @register_loader
18
+ class ObstableLoader(BaseLoader):
19
+ name = "harp-obstable"
20
+
21
+ space = Space.POINT
22
+ time = Time.OBSERVATION
23
+ uncertainty = Uncertainty.DETERMINISTIC
24
+
25
+ def _load(self):
26
+ if isinstance(self.files, list) and len(self.files > 1):
27
+ raise NotImplementedError(
28
+ "Reading from multiple SQLite-files not implemented"
29
+ )
30
+
31
+ conn = sqlite3.connect(self.files)
32
+
33
+ if self.variables is None:
34
+ # Retrieve all variables
35
+ variables = [
36
+ var
37
+ for var in pd.read_sql_query(
38
+ "SELECT * FROM SYNOP LIMIT 0", conn
39
+ ).columns
40
+ if var not in COORDS.values()
41
+ ]
42
+ print(variables)
43
+ else:
44
+ variables = self.variables
45
+
46
+ # Read the SIDs
47
+ codes = pd.read_sql(
48
+ "SELECT SID as code, MIN(lat) AS latitude, MIN(lon) AS longitude, elev as altitude FROM SYNOP GROUP BY SID",
49
+ conn,
50
+ index_col="code",
51
+ ).to_xarray()
52
+
53
+ print(codes)
54
+ # Read the data
55
+ query = f"""
56
+ SELECT SID as code, validdate as valid_time, {", ".join(variables)}
57
+ FROM SYNOP
58
+ """
59
+ print(query)
60
+ df = pd.read_sql(
61
+ query,
62
+ conn,
63
+ index_col=["code", "valid_time"],
64
+ parse_dates={"valid_time": {"unit": "s"}},
65
+ )
66
+ print(df)
67
+
68
+ ds = df.to_xarray()
69
+ lon_values = codes["longitude"].sel(code=ds["code"]).values
70
+ lat_values = codes["latitude"].sel(code=ds["code"]).values
71
+ alt_values = codes["altitude"].sel(code=ds["code"]).values
72
+
73
+ ds = ds.assign_coords(
74
+ longitude=("code", lon_values),
75
+ latitude=("code", lat_values),
76
+ altitude=("code", alt_values),
77
+ )
78
+
79
+ return ds.rename_dims({"code": "point_index"}).transpose(
80
+ "valid_time", "point_index"
81
+ )
@@ -0,0 +1,8 @@
1
+ from .registry import get_loader
2
+
3
+
4
+ def load(name, files, variables=None, grid_mapping=None, **kwargs):
5
+ loader_cls = get_loader(name)
6
+ loader = loader_cls(files, variables, grid_mapping, **kwargs)
7
+
8
+ return loader.load()
@@ -0,0 +1,17 @@
1
+ _LOADERS = {}
2
+
3
+
4
+ def register_loader(cls):
5
+ _LOADERS[cls.name] = cls
6
+ return cls
7
+
8
+
9
+ def available_loaders():
10
+ return list(_LOADERS.keys())
11
+
12
+
13
+ def get_loader(name):
14
+ try:
15
+ return _LOADERS[name]
16
+ except KeyError:
17
+ raise ValueError(f"Unknown loader: {name}")
File without changes
@@ -0,0 +1,25 @@
1
+ from enum import Enum
2
+ from dataclasses import dataclass
3
+
4
+
5
+ class Space(str, Enum):
6
+ GRID = "grid"
7
+ POINT = "point"
8
+
9
+
10
+ class Time(str, Enum):
11
+ FORECAST = "forecast"
12
+ OBSERVATION = "observation"
13
+
14
+
15
+ class Uncertainty(str, Enum):
16
+ DETERMINISTIC = "deterministic"
17
+ ENSEMBLE = "ensemble"
18
+ QUANTILE = "quantile"
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class Properties:
23
+ space: Space
24
+ time: Time
25
+ uncertainty: Uncertainty = Uncertainty.DETERMINISTIC
@@ -0,0 +1,54 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable
3
+ from .properties import Space, Time, Uncertainty
4
+
5
+
6
+ @dataclass
7
+ class PropertySpec:
8
+ dim_variants: list[set[str]] = field(default_factory=list)
9
+ required_coords: set[str] = field(default_factory=set)
10
+ optional_dims: set[str] = field(default_factory=set)
11
+ optional_coords: set[str] = field(default_factory=set)
12
+ validators: list[Callable] = field(default_factory=list)
13
+
14
+
15
+ SPACE_SPECS = {
16
+ Space.GRID: PropertySpec(
17
+ dim_variants=[
18
+ {"xc", "yc"},
19
+ {"grid_index"},
20
+ {"longitude", "latitude"},
21
+ ],
22
+ required_coords={"longitude", "latitude"},
23
+ optional_coords={"xc", "yc"},
24
+ optional_dims={"member"},
25
+ ),
26
+ Space.POINT: PropertySpec(
27
+ dim_variants=[
28
+ {"point_index"},
29
+ ],
30
+ required_coords={"longitude", "latitude"},
31
+ optional_coords={"code", "elevation", "name", "country"},
32
+ ),
33
+ }
34
+ TIME_SPECS = {
35
+ Time.FORECAST: PropertySpec(
36
+ dim_variants=[{"reference_time", "lead_time"}],
37
+ required_coords={"reference_time", "lead_time"},
38
+ optional_coords={"valid_time"},
39
+ ),
40
+ Time.OBSERVATION: PropertySpec(
41
+ dim_variants=[{"valid_time"}],
42
+ required_coords={"valid_time"},
43
+ ),
44
+ }
45
+
46
+ UNCERTAINTY_SPECS = {
47
+ Uncertainty.DETERMINISTIC: PropertySpec(),
48
+ Uncertainty.ENSEMBLE: PropertySpec(
49
+ dim_variants=[{"member"}], required_coords={"member"}
50
+ ),
51
+ Uncertainty.QUANTILE: PropertySpec(
52
+ dim_variants=[{"quantile"}], required_coords={"quantile"}
53
+ ),
54
+ }
@@ -0,0 +1,43 @@
1
+ from .properties import Properties, Space, Time, Uncertainty
2
+ from .validation import validate_time_dataset, validate_space_dataset
3
+
4
+
5
+ def properties_to_attrs(prop: Properties) -> dict:
6
+ return {
7
+ "space": prop.space.value,
8
+ "time": prop.time.value,
9
+ "uncertainty": prop.uncertainty.value,
10
+ }
11
+
12
+
13
+ def properties_from_attrs(ds) -> Properties:
14
+ attrs = ds.attrs.get("properties", {})
15
+ return Properties(
16
+ space=Space(attrs["space"]),
17
+ time=Time(attrs["time"]),
18
+ uncertainty=Uncertainty(attrs.get("uncertainty", Uncertainty.DETERMINISTIC)),
19
+ )
20
+
21
+
22
+ def update_space_property(ds, prop: Space):
23
+ old_props = properties_from_attrs(ds)
24
+ new_props = Properties(
25
+ space=prop,
26
+ time=old_props.time,
27
+ uncertainty=old_props.uncertainty,
28
+ )
29
+ validate_space_dataset(ds, new_props)
30
+ ds.attrs["properties"] = properties_to_attrs(new_props)
31
+ return ds
32
+
33
+
34
+ def update_time_property(ds, prop: Time):
35
+ old_props = properties_from_attrs(ds)
36
+ new_props = Properties(
37
+ space=old_props.space,
38
+ time=prop,
39
+ uncertainty=old_props.uncertainty,
40
+ )
41
+ validate_time_dataset(ds, new_props)
42
+ ds.attrs["properties"] = properties_to_attrs(new_props)
43
+ return ds
@@ -0,0 +1,48 @@
1
+ from .specs import SPACE_SPECS, TIME_SPECS, UNCERTAINTY_SPECS
2
+
3
+
4
+ def _validate_dims(ds, variants):
5
+ if not variants:
6
+ return
7
+
8
+ ds_dims = set(ds.dims)
9
+
10
+ for variant in variants:
11
+ if variant.issubset(ds_dims):
12
+ return
13
+
14
+ raise ValueError(f"Dataset dims {ds_dims} do not match allowed variants {variants}")
15
+
16
+
17
+ def _validate_coords(ds, required_coords, axis):
18
+ missing = required_coords - set(ds.coords)
19
+ if missing:
20
+ raise ValueError(f"{axis}: missing required coordinates {missing}")
21
+
22
+
23
+ # TIME
24
+ def validate_time_dataset(ds, properties):
25
+ time_spec = TIME_SPECS[properties.time.value]
26
+ _validate_dims(ds, time_spec.dim_variants)
27
+ _validate_coords(ds, time_spec.required_coords, "time")
28
+
29
+
30
+ # SPACE
31
+ def validate_space_dataset(ds, properties):
32
+ space_spec = SPACE_SPECS[properties.space.value]
33
+ _validate_dims(ds, space_spec.dim_variants)
34
+ _validate_coords(ds, space_spec.required_coords, "space")
35
+ validate_time_dataset(ds, properties)
36
+
37
+
38
+ # UNCERTAINTY
39
+ def validate_uncertainty_dataset(ds, properties):
40
+ uncertainty_spec = UNCERTAINTY_SPECS[properties.uncertainty.value]
41
+ _validate_dims(ds, uncertainty_spec.dim_variants)
42
+ _validate_coords(ds, uncertainty_spec.required_coords, "uncertainty")
43
+
44
+
45
+ def validate_dataset(ds, properties):
46
+ validate_time_dataset(ds, properties)
47
+ validate_space_dataset(ds, properties)
48
+ validate_uncertainty_dataset(ds, properties)
mxalign/runner.py ADDED
@@ -0,0 +1,167 @@
1
+ import os
2
+ import xarray as xr
3
+
4
+ from .utils.config import Config
5
+ from .loaders.loader import load
6
+ from .transformations.transform import transform
7
+ from .align.time import align_time
8
+ from .align.space import align_space
9
+ from .align.nans import broadcast_nans
10
+ from .utils.save import save_dataset, save_metrics
11
+ from .verification import Metric
12
+
13
+
14
+ class Runner:
15
+ def __init__(self, config: str | dict):
16
+ self.config = Config(config)
17
+ self.datasets = {}
18
+
19
+ def run(self):
20
+ # 1. Load the datasets
21
+ self.load_datasets()
22
+
23
+ # 2. Transform the datasets
24
+ self.transform_datasets()
25
+ self.align()
26
+ self.verify()
27
+
28
+ def load_datasets(self):
29
+ config = self.config["datasets"]
30
+ if config is None:
31
+ return ValueError("No datasets section in the config.")
32
+ for name, config_ds in config.items():
33
+ config_ds = config_ds.copy()
34
+ # Check if all the files exist
35
+ loader = config_ds.pop("loader")
36
+ variables = config_ds.pop("variables", None)
37
+ grid_mapping = config_ds.pop("grid_mapping", None)
38
+ files = []
39
+ # Check if all the files exist
40
+ for file in config_ds.pop("files"):
41
+ if os.path.exists(file):
42
+ files.append(file)
43
+ else:
44
+ print(f"File: {file} is missing, skipping.")
45
+ self.datasets[name] = load(
46
+ name=loader,
47
+ files=files,
48
+ variables=variables,
49
+ grid_mapping=grid_mapping,
50
+ **config_ds,
51
+ )
52
+
53
+ def transform_datasets(self):
54
+ config = self.config["transformations"]
55
+ if config is None:
56
+ pass
57
+ for transformation, config_trans in config.items():
58
+ config_trans = config_trans.copy()
59
+ # if no datasets specified, apply to all datasets
60
+ names_ds = config_trans.pop("datasets", self.datasets.keys())
61
+ for name in names_ds:
62
+ ds = self.datasets[name]
63
+ self.datasets[name] = transform(
64
+ name=transformation, datasets=ds, **config_trans
65
+ )
66
+
67
+ def align(self):
68
+ config = self.config["alignment"]
69
+ reference = config.pop("reference")
70
+ brdcst_nans = config.pop("broadcast_nans", True)
71
+ config_align_time = config.get("time", None)
72
+ config_align_space = config.get("space", None)
73
+ config_align_save = config.get("save", None)
74
+
75
+ # align in time
76
+ if config_align_time:
77
+ self.align_time(config_align_time)
78
+ else:
79
+ print("Skipping temporal alignment")
80
+
81
+ # align in space
82
+ if config_align_space:
83
+ self.align_space(reference=reference, config=config_align_space)
84
+ else:
85
+ print("Skipping spatial alignment")
86
+
87
+ # broadcast NaNs
88
+ if brdcst_nans:
89
+ self.datasets = broadcast_nans(self.datasets)
90
+
91
+ # Save aligned datasets
92
+ if config_align_save:
93
+ config = config_align_save.copy()
94
+ method = config.pop("method")
95
+ datasets = config.pop("datasets", "all")
96
+ if datasets == "all":
97
+ for name, ds in self.datasets.items():
98
+ save_dataset(method, name, ds, **config)
99
+ elif datasets == "merge":
100
+ ds = xr.concat(
101
+ self.datasets.values(),
102
+ dim=xr.Variable("model", list(self.datasets.keys())),
103
+ )
104
+ save_dataset(method, name, ds, **config)
105
+ else:
106
+ raise ValueError("Unknown option for dataset saving.")
107
+
108
+ def verify(self):
109
+ config = self.config["verification"]
110
+ reference = self.datasets[config["reference"]]
111
+ config_metrics = config.get("metrics", None)
112
+ config_save_metrics = config.get("save", None)
113
+
114
+ common_vars = set(reference.data_vars)
115
+ for ds in self.datasets.values():
116
+ common_vars.intersection_update(set(ds.data_vars))
117
+ common_vars = list(common_vars)
118
+
119
+ if config_metrics:
120
+ metrics = {}
121
+ for metric_name, config_metric in config["metrics"].items():
122
+ config_metric = config_metric.copy()
123
+ func_path = config_metric.pop("function")
124
+ inputs = config_metric.pop("inputs")
125
+
126
+ metric = Metric(
127
+ name=metric_name,
128
+ func_path=func_path,
129
+ ds_ref=reference[common_vars],
130
+ inputs=inputs,
131
+ **config_metric,
132
+ )
133
+ models = {}
134
+ for ds_name, ds in self.datasets.items():
135
+ if ds_name != config["reference"]:
136
+ models[ds_name] = metric.compute(ds[common_vars])
137
+ models = xr.concat(
138
+ models.values(), dim=xr.Variable("model", list(models.keys()))
139
+ )
140
+ metrics[metric.name] = models
141
+ metrics = xr.concat(
142
+ metrics.values(), dim=xr.Variable("metric", list(metrics.keys()))
143
+ )
144
+ self.metrics = metrics.transpose("model", "metric", ...).compute()
145
+
146
+ if config_save_metrics:
147
+ config = config_save_metrics.copy()
148
+ method = config.pop("method")
149
+ save_metrics(method, self.metrics, **config)
150
+
151
+ def align_time(self, config):
152
+ self.datasets = align_time(self.datasets, **config)
153
+
154
+ def align_space(self, reference, config):
155
+ ds_ref = self.datasets[reference]
156
+ for name, ds in self.datasets.items():
157
+ if name != reference:
158
+ options = config.get(get_spatial_alignment(ds, ds_ref), {})
159
+ self.datasets[name] = align_space(ds, ds_ref, **options)
160
+
161
+
162
+ def get_spatial_alignment(ds, reference):
163
+ if reference.space.is_point() and ds.space.is_grid():
164
+ return "interpolation"
165
+ if reference.space.is_grid() and ds.space.is_grid():
166
+ return "regrid"
167
+ return "null"
@@ -0,0 +1,7 @@
1
+ from . import base
2
+ from . import external
3
+
4
+ __all__ = [
5
+ "base",
6
+ "external",
7
+ ]