mxalign 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxalign/__init__.py +36 -0
- mxalign/accessors/__init__.py +7 -0
- mxalign/accessors/space.py +205 -0
- mxalign/accessors/time.py +180 -0
- mxalign/align/__init__.py +7 -0
- mxalign/align/nans.py +72 -0
- mxalign/align/space.py +21 -0
- mxalign/align/time.py +62 -0
- mxalign/cli.py +157 -0
- mxalign/interpolations/__init__.py +9 -0
- mxalign/interpolations/base.py +29 -0
- mxalign/interpolations/delaunay.py +218 -0
- mxalign/interpolations/interpolate.py +29 -0
- mxalign/interpolations/registry.py +17 -0
- mxalign/interpolations/xarray.py +63 -0
- mxalign/loaders/__init__.py +11 -0
- mxalign/loaders/anemoi_datasets.py +92 -0
- mxalign/loaders/anemoi_inference.py +103 -0
- mxalign/loaders/base.py +103 -0
- mxalign/loaders/harp_obstable.py +81 -0
- mxalign/loaders/loader.py +8 -0
- mxalign/loaders/registry.py +17 -0
- mxalign/properties/__init__.py +0 -0
- mxalign/properties/properties.py +25 -0
- mxalign/properties/specs.py +54 -0
- mxalign/properties/utils.py +43 -0
- mxalign/properties/validation.py +48 -0
- mxalign/runner.py +167 -0
- mxalign/transformations/__init__.py +7 -0
- mxalign/transformations/base.py +38 -0
- mxalign/transformations/external.py +34 -0
- mxalign/transformations/registry.py +20 -0
- mxalign/transformations/transform.py +28 -0
- mxalign/utils/config.py +55 -0
- mxalign/utils/dates.py +76 -0
- mxalign/utils/projections.py +104 -0
- mxalign/utils/save.py +62 -0
- mxalign/verification.py +57 -0
- mxalign-0.1.0.dist-info/METADATA +136 -0
- mxalign-0.1.0.dist-info/RECORD +43 -0
- mxalign-0.1.0.dist-info/WHEEL +4 -0
- mxalign-0.1.0.dist-info/entry_points.txt +2 -0
- mxalign-0.1.0.dist-info/licenses/LICENSE +21 -0
mxalign/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from .properties.properties import Properties, Time, Space, Uncertainty
|
|
2
|
+
from .loaders.loader import load
|
|
3
|
+
from .loaders.registry import available_loaders, register_loader
|
|
4
|
+
from .transformations.transform import transform
|
|
5
|
+
from .transformations.registry import available_transformations, register_transformation
|
|
6
|
+
from .interpolations.interpolate import interpolate
|
|
7
|
+
from .interpolations.registry import available_interpolations, register_interpolator
|
|
8
|
+
from .align.time import align_time
|
|
9
|
+
from .align.space import align_space
|
|
10
|
+
|
|
11
|
+
from . import accessors
|
|
12
|
+
from . import loaders
|
|
13
|
+
from . import transformations
|
|
14
|
+
from . import interpolations
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"Properties",
|
|
18
|
+
"Time",
|
|
19
|
+
"Space",
|
|
20
|
+
"Uncertainty",
|
|
21
|
+
"load",
|
|
22
|
+
"available_loaders",
|
|
23
|
+
"register_loader",
|
|
24
|
+
"transform",
|
|
25
|
+
"available_transformations",
|
|
26
|
+
"register_transformation",
|
|
27
|
+
"interpolate",
|
|
28
|
+
"available_interpolations",
|
|
29
|
+
"register_interpolator",
|
|
30
|
+
"align_time",
|
|
31
|
+
"align_space",
|
|
32
|
+
"accessors",
|
|
33
|
+
"loaders",
|
|
34
|
+
"transformations",
|
|
35
|
+
"interpolations",
|
|
36
|
+
]
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
import cartopy.crs as ccrs
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from ..properties.properties import Space
|
|
6
|
+
from ..properties.utils import properties_from_attrs
|
|
7
|
+
|
|
8
|
+
from ..utils.projections import create_cartopy_crs, BUILTIN
|
|
9
|
+
|
|
10
|
+
# Tolerance in degrees that the coordinates of two grids can differ while still being interpreted as the same grid.
|
|
11
|
+
# 0.0001 degrees ~ 10m at 45 deg latitude
|
|
12
|
+
COORD_TOLERANCE = 0.0001
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@xr.register_dataset_accessor("space")
|
|
16
|
+
class SpaceAccessor:
|
|
17
|
+
def __init__(self, ds):
|
|
18
|
+
self._space = properties_from_attrs(ds).space
|
|
19
|
+
self._ds = ds
|
|
20
|
+
|
|
21
|
+
def is_grid(self):
|
|
22
|
+
return self._space == Space.GRID
|
|
23
|
+
|
|
24
|
+
def is_point(self):
|
|
25
|
+
return self._space == Space.POINT
|
|
26
|
+
|
|
27
|
+
def add_crs(self, crs):
|
|
28
|
+
if self.is_point():
|
|
29
|
+
raise ValueError("Cannot add CRS to a point dataset")
|
|
30
|
+
if isinstance(crs, str):
|
|
31
|
+
try:
|
|
32
|
+
crs = BUILTIN[crs.lower()]
|
|
33
|
+
except KeyError:
|
|
34
|
+
raise ValueError("crs: {crs} not found in supported projections")
|
|
35
|
+
if isinstance(crs, dict):
|
|
36
|
+
crs = create_cartopy_crs(
|
|
37
|
+
projection=crs["projection"],
|
|
38
|
+
kws_projection=crs["kws_projection"],
|
|
39
|
+
kws_globe=crs.get("kws_globe", None),
|
|
40
|
+
)
|
|
41
|
+
return self._ds.assign_attrs({"crs": crs})
|
|
42
|
+
|
|
43
|
+
def add_grid_mapping(self, grid_mapping: str | dict):
|
|
44
|
+
if self.is_point():
|
|
45
|
+
raise ValueError("Cannot add grid mapping to a point dataset")
|
|
46
|
+
if isinstance(grid_mapping, str):
|
|
47
|
+
try:
|
|
48
|
+
grid_mapping = BUILTIN[grid_mapping.lower()]["kws_grid"]
|
|
49
|
+
except KeyError:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"grid mapping: {grid_mapping} not found in supported mappings"
|
|
52
|
+
)
|
|
53
|
+
return self._ds.assign_attrs({"grid_mapping": grid_mapping})
|
|
54
|
+
|
|
55
|
+
def add_xy(self, crs=None):
|
|
56
|
+
if crs is not None:
|
|
57
|
+
self._ds = self.add_crs(crs)
|
|
58
|
+
|
|
59
|
+
crs = self._ds.attrs.get("crs", None)
|
|
60
|
+
|
|
61
|
+
if crs is None:
|
|
62
|
+
raise ValueError("No CRS provided and no CRS found in dataset attributes")
|
|
63
|
+
|
|
64
|
+
if {"longitude", "latitude"}.issubset(self._ds.dims):
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"Cannot add x/y coordinates to a GRID dataset that has longitude/latitude dimensions"
|
|
67
|
+
)
|
|
68
|
+
elif {"xc", "yc"}.issubset(self._ds.coords):
|
|
69
|
+
return self._ds
|
|
70
|
+
else:
|
|
71
|
+
xyz = crs.transform_points(
|
|
72
|
+
x=self._ds["longitude"].values,
|
|
73
|
+
y=self._ds["latitude"].values,
|
|
74
|
+
src_crs=ccrs.PlateCarree(),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if self.is_grid():
|
|
78
|
+
ds_out = self._ds.assign_coords(
|
|
79
|
+
xc=("grid_index", xyz[:, 0]), yc=("grid_index", xyz[:, 1])
|
|
80
|
+
)
|
|
81
|
+
elif self.is_point():
|
|
82
|
+
ds_out = self._ds.assign_coords(
|
|
83
|
+
xc=("point_index", xyz[:, 0]), yc=("point_index", xyz[:, 1])
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError("Dataset does not have expected spatial properties")
|
|
87
|
+
|
|
88
|
+
return ds_out
|
|
89
|
+
|
|
90
|
+
def is_stacked(self):
|
|
91
|
+
if {"xc", "yc"}.issubset(self._ds.dims) or {"longitude", "latitude"}.issubset(
|
|
92
|
+
self._ds.dims
|
|
93
|
+
):
|
|
94
|
+
return False
|
|
95
|
+
elif "grid_index" in self._ds.dims:
|
|
96
|
+
return True
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError("Dataset does not have expected dimensions for GRID")
|
|
99
|
+
|
|
100
|
+
def stack(self):
|
|
101
|
+
if self.is_point():
|
|
102
|
+
raise ValueError("POINT datasets cannot be stacked")
|
|
103
|
+
if self.is_stacked():
|
|
104
|
+
return self._ds
|
|
105
|
+
else:
|
|
106
|
+
if {"xc", "yc"}.issubset(self._ds.dims):
|
|
107
|
+
dims_to_stack = ["yc", "xc"]
|
|
108
|
+
elif {"lat", "lon"}.issubset(self._ds.dims):
|
|
109
|
+
dims_to_stack = ["lat", "lon"]
|
|
110
|
+
else:
|
|
111
|
+
raise ValueError("Could not find correct dimensions to stack")
|
|
112
|
+
return self._ds.stack({"grid_index": dims_to_stack}).reset_index("grid_index")
|
|
113
|
+
|
|
114
|
+
def unstack(self, crs=None, **kwargs):
|
|
115
|
+
if self.is_point():
|
|
116
|
+
raise ValueError("POINT datasets cannot be unstacked")
|
|
117
|
+
if not self.is_stacked():
|
|
118
|
+
return self._ds
|
|
119
|
+
else:
|
|
120
|
+
if crs:
|
|
121
|
+
self.add_crs(crs)
|
|
122
|
+
kws_mindex = dict.fromkeys(["nx", "ny", "lon_ll", "lat_ll", "dx", "dy"])
|
|
123
|
+
for key in kws_mindex.keys():
|
|
124
|
+
value = kwargs.get(key, None)
|
|
125
|
+
if value is None:
|
|
126
|
+
try:
|
|
127
|
+
value = self._ds.attrs["grid_mapping"][key]
|
|
128
|
+
except KeyError:
|
|
129
|
+
raise KeyError(
|
|
130
|
+
f"Did not find a value for {key} in the dataset attributes, please provide it as an argument"
|
|
131
|
+
)
|
|
132
|
+
kws_mindex[key] = value
|
|
133
|
+
|
|
134
|
+
mindex = self._create_multiindex(**kws_mindex)
|
|
135
|
+
mcoords = xr.Coordinates.from_pandas_multiindex(mindex, "grid_index")
|
|
136
|
+
ds_mindex = self._ds.assign_coords(mcoords)
|
|
137
|
+
ds_mindex.attrs["grid_mapping"] = kws_mindex
|
|
138
|
+
return ds_mindex.unstack()
|
|
139
|
+
|
|
140
|
+
def _create_multiindex(self, nx, ny, lon_ll, lat_ll, dx, dy, **kwargs):
|
|
141
|
+
from pandas import MultiIndex
|
|
142
|
+
|
|
143
|
+
if self._ds.sizes["grid_index"] != nx * ny:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"Size of grid_index ({self._ds.sizes['grid_index']}) does not match product of nx and ny ({nx * ny})"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
crs = self._ds.attrs["crs"]
|
|
149
|
+
x_ll, y_ll = crs.transform_point(x=lon_ll, y=lat_ll, src_crs=ccrs.PlateCarree())
|
|
150
|
+
|
|
151
|
+
xc = x_ll + np.arange(nx) * dx
|
|
152
|
+
yc = y_ll + np.arange(ny) * dy
|
|
153
|
+
|
|
154
|
+
mindex = MultiIndex.from_product([yc, xc], names=["yc", "xc"])
|
|
155
|
+
|
|
156
|
+
return mindex
|
|
157
|
+
|
|
158
|
+
def align_with(self, ds, **kwargs):
|
|
159
|
+
if self.is_grid():
|
|
160
|
+
if ds.space.is_grid():
|
|
161
|
+
return _align_grid_grid(self._ds, ds, **kwargs)
|
|
162
|
+
elif ds.space.is_point():
|
|
163
|
+
return _align_grid_point(self._ds, ds, **kwargs)
|
|
164
|
+
elif self.is_point():
|
|
165
|
+
if ds.space.is_point():
|
|
166
|
+
return _align_point_point(self._ds, ds, **kwargs)
|
|
167
|
+
elif ds.space.is_grid():
|
|
168
|
+
return _align_point_grid(self._ds, ds, **kwargs)
|
|
169
|
+
else:
|
|
170
|
+
raise ValueError("Datasets do not have compatible spatial properties")
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _align_grid_grid(ds1, ds2, **kwargs):
|
|
174
|
+
if np.array_equal(
|
|
175
|
+
ds1["longitude"].values, ds2["longitude"].values
|
|
176
|
+
) and np.array_equal(ds1["latitude"].values, ds2["latitude"].values):
|
|
177
|
+
return ds1, ds2
|
|
178
|
+
elif np.allclose(
|
|
179
|
+
ds1["longitude"].values, ds2["longitude"].values, atol=COORD_TOLERANCE
|
|
180
|
+
) and np.allclose(
|
|
181
|
+
ds1["latitude"].values, ds2["latitude"].values, atol=COORD_TOLERANCE
|
|
182
|
+
):
|
|
183
|
+
print(
|
|
184
|
+
f"Some lat-lon coordinates differ. But the difference is smaller than {COORD_TOLERANCE} degrees, considering both grids as equal"
|
|
185
|
+
)
|
|
186
|
+
return ds1, ds2
|
|
187
|
+
else:
|
|
188
|
+
raise NotImplementedError("Regridding not implemented")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _align_grid_point(ds1, ds2, **kwargs):
|
|
192
|
+
from ..interpolations.interpolate import interpolate
|
|
193
|
+
|
|
194
|
+
method = kwargs.pop("method", "xarray")
|
|
195
|
+
ds1 = interpolate(ds1, ds2, method, **kwargs)
|
|
196
|
+
|
|
197
|
+
return ds1, ds2
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _align_point_point(ds1, ds2, **kwargs):
|
|
201
|
+
raise NotImplementedError("Point selection not implemented")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _align_point_grid(ds1, ds2, **kwargs):
|
|
205
|
+
raise NotImplementedError("Gridding of Point datanot implemented")
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from ..properties.properties import Time
|
|
5
|
+
from ..properties.utils import properties_from_attrs, update_time_property
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@xr.register_dataset_accessor("time")
|
|
9
|
+
class TimeAccessor:
|
|
10
|
+
def __init__(self, ds):
|
|
11
|
+
self._time = properties_from_attrs(ds).time
|
|
12
|
+
self._ds = ds
|
|
13
|
+
|
|
14
|
+
def is_forecast(self):
|
|
15
|
+
return self._time == Time.FORECAST
|
|
16
|
+
|
|
17
|
+
def is_observation(self):
|
|
18
|
+
return self._time == Time.OBSERVATION
|
|
19
|
+
|
|
20
|
+
def add_valid_time(self):
|
|
21
|
+
if self.is_forecast():
|
|
22
|
+
valid_time = (
|
|
23
|
+
self._ds["reference_time"].values[:, np.newaxis]
|
|
24
|
+
+ self._ds["lead_time"].values
|
|
25
|
+
)
|
|
26
|
+
ds_out = self._ds.assign_coords(
|
|
27
|
+
{"valid_time": (["reference_time", "lead_time"], valid_time)}
|
|
28
|
+
)
|
|
29
|
+
else:
|
|
30
|
+
ds_out = self._ds
|
|
31
|
+
return ds_out
|
|
32
|
+
|
|
33
|
+
def align_with(self, ds, **kwargs):
|
|
34
|
+
if self.is_forecast():
|
|
35
|
+
if ds.time.is_forecast():
|
|
36
|
+
return _align_forecast_forecast(self._ds, ds, **kwargs)
|
|
37
|
+
elif ds.time.is_observation():
|
|
38
|
+
return _align_forecast_observation(self._ds, ds, **kwargs)
|
|
39
|
+
elif self.is_observation():
|
|
40
|
+
if ds.time.is_observation():
|
|
41
|
+
return _align_observation_observation(self._ds, ds, **kwargs)
|
|
42
|
+
elif ds.time.is_forecast():
|
|
43
|
+
return _align_observation_forecast(self._ds, ds, **kwargs)
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError("Datasets do not have compatible temporal properties")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _align_forecast_forecast(ds1, ds2, only_common=False):
|
|
49
|
+
# Align the reference times
|
|
50
|
+
common_reference_times = ds1.indexes["reference_time"].intersection(
|
|
51
|
+
ds2.indexes["reference_time"]
|
|
52
|
+
)
|
|
53
|
+
ds1_aligned = ds1.sel(reference_time=common_reference_times)
|
|
54
|
+
ds2_aligned = ds2.sel(reference_time=common_reference_times)
|
|
55
|
+
|
|
56
|
+
# Align the lead times
|
|
57
|
+
if only_common:
|
|
58
|
+
common_lead_times = ds1_aligned.indexes["lead_time"].intersection(
|
|
59
|
+
ds2_aligned.indexes["lead_time"]
|
|
60
|
+
)
|
|
61
|
+
ds1_aligned = ds1_aligned.sel(lead_time=common_lead_times)
|
|
62
|
+
ds2_aligned = ds2_aligned.sel(lead_time=common_lead_times)
|
|
63
|
+
else:
|
|
64
|
+
non_aligning_dims = (set(ds1.dims) | set(ds2.dims)) - set(["lead_time"])
|
|
65
|
+
ds1_aligned, ds2_aligned = xr.align(
|
|
66
|
+
ds1_aligned, ds2_aligned, join="outer", exclude=non_aligning_dims
|
|
67
|
+
)
|
|
68
|
+
ds1_aligned = ds1_aligned.time.add_valid_time()
|
|
69
|
+
ds2_aligned = ds2_aligned.time.add_valid_time()
|
|
70
|
+
return ds1_aligned, ds2_aligned
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _align_forecast_observation(
|
|
74
|
+
ds_forecast, ds_observation, only_common=False, lead_time="start-min"
|
|
75
|
+
):
|
|
76
|
+
ds_forecast = ds_forecast.time.add_valid_time()
|
|
77
|
+
|
|
78
|
+
# Check if reference_times are continuous
|
|
79
|
+
reference_time_diff = ds_forecast.reference_time.diff("reference_time").values
|
|
80
|
+
if not (reference_time_diff[0] == reference_time_diff).all():
|
|
81
|
+
raise NotImplementedError(
|
|
82
|
+
"Aligning a forecast with non-continuous reference times with an observation is not implemented."
|
|
83
|
+
)
|
|
84
|
+
if lead_time == "start-min":
|
|
85
|
+
min_diff = reference_time_diff[0]
|
|
86
|
+
ds_forecast_reduced = ds_forecast.where(
|
|
87
|
+
ds_forecast.lead_time < min_diff, drop=True
|
|
88
|
+
)
|
|
89
|
+
elif lead_time == "start-max":
|
|
90
|
+
max_diff = ds_forecast.lead_time.max().values
|
|
91
|
+
reference_times = np.arange(
|
|
92
|
+
ds_forecast.reference_time.min().values,
|
|
93
|
+
ds_forecast.reference_time.max().values,
|
|
94
|
+
max_diff,
|
|
95
|
+
dtype="datetime64[ns]",
|
|
96
|
+
)
|
|
97
|
+
ds_forecast_reduced = ds_forecast.sel(reference_time=reference_times)
|
|
98
|
+
else:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
"Invalid value for lead_time. Expected 'start-min' or 'start-max'."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
ds_forecast_stacked = (
|
|
104
|
+
ds_forecast_reduced.stack(time=["reference_time", "lead_time"])
|
|
105
|
+
.reset_index("time")
|
|
106
|
+
.swap_dims({"time": "valid_time"})
|
|
107
|
+
.transpose("valid_time", ...)
|
|
108
|
+
)
|
|
109
|
+
if only_common:
|
|
110
|
+
ds_forecast_aligned, ds_observation_aligned = xr.align(
|
|
111
|
+
ds_forecast_stacked,
|
|
112
|
+
ds_observation,
|
|
113
|
+
join="inner",
|
|
114
|
+
exclude=set(ds_forecast_stacked.coords)
|
|
115
|
+
| set(ds_observation.coords) - set(["valid_time"]),
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
ds_forecast_aligned, ds_observation_aligned = xr.align(
|
|
119
|
+
ds_forecast_stacked,
|
|
120
|
+
ds_observation,
|
|
121
|
+
join="outer",
|
|
122
|
+
exclude=set(ds_forecast_stacked.coords)
|
|
123
|
+
| set(ds_observation.coords) - set(["valid_time"]),
|
|
124
|
+
)
|
|
125
|
+
ds_forecast_aligned = update_time_property(ds_forecast_aligned, Time.OBSERVATION)
|
|
126
|
+
return ds_forecast_aligned, ds_observation_aligned
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _align_observation_observation(ds1, ds2, only_common=False):
|
|
130
|
+
exclude = (set(ds1.dims) | set(ds2.dims)) - set(["valid_time"])
|
|
131
|
+
if only_common:
|
|
132
|
+
ds1_aligned, ds2_aligned = xr.align(ds1, ds2, join="inner", exclude=exclude)
|
|
133
|
+
else:
|
|
134
|
+
ds1_aligned, ds2_aligned = xr.align(ds1, ds2, join="outer", exclude=exclude)
|
|
135
|
+
return ds1_aligned, ds2_aligned
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _align_observation_forecast(ds_observation, ds_forecast, only_common=False):
|
|
139
|
+
ds_forecast_cut = ds_forecast.time.add_valid_time()
|
|
140
|
+
if (
|
|
141
|
+
ds_forecast_cut.reference_time.min().values
|
|
142
|
+
< ds_observation.valid_time.min().values
|
|
143
|
+
):
|
|
144
|
+
ds_forecast_cut = ds_forecast_cut.sel(
|
|
145
|
+
reference_time=slice(ds_observation.valid_time.min().values, None)
|
|
146
|
+
)
|
|
147
|
+
if ds_forecast_cut.valid_time.max().values > ds_observation.valid_time.max().values:
|
|
148
|
+
# The forecast time-step/lead times might not always align with the maximum observation time
|
|
149
|
+
valid_diff = (
|
|
150
|
+
ds_forecast_cut["valid_time"] - (ds_observation["valid_time"].max())
|
|
151
|
+
).isel(lead_time=-1)
|
|
152
|
+
last_valid_index = (
|
|
153
|
+
np.abs(valid_diff.where(valid_diff <= 0, drop=True)).argmin().values
|
|
154
|
+
)
|
|
155
|
+
max_reference_time = ds_forecast_cut.isel(reference_time=last_valid_index)[
|
|
156
|
+
"reference_time"
|
|
157
|
+
].values
|
|
158
|
+
|
|
159
|
+
# max_reference_time = ds_observation.valid_time.max().values - (ds_forecast_cut.lead_time.max().values - shift)
|
|
160
|
+
ds_forecast_cut = ds_forecast_cut.sel(
|
|
161
|
+
reference_time=slice(None, max_reference_time)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
ds_observation_aligned = ds_observation.sel(valid_time=ds_forecast_cut.valid_time)
|
|
165
|
+
ds_observation_aligned = ds_observation_aligned.transpose(
|
|
166
|
+
"reference_time", "lead_time", ...
|
|
167
|
+
)
|
|
168
|
+
ds_observation_aligned = update_time_property(ds_observation_aligned, Time.FORECAST)
|
|
169
|
+
if only_common:
|
|
170
|
+
return ds_observation_aligned, ds_forecast_cut
|
|
171
|
+
else:
|
|
172
|
+
ds_observation_aligned, ds_forecast_aligned = xr.align(
|
|
173
|
+
ds_observation_aligned,
|
|
174
|
+
ds_forecast.time.add_valid_time(),
|
|
175
|
+
join="outer",
|
|
176
|
+
exclude=(set(ds_observation_aligned.coords) | set(ds_forecast_cut.coords))
|
|
177
|
+
- set(["reference_time", "lead_time"]),
|
|
178
|
+
)
|
|
179
|
+
ds_observation_aligned["valid_time"] = ds_forecast_aligned["valid_time"]
|
|
180
|
+
return ds_observation_aligned, ds_forecast_aligned
|
mxalign/align/nans.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
import itertools
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def broadcast_nans(datasets: dict | list) -> None:
|
|
6
|
+
"""
|
|
7
|
+
Broadcasts NaN values across a list of xarray Datasets by ensuring that if a value is NaN
|
|
8
|
+
in one dataset at a specific coordinate, it becomes NaN in all datasets at that coordinate.
|
|
9
|
+
|
|
10
|
+
Parameters
|
|
11
|
+
----------
|
|
12
|
+
datasets : list[xr.Dataset] | dict[str, xr.Dataset]
|
|
13
|
+
A list of xarray Datasets to process. The datasets should share some common
|
|
14
|
+
coordinates and variables.
|
|
15
|
+
|
|
16
|
+
Returns
|
|
17
|
+
-------
|
|
18
|
+
list[xr.Dataset] | dict[str, xr.Dataset]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
Notes
|
|
22
|
+
-----
|
|
23
|
+
- The function operates on pairs of datasets, comparing each dataset with every other dataset
|
|
24
|
+
in the list.
|
|
25
|
+
- Only coordinate values that exist in both datasets of a pair are considered.
|
|
26
|
+
- Only variables that exist in both datasets of a pair are processed.
|
|
27
|
+
- The NaN broadcasting is performed at the intersection of coordinates between each pair
|
|
28
|
+
of datasets.
|
|
29
|
+
|
|
30
|
+
Examples
|
|
31
|
+
--------
|
|
32
|
+
>>> ds1 = xr.Dataset(...)
|
|
33
|
+
>>> ds2 = xr.Dataset(...)
|
|
34
|
+
>>> ds3 = xr.Dataset(...)
|
|
35
|
+
>>> broadcast_nans([ds1, ds2, ds3])
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
if isinstance(datasets, xr.Dataset):
|
|
39
|
+
return datasets
|
|
40
|
+
elif isinstance(datasets, dict):
|
|
41
|
+
keys = list(datasets.keys())
|
|
42
|
+
working = [ds.copy(deep=True) for ds in datasets.values()]
|
|
43
|
+
else:
|
|
44
|
+
keys = None
|
|
45
|
+
working = [ds.copy(deep=True) for ds in datasets]
|
|
46
|
+
|
|
47
|
+
# Iterate over all pairs of datasets
|
|
48
|
+
for dsA, dsB in itertools.combinations(working, 2):
|
|
49
|
+
# Find the shared coordinates for all dimensions
|
|
50
|
+
common_coords = {
|
|
51
|
+
dim: sorted(set(dsA[dim].values) & set(dsB[dim].values)) for dim in dsA.dims
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
# Iterate over all variables
|
|
55
|
+
for var in dsA.data_vars:
|
|
56
|
+
if var in dsB: # Ensure both datasets have the variable
|
|
57
|
+
# Select the data at common coordinates
|
|
58
|
+
selA = dsA[var].sel(**common_coords)
|
|
59
|
+
selB = dsB[var].sel(**common_coords)
|
|
60
|
+
|
|
61
|
+
# Compute NaN mask for shared coordinates
|
|
62
|
+
nan_mask = selA.isnull() | selB.isnull()
|
|
63
|
+
|
|
64
|
+
# Apply NaN mask back to both datasets
|
|
65
|
+
dsA[var].loc[common_coords] = (
|
|
66
|
+
dsA[var].sel(**common_coords).where(~nan_mask)
|
|
67
|
+
)
|
|
68
|
+
dsB[var].loc[common_coords] = (
|
|
69
|
+
dsB[var].sel(**common_coords).where(~nan_mask)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return dict(zip(keys, working)) if keys else working
|
mxalign/align/space.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def align_space(datasets, reference, **kwargs):
|
|
5
|
+
if isinstance(datasets, (xr.Dataset, xr.DataArray)):
|
|
6
|
+
datasets = [datasets]
|
|
7
|
+
if isinstance(datasets, dict):
|
|
8
|
+
keys = datasets.keys()
|
|
9
|
+
datasets = datasets.items()
|
|
10
|
+
else:
|
|
11
|
+
keys = None
|
|
12
|
+
|
|
13
|
+
datasets = [ds.space.align_with(reference, **kwargs)[0] for ds in datasets]
|
|
14
|
+
|
|
15
|
+
if keys is None:
|
|
16
|
+
if len(datasets) == 1:
|
|
17
|
+
return datasets[0]
|
|
18
|
+
else:
|
|
19
|
+
return datasets
|
|
20
|
+
else:
|
|
21
|
+
return {key: value for (key, value) in zip(keys, datasets)}
|
mxalign/align/time.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def align_time(
|
|
5
|
+
datasets: list[xr.Dataset] | dict[str, xr.Dataset], return_as: str = "forecast"
|
|
6
|
+
):
|
|
7
|
+
if isinstance(datasets, (xr.Dataset, xr.DataArray)):
|
|
8
|
+
datasets = [datasets]
|
|
9
|
+
if isinstance(datasets, dict):
|
|
10
|
+
keys = datasets.keys()
|
|
11
|
+
datasets = datasets.values()
|
|
12
|
+
else:
|
|
13
|
+
keys = None
|
|
14
|
+
|
|
15
|
+
if return_as != "forecast":
|
|
16
|
+
NotImplementedError(
|
|
17
|
+
"Currently only temporal alignment return forecast structure is supported."
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# Get the first forecast to start building the valid times
|
|
21
|
+
valid_times_fcst = None
|
|
22
|
+
valid_times_obs = None
|
|
23
|
+
first_fcst = True
|
|
24
|
+
first_obs = True
|
|
25
|
+
for ds in datasets:
|
|
26
|
+
if ds.time.is_forecast():
|
|
27
|
+
if first_fcst:
|
|
28
|
+
valid_times_fcst = ds.time.add_valid_time()["valid_time"].to_dataset(
|
|
29
|
+
name="valid_times"
|
|
30
|
+
)
|
|
31
|
+
valid_times_fcst = valid_times_fcst.assign_attrs(ds.attrs)
|
|
32
|
+
first_fcst = False
|
|
33
|
+
else:
|
|
34
|
+
_ds = ds.time.add_valid_time()["valid_time"].to_dataset(
|
|
35
|
+
name="valid_times"
|
|
36
|
+
)
|
|
37
|
+
_ds = _ds.assign_attrs(ds.attrs)
|
|
38
|
+
_, valid_times_fcst = _ds.time.align_with(valid_times_fcst)
|
|
39
|
+
elif ds.time.is_observation():
|
|
40
|
+
if first_obs:
|
|
41
|
+
valid_times_obs = ds["valid_time"].to_dataset(name="valid_times")
|
|
42
|
+
valid_times_obs = valid_times_obs.assign_attrs(ds.attrs)
|
|
43
|
+
first_obs = False
|
|
44
|
+
else:
|
|
45
|
+
_ds = ds["valid_time"].to_dataset(name="valid_times")
|
|
46
|
+
_ds = _ds.assign_attrs(ds.attrs)
|
|
47
|
+
_, valid_times_obs = _ds.time.align_with(valid_times_obs)
|
|
48
|
+
|
|
49
|
+
if (valid_times_obs is None) and (valid_times_fcst is None):
|
|
50
|
+
raise ValueError("No observations or forecasts found")
|
|
51
|
+
elif valid_times_fcst is None:
|
|
52
|
+
valid_times = valid_times_obs
|
|
53
|
+
elif valid_times_obs is None:
|
|
54
|
+
valid_times = valid_times_fcst
|
|
55
|
+
else:
|
|
56
|
+
_, valid_times = valid_times_obs.time.align_with(valid_times_fcst)
|
|
57
|
+
|
|
58
|
+
datasets = [ds.time.align_with(valid_times)[0] for ds in datasets]
|
|
59
|
+
if keys is None:
|
|
60
|
+
return datasets
|
|
61
|
+
else:
|
|
62
|
+
return {key: value for (key, value) in zip(keys, datasets)}
|