mxalign 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxalign/__init__.py +36 -0
- mxalign/accessors/__init__.py +7 -0
- mxalign/accessors/space.py +205 -0
- mxalign/accessors/time.py +180 -0
- mxalign/align/__init__.py +7 -0
- mxalign/align/nans.py +72 -0
- mxalign/align/space.py +21 -0
- mxalign/align/time.py +62 -0
- mxalign/cli.py +157 -0
- mxalign/interpolations/__init__.py +9 -0
- mxalign/interpolations/base.py +29 -0
- mxalign/interpolations/delaunay.py +218 -0
- mxalign/interpolations/interpolate.py +29 -0
- mxalign/interpolations/registry.py +17 -0
- mxalign/interpolations/xarray.py +63 -0
- mxalign/loaders/__init__.py +11 -0
- mxalign/loaders/anemoi_datasets.py +92 -0
- mxalign/loaders/anemoi_inference.py +103 -0
- mxalign/loaders/base.py +103 -0
- mxalign/loaders/harp_obstable.py +81 -0
- mxalign/loaders/loader.py +8 -0
- mxalign/loaders/registry.py +17 -0
- mxalign/properties/__init__.py +0 -0
- mxalign/properties/properties.py +25 -0
- mxalign/properties/specs.py +54 -0
- mxalign/properties/utils.py +43 -0
- mxalign/properties/validation.py +48 -0
- mxalign/runner.py +167 -0
- mxalign/transformations/__init__.py +7 -0
- mxalign/transformations/base.py +38 -0
- mxalign/transformations/external.py +34 -0
- mxalign/transformations/registry.py +20 -0
- mxalign/transformations/transform.py +28 -0
- mxalign/utils/config.py +55 -0
- mxalign/utils/dates.py +76 -0
- mxalign/utils/projections.py +104 -0
- mxalign/utils/save.py +62 -0
- mxalign/verification.py +57 -0
- mxalign-0.1.0.dist-info/METADATA +136 -0
- mxalign-0.1.0.dist-info/RECORD +43 -0
- mxalign-0.1.0.dist-info/WHEEL +4 -0
- mxalign-0.1.0.dist-info/entry_points.txt +2 -0
- mxalign-0.1.0.dist-info/licenses/LICENSE +21 -0
mxalign/cli.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import sys
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
# Define log format
|
|
6
|
+
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
|
|
7
|
+
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
|
8
|
+
LOG = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def run_local(args):
|
|
12
|
+
# Only import the necessary modules if function is called
|
|
13
|
+
# to avoid unnecessary slow imports at the top level
|
|
14
|
+
from dask.distributed import Client, LocalCluster
|
|
15
|
+
from .runner import Runner
|
|
16
|
+
|
|
17
|
+
cluster = LocalCluster(
|
|
18
|
+
n_workers=args.n_workers,
|
|
19
|
+
threads_per_worker=args.threads_per_worker,
|
|
20
|
+
processes=True,
|
|
21
|
+
)
|
|
22
|
+
client = Client(cluster)
|
|
23
|
+
|
|
24
|
+
runner = Runner(args.CONFIG)
|
|
25
|
+
try:
|
|
26
|
+
runner.run()
|
|
27
|
+
except Exception:
|
|
28
|
+
LOG.error("Error during verification closing down dask cluster", exc_info=True)
|
|
29
|
+
client.close()
|
|
30
|
+
cluster.close()
|
|
31
|
+
sys.exit(1)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def run_slurm(args):
|
|
35
|
+
# Only import the necessary modules if function is called
|
|
36
|
+
# to avoid unnecessary slow imports at the top level
|
|
37
|
+
from dask.distributed import Client
|
|
38
|
+
from dask_jobqueue import SLURMCluster
|
|
39
|
+
from .runner import Runner
|
|
40
|
+
|
|
41
|
+
cluster = SLURMCluster(
|
|
42
|
+
queue=args.queue,
|
|
43
|
+
account=args.account,
|
|
44
|
+
cores=args.cores,
|
|
45
|
+
# processes = args.processes,
|
|
46
|
+
memory=args.memory,
|
|
47
|
+
interface=args.interface,
|
|
48
|
+
)
|
|
49
|
+
cluster.scale(jobs=3)
|
|
50
|
+
client = Client(cluster)
|
|
51
|
+
|
|
52
|
+
logging.basicConfig(
|
|
53
|
+
level=logging.INFO, # Set log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
54
|
+
format=LOG_FORMAT,
|
|
55
|
+
datefmt=DATE_FORMAT,
|
|
56
|
+
handlers=[
|
|
57
|
+
# logging.FileHandler("app.log"), # Log to a file
|
|
58
|
+
logging.StreamHandler() # Log to console
|
|
59
|
+
],
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
runner = Runner(args.CONFIG)
|
|
63
|
+
try:
|
|
64
|
+
runner.run()
|
|
65
|
+
except Exception:
|
|
66
|
+
LOG.error("Error during verification closing down dask cluster", exc_info=True)
|
|
67
|
+
client.close()
|
|
68
|
+
cluster.close()
|
|
69
|
+
sys.exit(1)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def main():
|
|
73
|
+
|
|
74
|
+
parser = argparse.ArgumentParser(description="mxalign CLI")
|
|
75
|
+
subparsers = parser.add_subparsers(
|
|
76
|
+
dest="command", required=True, help="Available commands"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
local_parser = subparsers.add_parser(
|
|
80
|
+
"local",
|
|
81
|
+
help="Run the verification pipeline based on a config-file on a local dask cluster",
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
local_parser.add_argument(
|
|
85
|
+
"--n_workers", default=4, type=int, help="Number of dask workers"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
local_parser.add_argument(
|
|
89
|
+
"--threads_per_worker",
|
|
90
|
+
default=1,
|
|
91
|
+
type=int,
|
|
92
|
+
help="Number of threads per dask worker",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
slurm_parser = subparsers.add_parser(
|
|
96
|
+
"slurm",
|
|
97
|
+
help="Run the verification pipeline based on a config-file on a slurm cluster",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
slurm_parser.add_argument(
|
|
101
|
+
"--queue", type=str, help="Destination queue for the worker jobs"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
slurm_parser.add_argument(
|
|
105
|
+
"--account", type=str, help="Account to charge the jobs to"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
slurm_parser.add_argument(
|
|
109
|
+
"--cores",
|
|
110
|
+
type=int,
|
|
111
|
+
default=8,
|
|
112
|
+
help="Total number of CPU cores on which all worker threads inside a job will run",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
slurm_parser.add_argument(
|
|
116
|
+
"--memory",
|
|
117
|
+
type=str,
|
|
118
|
+
default="64GB",
|
|
119
|
+
help="Total amount of memory to be used by all workers inside a job",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
slurm_parser.add_argument(
|
|
123
|
+
"--interface",
|
|
124
|
+
type=str,
|
|
125
|
+
default="hsn0",
|
|
126
|
+
help="Network interface to use for the dask workers",
|
|
127
|
+
)
|
|
128
|
+
parser.add_argument("CONFIG", type=str, help="Path to the YAML configuration file")
|
|
129
|
+
|
|
130
|
+
args = parser.parse_args()
|
|
131
|
+
|
|
132
|
+
if args.command == "local":
|
|
133
|
+
run_local(args)
|
|
134
|
+
elif args.command == "slurm":
|
|
135
|
+
run_slurm(args)
|
|
136
|
+
elif not args.command:
|
|
137
|
+
parser.print_help()
|
|
138
|
+
sys.exit(1)
|
|
139
|
+
else:
|
|
140
|
+
LOG.error(f"Unknown command: {args.command}")
|
|
141
|
+
parser.print_help()
|
|
142
|
+
sys.exit(1)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
if __name__ == "__main__":
|
|
146
|
+
logging.basicConfig(
|
|
147
|
+
level=logging.INFO, # Set log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
148
|
+
format=LOG_FORMAT,
|
|
149
|
+
datefmt=DATE_FORMAT,
|
|
150
|
+
handlers=[
|
|
151
|
+
# logging.FileHandler("app.log"), # Log to a file
|
|
152
|
+
logging.StreamHandler() # Log to console
|
|
153
|
+
],
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
LOG.info("Starting mxalign CLI")
|
|
157
|
+
main()
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
from ..properties.properties import Space
|
|
3
|
+
from ..properties.utils import update_space_property
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseInterpolator:
|
|
7
|
+
"""Base class for all interpolators."""
|
|
8
|
+
|
|
9
|
+
name: str = "base"
|
|
10
|
+
source_space: Space | None = None
|
|
11
|
+
target_space: Space | None = None
|
|
12
|
+
|
|
13
|
+
def __init__(self, target_dataset, **options):
|
|
14
|
+
self.target_dataset = target_dataset
|
|
15
|
+
self.options = options
|
|
16
|
+
# TODO: Check the properties
|
|
17
|
+
|
|
18
|
+
# def supports(self, src: Properties, tgt: Properties):
|
|
19
|
+
|
|
20
|
+
def interpolate(
|
|
21
|
+
self, source_dataset: xr.Dataset | xr.DataArray
|
|
22
|
+
) -> xr.Dataset | xr.DataArray:
|
|
23
|
+
ds_out = self._interpolate(source_dataset)
|
|
24
|
+
return update_space_property(ds_out, self.target_space)
|
|
25
|
+
|
|
26
|
+
def _interpolate(
|
|
27
|
+
self, source_dataset: xr.Dataset | xr.DataArray
|
|
28
|
+
) -> xr.Dataset | xr.DataArray:
|
|
29
|
+
pass
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import dask.array as dda
|
|
5
|
+
import xarray as xr
|
|
6
|
+
|
|
7
|
+
from scipy.spatial import Delaunay
|
|
8
|
+
from scipy.sparse import csr_matrix
|
|
9
|
+
|
|
10
|
+
from .base import BaseInterpolator
|
|
11
|
+
from .registry import register_interpolator
|
|
12
|
+
from ..properties.properties import Space
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_interpolator
|
|
16
|
+
class DelaunayInterpolator(BaseInterpolator):
|
|
17
|
+
name = "delaunay"
|
|
18
|
+
source_space = Space.GRID
|
|
19
|
+
target_space = Space.POINT
|
|
20
|
+
|
|
21
|
+
def __init__(self, target_dataset, **options):
|
|
22
|
+
super().__init__(target_dataset, **options)
|
|
23
|
+
method = self.options.get("method", "linear")
|
|
24
|
+
self._W_cache = {} # keyed by source grid hash
|
|
25
|
+
if method != "linear":
|
|
26
|
+
raise ValueError(
|
|
27
|
+
f"Method: {method}. Delaunay interpolation only supports linear interpolation"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def _get_weights(self, source_points, target_points):
|
|
31
|
+
key = (
|
|
32
|
+
source_points.shape,
|
|
33
|
+
source_points[0, 0],
|
|
34
|
+
source_points[-1, 1],
|
|
35
|
+
) # cheap fingerprint
|
|
36
|
+
if key not in self._W_cache:
|
|
37
|
+
triangulation = Delaunay(source_points)
|
|
38
|
+
self._W_cache[key] = _build_weight_matrix(
|
|
39
|
+
triangulation, source_points, target_points
|
|
40
|
+
)
|
|
41
|
+
return self._W_cache[key]
|
|
42
|
+
|
|
43
|
+
def _interpolate(self, source_dataset):
|
|
44
|
+
if "grid_index" not in source_dataset.dims:
|
|
45
|
+
raise NotImplementedError(
|
|
46
|
+
"Delaunay interpolation currently only supports stacked grids"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if "latitude" in source_dataset.dims:
|
|
50
|
+
lon_grid, lat_grid = np.meshgrid(
|
|
51
|
+
source_dataset["longitude"].values, source_dataset["latitude"].values
|
|
52
|
+
)
|
|
53
|
+
source_points = np.column_stack((lat_grid.ravel(), lon_grid.ravel()))
|
|
54
|
+
else:
|
|
55
|
+
source_points = np.column_stack(
|
|
56
|
+
(source_dataset["latitude"].values, source_dataset["longitude"].values)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
target_points = np.column_stack(
|
|
60
|
+
(
|
|
61
|
+
self.target_dataset["latitude"].values,
|
|
62
|
+
self.target_dataset["longitude"].values,
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Compute triangulation and sparse weight matrix ONCE, shared across all variables
|
|
67
|
+
W = self._get_weights(source_points, target_points)
|
|
68
|
+
|
|
69
|
+
arrays_out = {}
|
|
70
|
+
for var in source_dataset.data_vars:
|
|
71
|
+
da = source_dataset[var]
|
|
72
|
+
if da.dims[-1] != "grid_index":
|
|
73
|
+
print(
|
|
74
|
+
f"Skipping variable '{var}' - doesn't end with spatial dimension grid_index"
|
|
75
|
+
)
|
|
76
|
+
continue
|
|
77
|
+
else:
|
|
78
|
+
arrays_out[var] = interpolate_da(da, W, target_points)
|
|
79
|
+
|
|
80
|
+
ds_out = xr.Dataset(arrays_out).assign_coords(
|
|
81
|
+
latitude=self.target_dataset["latitude"],
|
|
82
|
+
longitude=self.target_dataset["longitude"],
|
|
83
|
+
)
|
|
84
|
+
ds_out.attrs["properties"] = source_dataset.attrs["properties"]
|
|
85
|
+
return ds_out
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _build_weight_matrix(
|
|
89
|
+
triangulation: Delaunay,
|
|
90
|
+
source_points: np.ndarray,
|
|
91
|
+
target_points: np.ndarray,
|
|
92
|
+
) -> csr_matrix:
|
|
93
|
+
"""
|
|
94
|
+
Precompute a sparse (n_target, n_source) weight matrix from the triangulation.
|
|
95
|
+
|
|
96
|
+
Applying W to a (n_source,) value vector gives (n_target,) interpolated values
|
|
97
|
+
via a simple sparse matrix multiply. Target points outside the convex hull
|
|
98
|
+
receive NaN weights.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
print("Calculating interpolation-weight matrix")
|
|
102
|
+
|
|
103
|
+
n_target = len(target_points)
|
|
104
|
+
n_source = len(source_points)
|
|
105
|
+
ndim = source_points.shape[1] # 2 for lat/lon
|
|
106
|
+
|
|
107
|
+
# Find which simplex each target point falls in; -1 means outside convex hull
|
|
108
|
+
simplex_indices = triangulation.find_simplex(target_points) # (n_target,)
|
|
109
|
+
|
|
110
|
+
# Map outside points to simplex 0 temporarily to avoid index errors —
|
|
111
|
+
# their weights will be NaN'd out below
|
|
112
|
+
safe_indices = np.where(simplex_indices >= 0, simplex_indices, 0)
|
|
113
|
+
|
|
114
|
+
# Vertices of each target point's simplex: (n_target, ndim+1)
|
|
115
|
+
simplex_vertices = triangulation.simplices[safe_indices]
|
|
116
|
+
|
|
117
|
+
# Recover barycentric coordinates using the affine transforms stored in
|
|
118
|
+
# triangulation.transform: shape (nsimplex, ndim+1, ndim)
|
|
119
|
+
# transform[s, :ndim, :] — inverse of the edge matrix for simplex s
|
|
120
|
+
# transform[s, ndim, :] — the ndim-th vertex (origin) of simplex s
|
|
121
|
+
Tinv = triangulation.transform[safe_indices, :ndim, :] # (n_target, ndim, ndim)
|
|
122
|
+
origin = triangulation.transform[safe_indices, ndim, :] # (n_target, ndim)
|
|
123
|
+
|
|
124
|
+
r = target_points - origin # (n_target, ndim)
|
|
125
|
+
bary_partial = np.einsum("nij,nj->ni", Tinv, r) # (n_target, ndim)
|
|
126
|
+
last = 1.0 - bary_partial.sum(axis=1, keepdims=True)
|
|
127
|
+
bary = np.concatenate([bary_partial, last], axis=1) # (n_target, ndim+1)
|
|
128
|
+
|
|
129
|
+
# Flatten into coordinate format (COO) for sparse matrix construction
|
|
130
|
+
rows = np.repeat(np.arange(n_target), ndim + 1)
|
|
131
|
+
cols = simplex_vertices.ravel()
|
|
132
|
+
vals = bary.ravel()
|
|
133
|
+
|
|
134
|
+
# NaN out weights for points outside the convex hull
|
|
135
|
+
outside = simplex_indices == -1
|
|
136
|
+
vals[np.repeat(outside, ndim + 1)] = np.nan
|
|
137
|
+
|
|
138
|
+
W = csr_matrix((vals, (rows, cols)), shape=(n_target, n_source))
|
|
139
|
+
|
|
140
|
+
print("Done")
|
|
141
|
+
|
|
142
|
+
return W
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def interpolate_da(
|
|
146
|
+
da: xr.DataArray, W: csr_matrix, target_points: np.ndarray
|
|
147
|
+
) -> xr.DataArray:
|
|
148
|
+
n_target = len(target_points)
|
|
149
|
+
leading_dims = da.dims[:-1]
|
|
150
|
+
|
|
151
|
+
# Validate that grid_index is not chunked
|
|
152
|
+
if isinstance(da.data, dda.Array):
|
|
153
|
+
grid_chunks = dict(zip(da.dims, da.chunks)).get("grid_index")
|
|
154
|
+
if grid_chunks is not None and len(grid_chunks) > 1:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"grid_index must not be chunked for Delaunay interpolation "
|
|
157
|
+
f"(found {len(grid_chunks)} chunks). Rechunk with da.chunk({{'grid_index': -1}}) "
|
|
158
|
+
f"or enforce this on the loading side."
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Build the template
|
|
162
|
+
# Get chunking info for leading dims
|
|
163
|
+
shape_tmp = tuple(da.sizes[d] for d in leading_dims) + (n_target,)
|
|
164
|
+
|
|
165
|
+
if isinstance(da.data, dda.Array):
|
|
166
|
+
dim_to_chunks = dict(zip(da.dims, da.chunks))
|
|
167
|
+
else:
|
|
168
|
+
dim_to_chunks = {dim: (da.sizes[dim],) for dim in da.dims}
|
|
169
|
+
|
|
170
|
+
chunks_tmp = tuple(
|
|
171
|
+
dim_to_chunks[dim] if dim in dim_to_chunks else (da.sizes[dim],)
|
|
172
|
+
for dim in leading_dims
|
|
173
|
+
) + ((n_target,),)
|
|
174
|
+
|
|
175
|
+
# Create a dask array template matching the chunking pattern
|
|
176
|
+
tmp = dda.empty(shape=shape_tmp, chunks=chunks_tmp, dtype=da.dtype)
|
|
177
|
+
tmp = xr.DataArray(
|
|
178
|
+
tmp,
|
|
179
|
+
dims=leading_dims + ("point_index",),
|
|
180
|
+
coords={d: da.coords[d].load() for d in leading_dims},
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Drop coords tied to grid_index to avoid dimension mismatch in map_blocks
|
|
184
|
+
spatial_coords = [c for c in da.coords if "grid_index" in da[c].dims]
|
|
185
|
+
da_clean = da.drop_vars(spatial_coords)
|
|
186
|
+
|
|
187
|
+
da_interp = da_clean.map_blocks(
|
|
188
|
+
partial(interpolate_block, W=W, target_points=target_points), template=tmp
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return da_interp
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def interpolate_block(
|
|
195
|
+
block: xr.DataArray,
|
|
196
|
+
W: csr_matrix,
|
|
197
|
+
target_points: np.ndarray,
|
|
198
|
+
) -> xr.DataArray:
|
|
199
|
+
data = block.values # shape = (.., npoints)
|
|
200
|
+
original_shape = data.shape[:-1]
|
|
201
|
+
data_flat = data.reshape(
|
|
202
|
+
-1, data.shape[-1]
|
|
203
|
+
) # shape = (ndim1 * ndim2 * ... , npoints)
|
|
204
|
+
|
|
205
|
+
# Identify NaN source points
|
|
206
|
+
nan_mask = np.isnan(data_flat) # (nleading, n_source)
|
|
207
|
+
|
|
208
|
+
if nan_mask.any():
|
|
209
|
+
print(f"Warning, interpolating NaNs for variable {block.name}")
|
|
210
|
+
|
|
211
|
+
# Single sparse matrix multiply replaces the per-row interpolator loop:
|
|
212
|
+
# (nleading, n_source) @ (n_source, n_target) -> (nleading, n_target)
|
|
213
|
+
interpolated_flat = data_flat @ W.T
|
|
214
|
+
interpolated = interpolated_flat.reshape(*original_shape, target_points.shape[0])
|
|
215
|
+
|
|
216
|
+
new_dims = block.dims[:-1] + ("point_index",)
|
|
217
|
+
new_coords = {dim: block.coords[dim] for dim in block.dims[:-1]}
|
|
218
|
+
return xr.DataArray(interpolated, dims=new_dims, coords=new_coords)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .registry import get_interpolation
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def interpolate(source_datasets, target_dataset, method, **kwargs):
|
|
5
|
+
interp_cls = get_interpolation(method)
|
|
6
|
+
interpolator = interp_cls(target_dataset, **kwargs)
|
|
7
|
+
|
|
8
|
+
if isinstance(source_datasets, dict):
|
|
9
|
+
keys = list(source_datasets.keys())
|
|
10
|
+
datasets = list(source_datasets.values())
|
|
11
|
+
else:
|
|
12
|
+
if not isinstance(source_datasets, list):
|
|
13
|
+
datasets = [source_datasets]
|
|
14
|
+
keys = None
|
|
15
|
+
|
|
16
|
+
if keys:
|
|
17
|
+
interpolated_datasets = dict()
|
|
18
|
+
for key, ds in zip(keys, datasets):
|
|
19
|
+
interpolated_datasets[key] = interpolator.interpolate(ds.copy())
|
|
20
|
+
else:
|
|
21
|
+
interpolated_datasets = []
|
|
22
|
+
for ds in datasets:
|
|
23
|
+
interpolated_datasets.append(interpolator.interpolate(ds.copy()))
|
|
24
|
+
interpolated_datasets = (
|
|
25
|
+
interpolated_datasets[0]
|
|
26
|
+
if len(interpolated_datasets) == 1
|
|
27
|
+
else interpolated_datasets
|
|
28
|
+
)
|
|
29
|
+
return interpolated_datasets
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
_INTERPOLATORS = {}
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def register_interpolator(cls):
|
|
5
|
+
_INTERPOLATORS[cls.name] = cls
|
|
6
|
+
return cls
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def available_interpolations():
|
|
10
|
+
return list(_INTERPOLATORS.keys())
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_interpolation(name):
|
|
14
|
+
try:
|
|
15
|
+
return _INTERPOLATORS[name]
|
|
16
|
+
except KeyError:
|
|
17
|
+
raise ValueError(f"Unknown interpolation: {name}")
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from .base import BaseInterpolator
|
|
2
|
+
from .registry import register_interpolator
|
|
3
|
+
from ..properties.properties import Space
|
|
4
|
+
|
|
5
|
+
import xarray as xr
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@register_interpolator
|
|
9
|
+
class XarrayInterpolator(BaseInterpolator):
|
|
10
|
+
name = "xarray"
|
|
11
|
+
source_space = Space.GRID
|
|
12
|
+
target_space = Space.POINT
|
|
13
|
+
|
|
14
|
+
def _interpolate(self, source_dataset):
|
|
15
|
+
|
|
16
|
+
if "latitude" in source_dataset.dims and "longitude" in source_dataset.dims:
|
|
17
|
+
ds_out = self._interpolate_from_latlon(source_dataset)
|
|
18
|
+
|
|
19
|
+
else:
|
|
20
|
+
if source_dataset.space.is_stacked():
|
|
21
|
+
try:
|
|
22
|
+
source_dataset = source_dataset.space.unstack()
|
|
23
|
+
except ValueError:
|
|
24
|
+
raise ValueError(
|
|
25
|
+
"Cannot unstack dataset, dataset must be unstacked to use xarray interpolation"
|
|
26
|
+
)
|
|
27
|
+
ds_out = self._interpolate_from_xcyc(source_dataset)
|
|
28
|
+
return ds_out
|
|
29
|
+
|
|
30
|
+
def _interpolate_from_xcyc(self, source_dataset):
|
|
31
|
+
import cartopy.crs as ccrs
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
crs = source_dataset.attrs["crs"]
|
|
35
|
+
except KeyError:
|
|
36
|
+
raise KeyError("Source dataset does not have a crs-attribute")
|
|
37
|
+
|
|
38
|
+
xyz = crs.transform_points(
|
|
39
|
+
x=self.target_dataset["longitude"].values,
|
|
40
|
+
y=self.target_dataset["latitude"].values,
|
|
41
|
+
src_crs=ccrs.PlateCarree(),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
x = xr.DataArray(xyz[:, 0], dims="point_index")
|
|
45
|
+
|
|
46
|
+
y = xr.DataArray(xyz[:, 1], dims="point_index")
|
|
47
|
+
|
|
48
|
+
ds_out = source_dataset.interp(xc=x, yc=y, **self.options)
|
|
49
|
+
# ).assing_coords(
|
|
50
|
+
# longitude=self.target_dataset["longitude"],
|
|
51
|
+
# latitude=self.target_dataset["latitude"]
|
|
52
|
+
# )
|
|
53
|
+
|
|
54
|
+
return ds_out
|
|
55
|
+
|
|
56
|
+
def _interpolate_from_latlon(self, source_dataset):
|
|
57
|
+
longitude = self.target_dataset["longitude"]
|
|
58
|
+
latitude = self.target_dataset["latitude"]
|
|
59
|
+
ds_out = source_dataset.interp(
|
|
60
|
+
longitude=longitude, latitude=latitude, **self.options
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
return ds_out
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import xarray as xr
|
|
3
|
+
|
|
4
|
+
from .registry import register_loader
|
|
5
|
+
from ..properties.properties import Space, Time, Uncertainty
|
|
6
|
+
from .base import BaseLoader
|
|
7
|
+
|
|
8
|
+
DROP_VARS = [
|
|
9
|
+
"latitude",
|
|
10
|
+
"longitude",
|
|
11
|
+
"time",
|
|
12
|
+
"cos_julian_day",
|
|
13
|
+
"cos_latitude",
|
|
14
|
+
"cos_local_time",
|
|
15
|
+
"cos_longitude",
|
|
16
|
+
"insolation",
|
|
17
|
+
"sin_julian_day",
|
|
18
|
+
"sin_latitude",
|
|
19
|
+
"sin_local_time",
|
|
20
|
+
"sin_longitude",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
COORDS = dict(longitude="longitudes", latitude="latitudes", valid_time="dates")
|
|
24
|
+
|
|
25
|
+
DEFAULTS = {"chunks": "auto"}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@register_loader
|
|
29
|
+
class AnemoiDatasetsLoader(BaseLoader):
|
|
30
|
+
name = "anemoi-datasets"
|
|
31
|
+
|
|
32
|
+
space = Space.GRID
|
|
33
|
+
time = Time.OBSERVATION
|
|
34
|
+
uncertainty = Uncertainty.DETERMINISTIC
|
|
35
|
+
|
|
36
|
+
def _load(self):
|
|
37
|
+
|
|
38
|
+
if isinstance(self.files, list):
|
|
39
|
+
dss = [xr.open_zarr(file, consolidated=False) for file in self.files]
|
|
40
|
+
dss_postproc = [_postprocess(ds) for ds in dss]
|
|
41
|
+
ds_postproc = xr.concat(dss_postproc, dim="valid_time")
|
|
42
|
+
else:
|
|
43
|
+
ds = xr.open_zarr(self.files, consolidated=False)
|
|
44
|
+
ds_postproc = _postprocess(ds)
|
|
45
|
+
|
|
46
|
+
if self.variables:
|
|
47
|
+
ds_selected = ds_postproc.sel(variable=self.variables)
|
|
48
|
+
else:
|
|
49
|
+
ds_selected = ds_postproc
|
|
50
|
+
if len(ds_selected["variable"]) > 10:
|
|
51
|
+
print(
|
|
52
|
+
f"Transforming anemoi-datasets xr.DataArray with {len(ds_postproc['variable'])} variables to xr.Dataset, this might take some time. Consider selecting the relevant variables during loading"
|
|
53
|
+
)
|
|
54
|
+
return ds_selected.to_dataset(dim="variable")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _postprocess(dataset: xr.Dataset) -> xr.Dataset:
|
|
58
|
+
"""Post-process the dataset to add coordinates and drop unused variables.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
dataset (xr.Dataset): The input dataset to be processed.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
xr.Dataset: The processed dataset with assigned coordinates and
|
|
65
|
+
attributes.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
# Add coordinates
|
|
69
|
+
coords = {
|
|
70
|
+
key: dataset[value].astype("datetime64[ns]").load()
|
|
71
|
+
if key == "valid_time"
|
|
72
|
+
else dataset[value].load()
|
|
73
|
+
for key, value in COORDS.items()
|
|
74
|
+
}
|
|
75
|
+
for key in ("latitude", "longitude"):
|
|
76
|
+
coords[key] = coords[key].astype(np.float32)
|
|
77
|
+
|
|
78
|
+
coords["variable"] = dataset.attrs["variables"]
|
|
79
|
+
coords["valid_time"] = coords["valid_time"].astype("datetime64[ns]")
|
|
80
|
+
ds_coords = dataset.assign_coords(coords)
|
|
81
|
+
|
|
82
|
+
# Drop unused variables and remove ensemble dimension
|
|
83
|
+
drop_vars = [var for var in DROP_VARS if var in coords["variable"]]
|
|
84
|
+
|
|
85
|
+
ds_pruned = (
|
|
86
|
+
ds_coords["data"]
|
|
87
|
+
.isel(ensemble=0)
|
|
88
|
+
.drop_sel(variable=drop_vars)
|
|
89
|
+
.swap_dims({"time": "valid_time"})
|
|
90
|
+
.rename({"cell": "grid_index"})
|
|
91
|
+
)
|
|
92
|
+
return ds_pruned
|