mxalign 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. mxalign/__init__.py +36 -0
  2. mxalign/accessors/__init__.py +7 -0
  3. mxalign/accessors/space.py +205 -0
  4. mxalign/accessors/time.py +180 -0
  5. mxalign/align/__init__.py +7 -0
  6. mxalign/align/nans.py +72 -0
  7. mxalign/align/space.py +21 -0
  8. mxalign/align/time.py +62 -0
  9. mxalign/cli.py +157 -0
  10. mxalign/interpolations/__init__.py +9 -0
  11. mxalign/interpolations/base.py +29 -0
  12. mxalign/interpolations/delaunay.py +218 -0
  13. mxalign/interpolations/interpolate.py +29 -0
  14. mxalign/interpolations/registry.py +17 -0
  15. mxalign/interpolations/xarray.py +63 -0
  16. mxalign/loaders/__init__.py +11 -0
  17. mxalign/loaders/anemoi_datasets.py +92 -0
  18. mxalign/loaders/anemoi_inference.py +103 -0
  19. mxalign/loaders/base.py +103 -0
  20. mxalign/loaders/harp_obstable.py +81 -0
  21. mxalign/loaders/loader.py +8 -0
  22. mxalign/loaders/registry.py +17 -0
  23. mxalign/properties/__init__.py +0 -0
  24. mxalign/properties/properties.py +25 -0
  25. mxalign/properties/specs.py +54 -0
  26. mxalign/properties/utils.py +43 -0
  27. mxalign/properties/validation.py +48 -0
  28. mxalign/runner.py +167 -0
  29. mxalign/transformations/__init__.py +7 -0
  30. mxalign/transformations/base.py +38 -0
  31. mxalign/transformations/external.py +34 -0
  32. mxalign/transformations/registry.py +20 -0
  33. mxalign/transformations/transform.py +28 -0
  34. mxalign/utils/config.py +55 -0
  35. mxalign/utils/dates.py +76 -0
  36. mxalign/utils/projections.py +104 -0
  37. mxalign/utils/save.py +62 -0
  38. mxalign/verification.py +57 -0
  39. mxalign-0.1.0.dist-info/METADATA +136 -0
  40. mxalign-0.1.0.dist-info/RECORD +43 -0
  41. mxalign-0.1.0.dist-info/WHEEL +4 -0
  42. mxalign-0.1.0.dist-info/entry_points.txt +2 -0
  43. mxalign-0.1.0.dist-info/licenses/LICENSE +21 -0
