mxalign 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. mxalign/__init__.py +36 -0
  2. mxalign/accessors/__init__.py +7 -0
  3. mxalign/accessors/space.py +205 -0
  4. mxalign/accessors/time.py +180 -0
  5. mxalign/align/__init__.py +7 -0
  6. mxalign/align/nans.py +72 -0
  7. mxalign/align/space.py +21 -0
  8. mxalign/align/time.py +62 -0
  9. mxalign/cli.py +157 -0
  10. mxalign/interpolations/__init__.py +9 -0
  11. mxalign/interpolations/base.py +29 -0
  12. mxalign/interpolations/delaunay.py +218 -0
  13. mxalign/interpolations/interpolate.py +29 -0
  14. mxalign/interpolations/registry.py +17 -0
  15. mxalign/interpolations/xarray.py +63 -0
  16. mxalign/loaders/__init__.py +11 -0
  17. mxalign/loaders/anemoi_datasets.py +92 -0
  18. mxalign/loaders/anemoi_inference.py +103 -0
  19. mxalign/loaders/base.py +103 -0
  20. mxalign/loaders/harp_obstable.py +81 -0
  21. mxalign/loaders/loader.py +8 -0
  22. mxalign/loaders/registry.py +17 -0
  23. mxalign/properties/__init__.py +0 -0
  24. mxalign/properties/properties.py +25 -0
  25. mxalign/properties/specs.py +54 -0
  26. mxalign/properties/utils.py +43 -0
  27. mxalign/properties/validation.py +48 -0
  28. mxalign/runner.py +167 -0
  29. mxalign/transformations/__init__.py +7 -0
  30. mxalign/transformations/base.py +38 -0
  31. mxalign/transformations/external.py +34 -0
  32. mxalign/transformations/registry.py +20 -0
  33. mxalign/transformations/transform.py +28 -0
  34. mxalign/utils/config.py +55 -0
  35. mxalign/utils/dates.py +76 -0
  36. mxalign/utils/projections.py +104 -0
  37. mxalign/utils/save.py +62 -0
  38. mxalign/verification.py +57 -0
  39. mxalign-0.1.0.dist-info/METADATA +136 -0
  40. mxalign-0.1.0.dist-info/RECORD +43 -0
  41. mxalign-0.1.0.dist-info/WHEEL +4 -0
  42. mxalign-0.1.0.dist-info/entry_points.txt +2 -0
  43. mxalign-0.1.0.dist-info/licenses/LICENSE +21 -0