mxalign/cli.py ADDED
@@ -0,0 +1,157 @@
1
+ import argparse
2
+ import sys
3
+ import logging
4
+
5
+ # Define log format
6
+ LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
7
+ DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
8
+ LOG = logging.getLogger(__name__)
9
+
10
+
11
+ def run_local(args):
12
+ # Only import the necessary modules if function is called
13
+ # to avoid unnecessary slow imports at the top level
14
+ from dask.distributed import Client, LocalCluster
15
+ from .runner import Runner
16
+
17
+ cluster = LocalCluster(
18
+ n_workers=args.n_workers,
19
+ threads_per_worker=args.threads_per_worker,
20
+ processes=True,
21
+ )
22
+ client = Client(cluster)
23
+
24
+ runner = Runner(args.CONFIG)
25
+ try:
26
+ runner.run()
27
+ except Exception:
28
+ LOG.error("Error during verification closing down dask cluster", exc_info=True)
29
+ client.close()
30
+ cluster.close()
31
+ sys.exit(1)
32
+
33
+
34
+ def run_slurm(args):
35
+ # Only import the necessary modules if function is called
36
+ # to avoid unnecessary slow imports at the top level
37
+ from dask.distributed import Client
38
+ from dask_jobqueue import SLURMCluster
39
+ from .runner import Runner
40
+
41
+ cluster = SLURMCluster(
42
+ queue=args.queue,
43
+ account=args.account,
44
+ cores=args.cores,
45
+ # processes = args.processes,
46
+ memory=args.memory,
47
+ interface=args.interface,
48
+ )
49
+ cluster.scale(jobs=3)
50
+ client = Client(cluster)
51
+
52
+ logging.basicConfig(
53
+ level=logging.INFO, # Set log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
54
+ format=LOG_FORMAT,
55
+ datefmt=DATE_FORMAT,
56
+ handlers=[
57
+ # logging.FileHandler("app.log"), # Log to a file
58
+ logging.StreamHandler() # Log to console
59
+ ],
60
+ )
61
+
62
+ runner = Runner(args.CONFIG)
63
+ try:
64
+ runner.run()
65
+ except Exception:
66
+ LOG.error("Error during verification closing down dask cluster", exc_info=True)
67
+ client.close()
68
+ cluster.close()
69
+ sys.exit(1)
70
+
71
+
72
+ def main():
73
+
74
+ parser = argparse.ArgumentParser(description="mxalign CLI")
75
+ subparsers = parser.add_subparsers(
76
+ dest="command", required=True, help="Available commands"
77
+ )
78
+
79
+ local_parser = subparsers.add_parser(
80
+ "local",
81
+ help="Run the verification pipeline based on a config-file on a local dask cluster",
82
+ )
83
+
84
+ local_parser.add_argument(
85
+ "--n_workers", default=4, type=int, help="Number of dask workers"
86
+ )
87
+
88
+ local_parser.add_argument(
89
+ "--threads_per_worker",
90
+ default=1,
91
+ type=int,
92
+ help="Number of threads per dask worker",
93
+ )
94
+
95
+ slurm_parser = subparsers.add_parser(
96
+ "slurm",
97
+ help="Run the verification pipeline based on a config-file on a slurm cluster",
98
+ )
99
+
100
+ slurm_parser.add_argument(
101
+ "--queue", type=str, help="Destination queue for the worker jobs"
102
+ )
103
+
104
+ slurm_parser.add_argument(
105
+ "--account", type=str, help="Account to charge the jobs to"
106
+ )
107
+
108
+ slurm_parser.add_argument(
109
+ "--cores",
110
+ type=int,
111
+ default=8,
112
+ help="Total number of CPU cores on which all worker threads inside a job will run",
113
+ )
114
+
115
+ slurm_parser.add_argument(
116
+ "--memory",
117
+ type=str,
118
+ default="64GB",
119
+ help="Total amount of memory to be used by all workers inside a job",
120
+ )
121
+
122
+ slurm_parser.add_argument(
123
+ "--interface",
124
+ type=str,
125
+ default="hsn0",
126
+ help="Network interface to use for the dask workers",
127
+ )
128
+ parser.add_argument("CONFIG", type=str, help="Path to the YAML configuration file")
129
+
130
+ args = parser.parse_args()
131
+
132
+ if args.command == "local":
133
+ run_local(args)
134
+ elif args.command == "slurm":
135
+ run_slurm(args)
136
+ elif not args.command:
137
+ parser.print_help()
138
+ sys.exit(1)
139
+ else:
140
+ LOG.error(f"Unknown command: {args.command}")
141
+ parser.print_help()
142
+ sys.exit(1)
143
+
144
+
145
+ if __name__ == "__main__":
146
+ logging.basicConfig(
147
+ level=logging.INFO, # Set log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
148
+ format=LOG_FORMAT,
149
+ datefmt=DATE_FORMAT,
150
+ handlers=[
151
+ # logging.FileHandler("app.log"), # Log to a file
152
+ logging.StreamHandler() # Log to console
153
+ ],
154
+ )
155
+
156
+ LOG.info("Starting mxalign CLI")
157
+ main()
@@ -0,0 +1,9 @@
1
+ from . import base
2
+ from . import xarray
3
+ from . import delaunay
4
+
5
+ __all__ = [
6
+ "base",
7
+ "xarray",
8
+ "delaunay",
9
+ ]
@@ -0,0 +1,29 @@
1
+ import xarray as xr
2
+ from ..properties.properties import Space
3
+ from ..properties.utils import update_space_property
4
+
5
+
6
+ class BaseInterpolator:
7
+ """Base class for all interpolators."""
8
+
9
+ name: str = "base"
10
+ source_space: Space | None = None
11
+ target_space: Space | None = None
12
+
13
+ def __init__(self, target_dataset, **options):
14
+ self.target_dataset = target_dataset
15
+ self.options = options
16
+ # TODO: Check the properties
17
+
18
+ # def supports(self, src: Properties, tgt: Properties):
19
+
20
+ def interpolate(
21
+ self, source_dataset: xr.Dataset | xr.DataArray
22
+ ) -> xr.Dataset | xr.DataArray:
23
+ ds_out = self._interpolate(source_dataset)
24
+ return update_space_property(ds_out, self.target_space)
25
+
26
+ def _interpolate(
27
+ self, source_dataset: xr.Dataset | xr.DataArray
28
+ ) -> xr.Dataset | xr.DataArray:
29
+ pass
@@ -0,0 +1,218 @@
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import dask.array as dda
5
+ import xarray as xr
6
+
7
+ from scipy.spatial import Delaunay
8
+ from scipy.sparse import csr_matrix
9
+
10
+ from .base import BaseInterpolator
11
+ from .registry import register_interpolator
12
+ from ..properties.properties import Space
13
+
14
+
15
+ @register_interpolator
16
+ class DelaunayInterpolator(BaseInterpolator):
17
+ name = "delaunay"
18
+ source_space = Space.GRID
19
+ target_space = Space.POINT
20
+
21
+ def __init__(self, target_dataset, **options):
22
+ super().__init__(target_dataset, **options)
23
+ method = self.options.get("method", "linear")
24
+ self._W_cache = {} # keyed by source grid hash
25
+ if method != "linear":
26
+ raise ValueError(
27
+ f"Method: {method}. Delaunay interpolation only supports linear interpolation"
28
+ )
29
+
30
+ def _get_weights(self, source_points, target_points):
31
+ key = (
32
+ source_points.shape,
33
+ source_points[0, 0],
34
+ source_points[-1, 1],
35
+ ) # cheap fingerprint
36
+ if key not in self._W_cache:
37
+ triangulation = Delaunay(source_points)
38
+ self._W_cache[key] = _build_weight_matrix(
39
+ triangulation, source_points, target_points
40
+ )
41
+ return self._W_cache[key]
42
+
43
+ def _interpolate(self, source_dataset):
44
+ if "grid_index" not in source_dataset.dims:
45
+ raise NotImplementedError(
46
+ "Delaunay interpolation currently only supports stacked grids"
47
+ )
48
+
49
+ if "latitude" in source_dataset.dims:
50
+ lon_grid, lat_grid = np.meshgrid(
51
+ source_dataset["longitude"].values, source_dataset["latitude"].values
52
+ )
53
+ source_points = np.column_stack((lat_grid.ravel(), lon_grid.ravel()))
54
+ else:
55
+ source_points = np.column_stack(
56
+ (source_dataset["latitude"].values, source_dataset["longitude"].values)
57
+ )
58
+
59
+ target_points = np.column_stack(
60
+ (
61
+ self.target_dataset["latitude"].values,
62
+ self.target_dataset["longitude"].values,
63
+ )
64
+ )
65
+
66
+ # Compute triangulation and sparse weight matrix ONCE, shared across all variables
67
+ W = self._get_weights(source_points, target_points)
68
+
69
+ arrays_out = {}
70
+ for var in source_dataset.data_vars:
71
+ da = source_dataset[var]
72
+ if da.dims[-1] != "grid_index":
73
+ print(
74
+ f"Skipping variable '{var}' - doesn't end with spatial dimension grid_index"
75
+ )
76
+ continue
77
+ else:
78
+ arrays_out[var] = interpolate_da(da, W, target_points)
79
+
80
+ ds_out = xr.Dataset(arrays_out).assign_coords(
81
+ latitude=self.target_dataset["latitude"],
82
+ longitude=self.target_dataset["longitude"],
83
+ )
84
+ ds_out.attrs["properties"] = source_dataset.attrs["properties"]
85
+ return ds_out
86
+
87
+
88
+ def _build_weight_matrix(
89
+ triangulation: Delaunay,
90
+ source_points: np.ndarray,
91
+ target_points: np.ndarray,
92
+ ) -> csr_matrix:
93
+ """
94
+ Precompute a sparse (n_target, n_source) weight matrix from the triangulation.
95
+
96
+ Applying W to a (n_source,) value vector gives (n_target,) interpolated values
97
+ via a simple sparse matrix multiply. Target points outside the convex hull
98
+ receive NaN weights.
99
+ """
100
+
101
+ print("Calculating interpolation-weight matrix")
102
+
103
+ n_target = len(target_points)
104
+ n_source = len(source_points)
105
+ ndim = source_points.shape[1] # 2 for lat/lon
106
+
107
+ # Find which simplex each target point falls in; -1 means outside convex hull
108
+ simplex_indices = triangulation.find_simplex(target_points) # (n_target,)
109
+
110
+ # Map outside points to simplex 0 temporarily to avoid index errors —
111
+ # their weights will be NaN'd out below
112
+ safe_indices = np.where(simplex_indices >= 0, simplex_indices, 0)
113
+
114
+ # Vertices of each target point's simplex: (n_target, ndim+1)
115
+ simplex_vertices = triangulation.simplices[safe_indices]
116
+
117
+ # Recover barycentric coordinates using the affine transforms stored in
118
+ # triangulation.transform: shape (nsimplex, ndim+1, ndim)
119
+ # transform[s, :ndim, :] — inverse of the edge matrix for simplex s
120
+ # transform[s, ndim, :] — the ndim-th vertex (origin) of simplex s
121
+ Tinv = triangulation.transform[safe_indices, :ndim, :] # (n_target, ndim, ndim)
122
+ origin = triangulation.transform[safe_indices, ndim, :] # (n_target, ndim)
123
+
124
+ r = target_points - origin # (n_target, ndim)
125
+ bary_partial = np.einsum("nij,nj->ni", Tinv, r) # (n_target, ndim)
126
+ last = 1.0 - bary_partial.sum(axis=1, keepdims=True)
127
+ bary = np.concatenate([bary_partial, last], axis=1) # (n_target, ndim+1)
128
+
129
+ # Flatten into coordinate format (COO) for sparse matrix construction
130
+ rows = np.repeat(np.arange(n_target), ndim + 1)
131
+ cols = simplex_vertices.ravel()
132
+ vals = bary.ravel()
133
+
134
+ # NaN out weights for points outside the convex hull
135
+ outside = simplex_indices == -1
136
+ vals[np.repeat(outside, ndim + 1)] = np.nan
137
+
138
+ W = csr_matrix((vals, (rows, cols)), shape=(n_target, n_source))
139
+
140
+ print("Done")
141
+
142
+ return W
143
+
144
+
145
+ def interpolate_da(
146
+ da: xr.DataArray, W: csr_matrix, target_points: np.ndarray
147
+ ) -> xr.DataArray:
148
+ n_target = len(target_points)
149
+ leading_dims = da.dims[:-1]
150
+
151
+ # Validate that grid_index is not chunked
152
+ if isinstance(da.data, dda.Array):
153
+ grid_chunks = dict(zip(da.dims, da.chunks)).get("grid_index")
154
+ if grid_chunks is not None and len(grid_chunks) > 1:
155
+ raise ValueError(
156
+ f"grid_index must not be chunked for Delaunay interpolation "
157
+ f"(found {len(grid_chunks)} chunks). Rechunk with da.chunk({{'grid_index': -1}}) "
158
+ f"or enforce this on the loading side."
159
+ )
160
+
161
+ # Build the template
162
+ # Get chunking info for leading dims
163
+ shape_tmp = tuple(da.sizes[d] for d in leading_dims) + (n_target,)
164
+
165
+ if isinstance(da.data, dda.Array):
166
+ dim_to_chunks = dict(zip(da.dims, da.chunks))
167
+ else:
168
+ dim_to_chunks = {dim: (da.sizes[dim],) for dim in da.dims}
169
+
170
+ chunks_tmp = tuple(
171
+ dim_to_chunks[dim] if dim in dim_to_chunks else (da.sizes[dim],)
172
+ for dim in leading_dims
173
+ ) + ((n_target,),)
174
+
175
+ # Create a dask array template matching the chunking pattern
176
+ tmp = dda.empty(shape=shape_tmp, chunks=chunks_tmp, dtype=da.dtype)
177
+ tmp = xr.DataArray(
178
+ tmp,
179
+ dims=leading_dims + ("point_index",),
180
+ coords={d: da.coords[d].load() for d in leading_dims},
181
+ )
182
+
183
+ # Drop coords tied to grid_index to avoid dimension mismatch in map_blocks
184
+ spatial_coords = [c for c in da.coords if "grid_index" in da[c].dims]
185
+ da_clean = da.drop_vars(spatial_coords)
186
+
187
+ da_interp = da_clean.map_blocks(
188
+ partial(interpolate_block, W=W, target_points=target_points), template=tmp
189
+ )
190
+
191
+ return da_interp
192
+
193
+
194
+ def interpolate_block(
195
+ block: xr.DataArray,
196
+ W: csr_matrix,
197
+ target_points: np.ndarray,
198
+ ) -> xr.DataArray:
199
+ data = block.values # shape = (.., npoints)
200
+ original_shape = data.shape[:-1]
201
+ data_flat = data.reshape(
202
+ -1, data.shape[-1]
203
+ ) # shape = (ndim1 * ndim2 * ... , npoints)
204
+
205
+ # Identify NaN source points
206
+ nan_mask = np.isnan(data_flat) # (nleading, n_source)
207
+
208
+ if nan_mask.any():
209
+ print(f"Warning, interpolating NaNs for variable {block.name}")
210
+
211
+ # Single sparse matrix multiply replaces the per-row interpolator loop:
212
+ # (nleading, n_source) @ (n_source, n_target) -> (nleading, n_target)
213
+ interpolated_flat = data_flat @ W.T
214
+ interpolated = interpolated_flat.reshape(*original_shape, target_points.shape[0])
215
+
216
+ new_dims = block.dims[:-1] + ("point_index",)
217
+ new_coords = {dim: block.coords[dim] for dim in block.dims[:-1]}
218
+ return xr.DataArray(interpolated, dims=new_dims, coords=new_coords)
@@ -0,0 +1,29 @@
1
+ from .registry import get_interpolation
2
+
3
+
4
+ def interpolate(source_datasets, target_dataset, method, **kwargs):
5
+ interp_cls = get_interpolation(method)
6
+ interpolator = interp_cls(target_dataset, **kwargs)
7
+
8
+ if isinstance(source_datasets, dict):
9
+ keys = list(source_datasets.keys())
10
+ datasets = list(source_datasets.values())
11
+ else:
12
+ if not isinstance(source_datasets, list):
13
+ datasets = [source_datasets]
14
+ keys = None
15
+
16
+ if keys:
17
+ interpolated_datasets = dict()
18
+ for key, ds in zip(keys, datasets):
19
+ interpolated_datasets[key] = interpolator.interpolate(ds.copy())
20
+ else:
21
+ interpolated_datasets = []
22
+ for ds in datasets:
23
+ interpolated_datasets.append(interpolator.interpolate(ds.copy()))
24
+ interpolated_datasets = (
25
+ interpolated_datasets[0]
26
+ if len(interpolated_datasets) == 1
27
+ else interpolated_datasets
28
+ )
29
+ return interpolated_datasets
@@ -0,0 +1,17 @@
1
+ _INTERPOLATORS = {}
2
+
3
+
4
+ def register_interpolator(cls):
5
+ _INTERPOLATORS[cls.name] = cls
6
+ return cls
7
+
8
+
9
+ def available_interpolations():
10
+ return list(_INTERPOLATORS.keys())
11
+
12
+
13
+ def get_interpolation(name):
14
+ try:
15
+ return _INTERPOLATORS[name]
16
+ except KeyError:
17
+ raise ValueError(f"Unknown interpolation: {name}")
@@ -0,0 +1,63 @@
1
+ from .base import BaseInterpolator
2
+ from .registry import register_interpolator
3
+ from ..properties.properties import Space
4
+
5
+ import xarray as xr
6
+
7
+
8
+ @register_interpolator
9
+ class XarrayInterpolator(BaseInterpolator):
10
+ name = "xarray"
11
+ source_space = Space.GRID
12
+ target_space = Space.POINT
13
+
14
+ def _interpolate(self, source_dataset):
15
+
16
+ if "latitude" in source_dataset.dims and "longitude" in source_dataset.dims:
17
+ ds_out = self._interpolate_from_latlon(source_dataset)
18
+
19
+ else:
20
+ if source_dataset.space.is_stacked():
21
+ try:
22
+ source_dataset = source_dataset.space.unstack()
23
+ except ValueError:
24
+ raise ValueError(
25
+ "Cannot unstack dataset, dataset must be unstacked to use xarray interpolation"
26
+ )
27
+ ds_out = self._interpolate_from_xcyc(source_dataset)
28
+ return ds_out
29
+
30
+ def _interpolate_from_xcyc(self, source_dataset):
31
+ import cartopy.crs as ccrs
32
+
33
+ try:
34
+ crs = source_dataset.attrs["crs"]
35
+ except KeyError:
36
+ raise KeyError("Source dataset does not have a crs-attribute")
37
+
38
+ xyz = crs.transform_points(
39
+ x=self.target_dataset["longitude"].values,
40
+ y=self.target_dataset["latitude"].values,
41
+ src_crs=ccrs.PlateCarree(),
42
+ )
43
+
44
+ x = xr.DataArray(xyz[:, 0], dims="point_index")
45
+
46
+ y = xr.DataArray(xyz[:, 1], dims="point_index")
47
+
48
+ ds_out = source_dataset.interp(xc=x, yc=y, **self.options)
49
+ # ).assing_coords(
50
+ # longitude=self.target_dataset["longitude"],
51
+ # latitude=self.target_dataset["latitude"]
52
+ # )
53
+
54
+ return ds_out
55
+
56
+ def _interpolate_from_latlon(self, source_dataset):
57
+ longitude = self.target_dataset["longitude"]
58
+ latitude = self.target_dataset["latitude"]
59
+ ds_out = source_dataset.interp(
60
+ longitude=longitude, latitude=latitude, **self.options
61
+ )
62
+
63
+ return ds_out
@@ -0,0 +1,11 @@
1
+ from . import anemoi_datasets
2
+ from . import anemoi_inference
3
+ from . import harp_obstable
4
+ from . import base
5
+
6
+ __all__ = [
7
+ "anemoi_datasets",
8
+ "anemoi_inference",
9
+ "harp_obstable",
10
+ "base",
11
+ ]
@@ -0,0 +1,92 @@
1
+ import numpy as np
2
+ import xarray as xr
3
+
4
+ from .registry import register_loader
5
+ from ..properties.properties import Space, Time, Uncertainty
6
+ from .base import BaseLoader
7
+
8
+ DROP_VARS = [
9
+ "latitude",
10
+ "longitude",
11
+ "time",
12
+ "cos_julian_day",
13
+ "cos_latitude",
14
+ "cos_local_time",
15
+ "cos_longitude",
16
+ "insolation",
17
+ "sin_julian_day",
18
+ "sin_latitude",
19
+ "sin_local_time",
20
+ "sin_longitude",
21
+ ]
22
+
23
+ COORDS = dict(longitude="longitudes", latitude="latitudes", valid_time="dates")
24
+
25
+ DEFAULTS = {"chunks": "auto"}
26
+
27
+
28
+ @register_loader
29
+ class AnemoiDatasetsLoader(BaseLoader):
30
+ name = "anemoi-datasets"
31
+
32
+ space = Space.GRID
33
+ time = Time.OBSERVATION
34
+ uncertainty = Uncertainty.DETERMINISTIC
35
+
36
+ def _load(self):
37
+
38
+ if isinstance(self.files, list):
39
+ dss = [xr.open_zarr(file, consolidated=False) for file in self.files]
40
+ dss_postproc = [_postprocess(ds) for ds in dss]
41
+ ds_postproc = xr.concat(dss_postproc, dim="valid_time")
42
+ else:
43
+ ds = xr.open_zarr(self.files, consolidated=False)
44
+ ds_postproc = _postprocess(ds)
45
+
46
+ if self.variables:
47
+ ds_selected = ds_postproc.sel(variable=self.variables)
48
+ else:
49
+ ds_selected = ds_postproc
50
+ if len(ds_selected["variable"]) > 10:
51
+ print(
52
+ f"Transforming anemoi-datasets xr.DataArray with {len(ds_postproc['variable'])} variables to xr.Dataset, this might take some time. Consider selecting the relevant variables during loading"
53
+ )
54
+ return ds_selected.to_dataset(dim="variable")
55
+
56
+
57
+ def _postprocess(dataset: xr.Dataset) -> xr.Dataset:
58
+ """Post-process the dataset to add coordinates and drop unused variables.
59
+
60
+ Args:
61
+ dataset (xr.Dataset): The input dataset to be processed.
62
+
63
+ Returns:
64
+ xr.Dataset: The processed dataset with assigned coordinates and
65
+ attributes.
66
+ """
67
+
68
+ # Add coordinates
69
+ coords = {
70
+ key: dataset[value].astype("datetime64[ns]").load()
71
+ if key == "valid_time"
72
+ else dataset[value].load()
73
+ for key, value in COORDS.items()
74
+ }
75
+ for key in ("latitude", "longitude"):
76
+ coords[key] = coords[key].astype(np.float32)
77
+
78
+ coords["variable"] = dataset.attrs["variables"]
79
+ coords["valid_time"] = coords["valid_time"].astype("datetime64[ns]")
80
+ ds_coords = dataset.assign_coords(coords)
81
+
82
+ # Drop unused variables and remove ensemble dimension
83
+ drop_vars = [var for var in DROP_VARS if var in coords["variable"]]
84
+
85
+ ds_pruned = (
86
+ ds_coords["data"]
87
+ .isel(ensemble=0)
88
+ .drop_sel(variable=drop_vars)
89
+ .swap_dims({"time": "valid_time"})
90
+ .rename({"cell": "grid_index"})
91
+ )
92
+ return ds_pruned