mxalign/__init__.py ADDED
@@ -0,0 +1,36 @@
1
+ from .properties.properties import Properties, Time, Space, Uncertainty
2
+ from .loaders.loader import load
3
+ from .loaders.registry import available_loaders, register_loader
4
+ from .transformations.transform import transform
5
+ from .transformations.registry import available_transformations, register_transformation
6
+ from .interpolations.interpolate import interpolate
7
+ from .interpolations.registry import available_interpolations, register_interpolator
8
+ from .align.time import align_time
9
+ from .align.space import align_space
10
+
11
+ from . import accessors
12
+ from . import loaders
13
+ from . import transformations
14
+ from . import interpolations
15
+
16
+ __all__ = [
17
+ "Properties",
18
+ "Time",
19
+ "Space",
20
+ "Uncertainty",
21
+ "load",
22
+ "available_loaders",
23
+ "register_loader",
24
+ "transform",
25
+ "available_transformations",
26
+ "register_transformation",
27
+ "interpolate",
28
+ "available_interpolations",
29
+ "register_interpolator",
30
+ "align_time",
31
+ "align_space",
32
+ "accessors",
33
+ "loaders",
34
+ "transformations",
35
+ "interpolations",
36
+ ]
@@ -0,0 +1,7 @@
1
+ from . import space
2
+ from . import time
3
+
4
+ __all__ = [
5
+ "space",
6
+ "time",
7
+ ]
@@ -0,0 +1,205 @@
1
+ import xarray as xr
2
+ import cartopy.crs as ccrs
3
+ import numpy as np
4
+
5
+ from ..properties.properties import Space
6
+ from ..properties.utils import properties_from_attrs
7
+
8
+ from ..utils.projections import create_cartopy_crs, BUILTIN
9
+
10
+ # Tolerance in degrees that the coordinates of two grids can differ while still being interpreted as the same grid.
11
+ # 0.0001 degrees ~ 10m at 45 deg latitude
12
+ COORD_TOLERANCE = 0.0001
13
+
14
+
15
+ @xr.register_dataset_accessor("space")
16
+ class SpaceAccessor:
17
+ def __init__(self, ds):
18
+ self._space = properties_from_attrs(ds).space
19
+ self._ds = ds
20
+
21
+ def is_grid(self):
22
+ return self._space == Space.GRID
23
+
24
+ def is_point(self):
25
+ return self._space == Space.POINT
26
+
27
+ def add_crs(self, crs):
28
+ if self.is_point():
29
+ raise ValueError("Cannot add CRS to a point dataset")
30
+ if isinstance(crs, str):
31
+ try:
32
+ crs = BUILTIN[crs.lower()]
33
+ except KeyError:
34
+ raise ValueError("crs: {crs} not found in supported projections")
35
+ if isinstance(crs, dict):
36
+ crs = create_cartopy_crs(
37
+ projection=crs["projection"],
38
+ kws_projection=crs["kws_projection"],
39
+ kws_globe=crs.get("kws_globe", None),
40
+ )
41
+ return self._ds.assign_attrs({"crs": crs})
42
+
43
+ def add_grid_mapping(self, grid_mapping: str | dict):
44
+ if self.is_point():
45
+ raise ValueError("Cannot add grid mapping to a point dataset")
46
+ if isinstance(grid_mapping, str):
47
+ try:
48
+ grid_mapping = BUILTIN[grid_mapping.lower()]["kws_grid"]
49
+ except KeyError:
50
+ raise ValueError(
51
+ "grid mapping: {grid_mapping} not found in supported mappings"
52
+ )
53
+ return self._ds.assign_attrs({"grid_mapping": grid_mapping})
54
+
55
+ def add_xy(self, crs=None):
56
+ if crs is not None:
57
+ self._ds = self.add_crs(crs)
58
+
59
+ crs = self._ds.attrs.get("crs", None)
60
+
61
+ if crs is None:
62
+ raise ValueError("No CRS provided and no CRS found in dataset attributes")
63
+
64
+ if {"longitude", "latitude"}.issubset(self._ds.dims):
65
+ raise ValueError(
66
+ "Cannot add x/y coordinates to a GRID dataset that has longitude/latitude dimensions"
67
+ )
68
+ elif {"xc", "yc"}.issubset(self._ds.coords):
69
+ return self._ds
70
+ else:
71
+ xyz = crs.transform_points(
72
+ x=self._ds["longitude"].values,
73
+ y=self._ds["latitude"].values,
74
+ src_crs=ccrs.PlateCarree(),
75
+ )
76
+
77
+ if self.is_grid():
78
+ ds_out = self._ds.assign_coords(
79
+ xc=("grid_index", xyz[:, 0]), yc=("grid_index", xyz[:, 1])
80
+ )
81
+ elif self.is_point():
82
+ ds_out = self._ds.assign_coords(
83
+ xc=("point_index", xyz[:, 0]), yc=("point_index", xyz[:, 1])
84
+ )
85
+ else:
86
+ raise ValueError("Dataset does not have expected spatial properties")
87
+
88
+ return ds_out
89
+
90
+ def is_stacked(self):
91
+ if {"xc", "yc"}.issubset(self._ds.dims) or {"longitude", "latitude"}.issubset(
92
+ self._ds.dims
93
+ ):
94
+ return False
95
+ elif "grid_index" in self._ds.dims:
96
+ return True
97
+ else:
98
+ raise ValueError("Dataset does not have expected dimensions for GRID")
99
+
100
+ def stack(self):
101
+ if self.is_point():
102
+ raise ValueError("POINT datasets cannot be stacked")
103
+ if self.is_stacked():
104
+ return self._ds
105
+ else:
106
+ if {"xc", "yc"}.issubset(self._ds.dims):
107
+ dims_to_stack = ["yc", "xc"]
108
+ elif {"lat", "lon"}.issubset(self._ds.dims):
109
+ dims_to_stack = ["lat", "lon"]
110
+ else:
111
+ raise ValueError("Could not find correct dimensions to stack")
112
+ return self._ds.stack({"grid_index": dims_to_stack}).reset_index("grid_index")
113
+
114
+ def unstack(self, crs=None, **kwargs):
115
+ if self.is_point():
116
+ raise ValueError("POINT datasets cannot be unstacked")
117
+ if not self.is_stacked():
118
+ return self._ds
119
+ else:
120
+ if crs:
121
+ self.add_crs(crs)
122
+ kws_mindex = dict.fromkeys(["nx", "ny", "lon_ll", "lat_ll", "dx", "dy"])
123
+ for key in kws_mindex.keys():
124
+ value = kwargs.get(key, None)
125
+ if value is None:
126
+ try:
127
+ value = self._ds.attrs["grid_mapping"][key]
128
+ except KeyError:
129
+ raise KeyError(
130
+ f"Did not find a value for {key} in the dataset attributes, please provide it as an argument"
131
+ )
132
+ kws_mindex[key] = value
133
+
134
+ mindex = self._create_multiindex(**kws_mindex)
135
+ mcoords = xr.Coordinates.from_pandas_multiindex(mindex, "grid_index")
136
+ ds_mindex = self._ds.assign_coords(mcoords)
137
+ ds_mindex.attrs["grid_mapping"] = kws_mindex
138
+ return ds_mindex.unstack()
139
+
140
+ def _create_multiindex(self, nx, ny, lon_ll, lat_ll, dx, dy, **kwargs):
141
+ from pandas import MultiIndex
142
+
143
+ if self._ds.sizes["grid_index"] != nx * ny:
144
+ raise ValueError(
145
+ f"Size of grid_index ({self._ds.sizes['grid_index']}) does not match product of nx and ny ({nx * ny})"
146
+ )
147
+
148
+ crs = self._ds.attrs["crs"]
149
+ x_ll, y_ll = crs.transform_point(x=lon_ll, y=lat_ll, src_crs=ccrs.PlateCarree())
150
+
151
+ xc = x_ll + np.arange(nx) * dx
152
+ yc = y_ll + np.arange(ny) * dy
153
+
154
+ mindex = MultiIndex.from_product([yc, xc], names=["yc", "xc"])
155
+
156
+ return mindex
157
+
158
+ def align_with(self, ds, **kwargs):
159
+ if self.is_grid():
160
+ if ds.space.is_grid():
161
+ return _align_grid_grid(self._ds, ds, **kwargs)
162
+ elif ds.space.is_point():
163
+ return _align_grid_point(self._ds, ds, **kwargs)
164
+ elif self.is_point():
165
+ if ds.space.is_point():
166
+ return _align_point_point(self._ds, ds, **kwargs)
167
+ elif ds.space.is_grid():
168
+ return _align_point_grid(self._ds, ds, **kwargs)
169
+ else:
170
+ raise ValueError("Datasets do not have compatible spatial properties")
171
+
172
+
173
+ def _align_grid_grid(ds1, ds2, **kwargs):
174
+ if np.array_equal(
175
+ ds1["longitude"].values, ds2["longitude"].values
176
+ ) and np.array_equal(ds1["latitude"].values, ds2["latitude"].values):
177
+ return ds1, ds2
178
+ elif np.allclose(
179
+ ds1["longitude"].values, ds2["longitude"].values, atol=COORD_TOLERANCE
180
+ ) and np.allclose(
181
+ ds1["latitude"].values, ds2["latitude"].values, atol=COORD_TOLERANCE
182
+ ):
183
+ print(
184
+ f"Some lat-lon coordinates differ. But the difference is smaller than {COORD_TOLERANCE} degrees, considering both grids as equal"
185
+ )
186
+ return ds1, ds2
187
+ else:
188
+ raise NotImplementedError("Regridding not implemented")
189
+
190
+
191
+ def _align_grid_point(ds1, ds2, **kwargs):
192
+ from ..interpolations.interpolate import interpolate
193
+
194
+ method = kwargs.pop("method", "xarray")
195
+ ds1 = interpolate(ds1, ds2, method, **kwargs)
196
+
197
+ return ds1, ds2
198
+
199
+
200
+ def _align_point_point(ds1, ds2, **kwargs):
201
+ raise NotImplementedError("Point selection not implemented")
202
+
203
+
204
+ def _align_point_grid(ds1, ds2, **kwargs):
205
+ raise NotImplementedError("Gridding of Point datanot implemented")
@@ -0,0 +1,180 @@
1
+ import xarray as xr
2
+ import numpy as np
3
+
4
+ from ..properties.properties import Time
5
+ from ..properties.utils import properties_from_attrs, update_time_property
6
+
7
+
8
+ @xr.register_dataset_accessor("time")
9
+ class TimeAccessor:
10
+ def __init__(self, ds):
11
+ self._time = properties_from_attrs(ds).time
12
+ self._ds = ds
13
+
14
+ def is_forecast(self):
15
+ return self._time == Time.FORECAST
16
+
17
+ def is_observation(self):
18
+ return self._time == Time.OBSERVATION
19
+
20
+ def add_valid_time(self):
21
+ if self.is_forecast():
22
+ valid_time = (
23
+ self._ds["reference_time"].values[:, np.newaxis]
24
+ + self._ds["lead_time"].values
25
+ )
26
+ ds_out = self._ds.assign_coords(
27
+ {"valid_time": (["reference_time", "lead_time"], valid_time)}
28
+ )
29
+ else:
30
+ ds_out = self._ds
31
+ return ds_out
32
+
33
+ def align_with(self, ds, **kwargs):
34
+ if self.is_forecast():
35
+ if ds.time.is_forecast():
36
+ return _align_forecast_forecast(self._ds, ds, **kwargs)
37
+ elif ds.time.is_observation():
38
+ return _align_forecast_observation(self._ds, ds, **kwargs)
39
+ elif self.is_observation():
40
+ if ds.time.is_observation():
41
+ return _align_observation_observation(self._ds, ds, **kwargs)
42
+ elif ds.time.is_forecast():
43
+ return _align_observation_forecast(self._ds, ds, **kwargs)
44
+ else:
45
+ raise ValueError("Datasets do not have compatible temporal properties")
46
+
47
+
48
+ def _align_forecast_forecast(ds1, ds2, only_common=False):
49
+ # Align the reference times
50
+ common_reference_times = ds1.indexes["reference_time"].intersection(
51
+ ds2.indexes["reference_time"]
52
+ )
53
+ ds1_aligned = ds1.sel(reference_time=common_reference_times)
54
+ ds2_aligned = ds2.sel(reference_time=common_reference_times)
55
+
56
+ # Align the lead times
57
+ if only_common:
58
+ common_lead_times = ds1_aligned.indexes["lead_time"].intersection(
59
+ ds2_aligned.indexes["lead_time"]
60
+ )
61
+ ds1_aligned = ds1_aligned.sel(lead_time=common_lead_times)
62
+ ds2_aligned = ds2_aligned.sel(lead_time=common_lead_times)
63
+ else:
64
+ non_aligning_dims = (set(ds1.dims) | set(ds2.dims)) - set(["lead_time"])
65
+ ds1_aligned, ds2_aligned = xr.align(
66
+ ds1_aligned, ds2_aligned, join="outer", exclude=non_aligning_dims
67
+ )
68
+ ds1_aligned = ds1_aligned.time.add_valid_time()
69
+ ds2_aligned = ds2_aligned.time.add_valid_time()
70
+ return ds1_aligned, ds2_aligned
71
+
72
+
73
+ def _align_forecast_observation(
74
+ ds_forecast, ds_observation, only_common=False, lead_time="start-min"
75
+ ):
76
+ ds_forecast = ds_forecast.time.add_valid_time()
77
+
78
+ # Check if reference_times are continuous
79
+ reference_time_diff = ds_forecast.reference_time.diff("reference_time").values
80
+ if not (reference_time_diff[0] == reference_time_diff).all():
81
+ raise NotImplementedError(
82
+ "Aligning a forecast with non-continuous reference times with an observation is not implemented."
83
+ )
84
+ if lead_time == "start-min":
85
+ min_diff = reference_time_diff[0]
86
+ ds_forecast_reduced = ds_forecast.where(
87
+ ds_forecast.lead_time < min_diff, drop=True
88
+ )
89
+ elif lead_time == "start-max":
90
+ max_diff = ds_forecast.lead_time.max().values
91
+ reference_times = np.arange(
92
+ ds_forecast.reference_time.min().values,
93
+ ds_forecast.reference_time.max().values,
94
+ max_diff,
95
+ dtype="datetime64[ns]",
96
+ )
97
+ ds_forecast_reduced = ds_forecast.sel(reference_time=reference_times)
98
+ else:
99
+ raise ValueError(
100
+ "Invalid value for lead_time. Expected 'start-min' or 'start-max'."
101
+ )
102
+
103
+ ds_forecast_stacked = (
104
+ ds_forecast_reduced.stack(time=["reference_time", "lead_time"])
105
+ .reset_index("time")
106
+ .swap_dims({"time": "valid_time"})
107
+ .transpose("valid_time", ...)
108
+ )
109
+ if only_common:
110
+ ds_forecast_aligned, ds_observation_aligned = xr.align(
111
+ ds_forecast_stacked,
112
+ ds_observation,
113
+ join="inner",
114
+ exclude=set(ds_forecast_stacked.coords)
115
+ | set(ds_observation.coords) - set(["valid_time"]),
116
+ )
117
+ else:
118
+ ds_forecast_aligned, ds_observation_aligned = xr.align(
119
+ ds_forecast_stacked,
120
+ ds_observation,
121
+ join="outer",
122
+ exclude=set(ds_forecast_stacked.coords)
123
+ | set(ds_observation.coords) - set(["valid_time"]),
124
+ )
125
+ ds_forecast_aligned = update_time_property(ds_forecast_aligned, Time.OBSERVATION)
126
+ return ds_forecast_aligned, ds_observation_aligned
127
+
128
+
129
+ def _align_observation_observation(ds1, ds2, only_common=False):
130
+ exclude = (set(ds1.dims) | set(ds2.dims)) - set(["valid_time"])
131
+ if only_common:
132
+ ds1_aligned, ds2_aligned = xr.align(ds1, ds2, join="inner", exclude=exclude)
133
+ else:
134
+ ds1_aligned, ds2_aligned = xr.align(ds1, ds2, join="outer", exclude=exclude)
135
+ return ds1_aligned, ds2_aligned
136
+
137
+
138
+ def _align_observation_forecast(ds_observation, ds_forecast, only_common=False):
139
+ ds_forecast_cut = ds_forecast.time.add_valid_time()
140
+ if (
141
+ ds_forecast_cut.reference_time.min().values
142
+ < ds_observation.valid_time.min().values
143
+ ):
144
+ ds_forecast_cut = ds_forecast_cut.sel(
145
+ reference_time=slice(ds_observation.valid_time.min().values, None)
146
+ )
147
+ if ds_forecast_cut.valid_time.max().values > ds_observation.valid_time.max().values:
148
+ # The forecast time-step/lead times might not always align with the maximum observation time
149
+ valid_diff = (
150
+ ds_forecast_cut["valid_time"] - (ds_observation["valid_time"].max())
151
+ ).isel(lead_time=-1)
152
+ last_valid_index = (
153
+ np.abs(valid_diff.where(valid_diff <= 0, drop=True)).argmin().values
154
+ )
155
+ max_reference_time = ds_forecast_cut.isel(reference_time=last_valid_index)[
156
+ "reference_time"
157
+ ].values
158
+
159
+ # max_reference_time = ds_observation.valid_time.max().values - (ds_forecast_cut.lead_time.max().values - shift)
160
+ ds_forecast_cut = ds_forecast_cut.sel(
161
+ reference_time=slice(None, max_reference_time)
162
+ )
163
+
164
+ ds_observation_aligned = ds_observation.sel(valid_time=ds_forecast_cut.valid_time)
165
+ ds_observation_aligned = ds_observation_aligned.transpose(
166
+ "reference_time", "lead_time", ...
167
+ )
168
+ ds_observation_aligned = update_time_property(ds_observation_aligned, Time.FORECAST)
169
+ if only_common:
170
+ return ds_observation_aligned, ds_forecast_cut
171
+ else:
172
+ ds_observation_aligned, ds_forecast_aligned = xr.align(
173
+ ds_observation_aligned,
174
+ ds_forecast.time.add_valid_time(),
175
+ join="outer",
176
+ exclude=(set(ds_observation_aligned.coords) | set(ds_forecast_cut.coords))
177
+ - set(["reference_time", "lead_time"]),
178
+ )
179
+ ds_observation_aligned["valid_time"] = ds_forecast_aligned["valid_time"]
180
+ return ds_observation_aligned, ds_forecast_aligned
@@ -0,0 +1,7 @@
1
+ from . import time
2
+ from . import space
3
+
4
+ __all__ = [
5
+ "time",
6
+ "space",
7
+ ]
mxalign/align/nans.py ADDED
@@ -0,0 +1,72 @@
1
+ import xarray as xr
2
+ import itertools
3
+
4
+
5
+ def broadcast_nans(datasets: dict | list) -> None:
6
+ """
7
+ Broadcasts NaN values across a list of xarray Datasets by ensuring that if a value is NaN
8
+ in one dataset at a specific coordinate, it becomes NaN in all datasets at that coordinate.
9
+
10
+ Parameters
11
+ ----------
12
+ datasets : list[xr.Dataset] | dict[str, xr.Dataset]
13
+ A list of xarray Datasets to process. The datasets should share some common
14
+ coordinates and variables.
15
+
16
+ Returns
17
+ -------
18
+ list[xr.Dataset] | dict[str, xr.Dataset]
19
+
20
+
21
+ Notes
22
+ -----
23
+ - The function operates on pairs of datasets, comparing each dataset with every other dataset
24
+ in the list.
25
+ - Only coordinate values that exist in both datasets of a pair are considered.
26
+ - Only variables that exist in both datasets of a pair are processed.
27
+ - The NaN broadcasting is performed at the intersection of coordinates between each pair
28
+ of datasets.
29
+
30
+ Examples
31
+ --------
32
+ >>> ds1 = xr.Dataset(...)
33
+ >>> ds2 = xr.Dataset(...)
34
+ >>> ds3 = xr.Dataset(...)
35
+ >>> broadcast_nans([ds1, ds2, ds3])
36
+ """
37
+
38
+ if isinstance(datasets, xr.Dataset):
39
+ return datasets
40
+ elif isinstance(datasets, dict):
41
+ keys = list(datasets.keys())
42
+ working = [ds.copy(deep=True) for ds in datasets.values()]
43
+ else:
44
+ keys = None
45
+ working = [ds.copy(deep=True) for ds in datasets]
46
+
47
+ # Iterate over all pairs of datasets
48
+ for dsA, dsB in itertools.combinations(working, 2):
49
+ # Find the shared coordinates for all dimensions
50
+ common_coords = {
51
+ dim: sorted(set(dsA[dim].values) & set(dsB[dim].values)) for dim in dsA.dims
52
+ }
53
+
54
+ # Iterate over all variables
55
+ for var in dsA.data_vars:
56
+ if var in dsB: # Ensure both datasets have the variable
57
+ # Select the data at common coordinates
58
+ selA = dsA[var].sel(**common_coords)
59
+ selB = dsB[var].sel(**common_coords)
60
+
61
+ # Compute NaN mask for shared coordinates
62
+ nan_mask = selA.isnull() | selB.isnull()
63
+
64
+ # Apply NaN mask back to both datasets
65
+ dsA[var].loc[common_coords] = (
66
+ dsA[var].sel(**common_coords).where(~nan_mask)
67
+ )
68
+ dsB[var].loc[common_coords] = (
69
+ dsB[var].sel(**common_coords).where(~nan_mask)
70
+ )
71
+
72
+ return dict(zip(keys, working)) if keys else working
mxalign/align/space.py ADDED
@@ -0,0 +1,21 @@
1
+ import xarray as xr
2
+
3
+
4
+ def align_space(datasets, reference, **kwargs):
5
+ if isinstance(datasets, (xr.Dataset, xr.DataArray)):
6
+ datasets = [datasets]
7
+ if isinstance(datasets, dict):
8
+ keys = datasets.keys()
9
+ datasets = datasets.items()
10
+ else:
11
+ keys = None
12
+
13
+ datasets = [ds.space.align_with(reference, **kwargs)[0] for ds in datasets]
14
+
15
+ if keys is None:
16
+ if len(datasets) == 1:
17
+ return datasets[0]
18
+ else:
19
+ return datasets
20
+ else:
21
+ return {key: value for (key, value) in zip(keys, datasets)}
mxalign/align/time.py ADDED
@@ -0,0 +1,62 @@
1
+ import xarray as xr
2
+
3
+
4
+ def align_time(
5
+ datasets: list[xr.Dataset] | dict[str, xr.Dataset], return_as: str = "forecast"
6
+ ):
7
+ if isinstance(datasets, (xr.Dataset, xr.DataArray)):
8
+ datasets = [datasets]
9
+ if isinstance(datasets, dict):
10
+ keys = datasets.keys()
11
+ datasets = datasets.values()
12
+ else:
13
+ keys = None
14
+
15
+ if return_as != "forecast":
16
+ NotImplementedError(
17
+ "Currently only temporal alignment return forecast structure is supported."
18
+ )
19
+
20
+ # Get the first forecast to start building the valid times
21
+ valid_times_fcst = None
22
+ valid_times_obs = None
23
+ first_fcst = True
24
+ first_obs = True
25
+ for ds in datasets:
26
+ if ds.time.is_forecast():
27
+ if first_fcst:
28
+ valid_times_fcst = ds.time.add_valid_time()["valid_time"].to_dataset(
29
+ name="valid_times"
30
+ )
31
+ valid_times_fcst = valid_times_fcst.assign_attrs(ds.attrs)
32
+ first_fcst = False
33
+ else:
34
+ _ds = ds.time.add_valid_time()["valid_time"].to_dataset(
35
+ name="valid_times"
36
+ )
37
+ _ds = _ds.assign_attrs(ds.attrs)
38
+ _, valid_times_fcst = _ds.time.align_with(valid_times_fcst)
39
+ elif ds.time.is_observation():
40
+ if first_obs:
41
+ valid_times_obs = ds["valid_time"].to_dataset(name="valid_times")
42
+ valid_times_obs = valid_times_obs.assign_attrs(ds.attrs)
43
+ first_obs = False
44
+ else:
45
+ _ds = ds["valid_time"].to_dataset(name="valid_times")
46
+ _ds = _ds.assign_attrs(ds.attrs)
47
+ _, valid_times_obs = _ds.time.align_with(valid_times_obs)
48
+
49
+ if (valid_times_obs is None) and (valid_times_fcst is None):
50
+ raise ValueError("No observations or forecasts found")
51
+ elif valid_times_fcst is None:
52
+ valid_times = valid_times_obs
53
+ elif valid_times_obs is None:
54
+ valid_times = valid_times_fcst
55
+ else:
56
+ _, valid_times = valid_times_obs.time.align_with(valid_times_fcst)
57
+
58
+ datasets = [ds.time.align_with(valid_times)[0] for ds in datasets]
59
+ if keys is None:
60
+ return datasets
61
+ else:
62
+ return {key: value for (key, value) in zip(keys, datasets)}