anemoi-datasets 0.5.6__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/__init__.py +11 -3
- anemoi/datasets/__main__.py +2 -3
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/__init__.py +2 -3
- anemoi/datasets/commands/cleanup.py +9 -0
- anemoi/datasets/commands/compare.py +3 -3
- anemoi/datasets/commands/copy.py +38 -68
- anemoi/datasets/commands/create.py +20 -5
- anemoi/datasets/commands/finalise-additions.py +9 -0
- anemoi/datasets/commands/finalise.py +9 -0
- anemoi/datasets/commands/init-additions.py +9 -0
- anemoi/datasets/commands/init.py +9 -0
- anemoi/datasets/commands/inspect.py +7 -1
- anemoi/datasets/commands/load-additions.py +9 -0
- anemoi/datasets/commands/load.py +9 -0
- anemoi/datasets/commands/patch.py +9 -0
- anemoi/datasets/commands/publish.py +9 -0
- anemoi/datasets/commands/scan.py +9 -0
- anemoi/datasets/compute/__init__.py +8 -0
- anemoi/datasets/compute/recentre.py +3 -2
- anemoi/datasets/create/__init__.py +64 -48
- anemoi/datasets/create/check.py +4 -3
- anemoi/datasets/create/chunks.py +3 -2
- anemoi/datasets/create/config.py +5 -5
- anemoi/datasets/create/functions/__init__.py +22 -7
- anemoi/datasets/create/functions/filters/__init__.py +2 -1
- anemoi/datasets/create/functions/filters/empty.py +3 -2
- anemoi/datasets/create/functions/filters/noop.py +2 -2
- anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +3 -2
- anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +3 -2
- anemoi/datasets/create/functions/filters/rename.py +16 -10
- anemoi/datasets/create/functions/filters/rotate_winds.py +3 -2
- anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +3 -2
- anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +3 -2
- anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +2 -2
- anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +2 -2
- anemoi/datasets/create/functions/filters/speeddir_to_uv.py +3 -2
- anemoi/datasets/create/functions/filters/unrotate_winds.py +3 -2
- anemoi/datasets/create/functions/filters/uv_to_speeddir.py +3 -2
- anemoi/datasets/create/functions/sources/__init__.py +2 -2
- anemoi/datasets/create/functions/sources/accumulations.py +10 -4
- anemoi/datasets/create/functions/sources/constants.py +3 -2
- anemoi/datasets/create/functions/sources/empty.py +3 -2
- anemoi/datasets/create/functions/sources/forcings.py +3 -2
- anemoi/datasets/create/functions/sources/grib.py +2 -2
- anemoi/datasets/create/functions/sources/hindcasts.py +3 -2
- anemoi/datasets/create/functions/sources/mars.py +97 -17
- anemoi/datasets/create/functions/sources/netcdf.py +3 -2
- anemoi/datasets/create/functions/sources/opendap.py +2 -2
- anemoi/datasets/create/functions/sources/recentre.py +3 -2
- anemoi/datasets/create/functions/sources/source.py +3 -2
- anemoi/datasets/create/functions/sources/tendencies.py +3 -2
- anemoi/datasets/create/functions/sources/xarray/__init__.py +8 -2
- anemoi/datasets/create/functions/sources/xarray/coordinates.py +5 -2
- anemoi/datasets/create/functions/sources/xarray/field.py +3 -2
- anemoi/datasets/create/functions/sources/xarray/fieldlist.py +12 -2
- anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -16
- anemoi/datasets/create/functions/sources/xarray/grid.py +3 -2
- anemoi/datasets/create/functions/sources/xarray/metadata.py +3 -2
- anemoi/datasets/create/functions/sources/xarray/time.py +39 -4
- anemoi/datasets/create/functions/sources/xarray/variable.py +6 -6
- anemoi/datasets/create/functions/sources/xarray_kerchunk.py +2 -2
- anemoi/datasets/create/functions/sources/xarray_zarr.py +2 -2
- anemoi/datasets/create/functions/sources/zenodo.py +2 -2
- anemoi/datasets/create/input/__init__.py +3 -17
- anemoi/datasets/create/input/action.py +3 -2
- anemoi/datasets/create/input/concat.py +3 -2
- anemoi/datasets/create/input/context.py +3 -2
- anemoi/datasets/create/input/data_sources.py +3 -2
- anemoi/datasets/create/input/empty.py +3 -2
- anemoi/datasets/create/input/filter.py +3 -2
- anemoi/datasets/create/input/function.py +3 -2
- anemoi/datasets/create/input/join.py +3 -2
- anemoi/datasets/create/input/misc.py +3 -2
- anemoi/datasets/create/input/pipe.py +3 -2
- anemoi/datasets/create/input/repeated_dates.py +3 -2
- anemoi/datasets/create/input/result.py +187 -3
- anemoi/datasets/create/input/step.py +4 -2
- anemoi/datasets/create/input/template.py +3 -2
- anemoi/datasets/create/input/trace.py +3 -2
- anemoi/datasets/create/patch.py +9 -1
- anemoi/datasets/create/persistent.py +7 -3
- anemoi/datasets/create/size.py +3 -2
- anemoi/datasets/create/statistics/__init__.py +7 -3
- anemoi/datasets/create/statistics/summary.py +3 -2
- anemoi/datasets/create/utils.py +15 -2
- anemoi/datasets/create/writer.py +3 -2
- anemoi/datasets/create/zarr.py +8 -3
- anemoi/datasets/data/__init__.py +27 -1
- anemoi/datasets/data/concat.py +5 -1
- anemoi/datasets/data/dataset.py +216 -37
- anemoi/datasets/data/debug.py +4 -1
- anemoi/datasets/data/ensemble.py +4 -1
- anemoi/datasets/data/fill_missing.py +165 -0
- anemoi/datasets/data/forwards.py +27 -2
- anemoi/datasets/data/grids.py +236 -58
- anemoi/datasets/data/indexing.py +4 -1
- anemoi/datasets/data/interpolate.py +4 -1
- anemoi/datasets/data/join.py +17 -1
- anemoi/datasets/data/masked.py +36 -10
- anemoi/datasets/data/merge.py +180 -0
- anemoi/datasets/data/misc.py +18 -3
- anemoi/datasets/data/missing.py +4 -1
- anemoi/datasets/data/rescale.py +4 -1
- anemoi/datasets/data/select.py +15 -1
- anemoi/datasets/data/statistics.py +4 -1
- anemoi/datasets/data/stores.py +70 -3
- anemoi/datasets/data/subset.py +6 -1
- anemoi/datasets/data/unchecked.py +9 -1
- anemoi/datasets/data/xy.py +20 -5
- anemoi/datasets/dates/__init__.py +9 -7
- anemoi/datasets/dates/groups.py +3 -1
- anemoi/datasets/fields.py +3 -1
- anemoi/datasets/grids.py +86 -2
- anemoi/datasets/testing.py +60 -0
- anemoi/datasets/utils/__init__.py +8 -0
- anemoi/datasets/utils/fields.py +2 -2
- {anemoi_datasets-0.5.6.dist-info → anemoi_datasets-0.5.10.dist-info}/METADATA +11 -29
- anemoi_datasets-0.5.10.dist-info/RECORD +124 -0
- {anemoi_datasets-0.5.6.dist-info → anemoi_datasets-0.5.10.dist-info}/WHEEL +1 -1
- anemoi_datasets-0.5.6.dist-info/RECORD +0 -121
- {anemoi_datasets-0.5.6.dist-info → anemoi_datasets-0.5.10.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.5.6.dist-info → anemoi_datasets-0.5.10.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.6.dist-info → anemoi_datasets-0.5.10.dist-info}/top_level.txt +0 -0
anemoi/datasets/data/grids.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
|
-
# (C) Copyright 2024
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
2
3
|
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
4
|
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
4
6
|
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
7
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
8
|
# nor does it submit to any jurisdiction.
|
|
7
9
|
|
|
10
|
+
|
|
8
11
|
import logging
|
|
9
12
|
from functools import cached_property
|
|
10
13
|
|
|
11
14
|
import numpy as np
|
|
15
|
+
from scipy.spatial import cKDTree
|
|
12
16
|
|
|
13
17
|
from .debug import Node
|
|
14
18
|
from .debug import debug_indexing
|
|
@@ -105,6 +109,17 @@ class GridsBase(GivenAxis):
|
|
|
105
109
|
# We don't check the resolution, because we want to be able to combine
|
|
106
110
|
pass
|
|
107
111
|
|
|
112
|
+
def metadata_specific(self):
|
|
113
|
+
return super().metadata_specific(
|
|
114
|
+
multi_grids=True,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def collect_input_sources(self, collected):
|
|
118
|
+
# We assume that,because they have different grids, they have different input sources
|
|
119
|
+
for d in self.datasets:
|
|
120
|
+
collected.append(d)
|
|
121
|
+
d.collect_input_sources(collected)
|
|
122
|
+
|
|
108
123
|
|
|
109
124
|
class Grids(GridsBase):
|
|
110
125
|
# TODO: select the statistics of the most global grid?
|
|
@@ -128,85 +143,248 @@ class Grids(GridsBase):
|
|
|
128
143
|
|
|
129
144
|
|
|
130
145
|
class Cutout(GridsBase):
|
|
131
|
-
def __init__(self, datasets, axis
|
|
132
|
-
|
|
133
|
-
|
|
146
|
+
def __init__(self, datasets, axis=3, cropping_distance=2.0, neighbours=5, min_distance_km=None, plot=None):
|
|
147
|
+
"""Initializes a Cutout object for hierarchical management of Limited Area
|
|
148
|
+
Models (LAMs) and a global dataset, handling overlapping regions.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
datasets (list): List of LAM and global datasets.
|
|
152
|
+
axis (int): Concatenation axis, must be set to 3.
|
|
153
|
+
cropping_distance (float): Distance threshold in degrees for
|
|
154
|
+
cropping cutouts.
|
|
155
|
+
neighbours (int): Number of neighboring points to consider when
|
|
156
|
+
constructing masks.
|
|
157
|
+
min_distance_km (float, optional): Minimum distance threshold in km
|
|
158
|
+
between grid points.
|
|
159
|
+
plot (bool, optional): Flag to enable or disable visualization
|
|
160
|
+
plots.
|
|
161
|
+
"""
|
|
134
162
|
super().__init__(datasets, axis)
|
|
135
|
-
assert len(datasets)
|
|
163
|
+
assert len(datasets) >= 2, "CutoutGrids requires at least two datasets"
|
|
136
164
|
assert axis == 3, "CutoutGrids requires axis=3"
|
|
165
|
+
assert cropping_distance >= 0, "cropping_distance must be a non-negative number"
|
|
166
|
+
if min_distance_km is not None:
|
|
167
|
+
assert min_distance_km >= 0, "min_distance_km must be a non-negative number"
|
|
168
|
+
|
|
169
|
+
self.lams = datasets[:-1] # Assume the last dataset is the global one
|
|
170
|
+
self.globe = datasets[-1]
|
|
171
|
+
self.axis = axis
|
|
172
|
+
self.cropping_distance = cropping_distance
|
|
173
|
+
self.neighbours = neighbours
|
|
174
|
+
self.min_distance_km = min_distance_km
|
|
175
|
+
self.plot = plot
|
|
176
|
+
self.masks = [] # To store the masks for each LAM dataset
|
|
177
|
+
self.global_mask = np.ones(self.globe.shape[-1], dtype=bool)
|
|
178
|
+
|
|
179
|
+
# Initialize cumulative masks
|
|
180
|
+
self._initialize_masks()
|
|
181
|
+
|
|
182
|
+
def _initialize_masks(self):
|
|
183
|
+
"""Generates hierarchical masks for each LAM dataset by excluding
|
|
184
|
+
overlapping regions with previous LAMs and creating a global mask for
|
|
185
|
+
the global dataset.
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
ValueError: If the global mask dimension does not match the global
|
|
189
|
+
dataset grid points.
|
|
190
|
+
"""
|
|
191
|
+
from anemoi.datasets.grids import cutout_mask
|
|
137
192
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
193
|
+
for i, lam in enumerate(self.lams):
|
|
194
|
+
assert len(lam.shape) == len(
|
|
195
|
+
self.globe.shape
|
|
196
|
+
), "LAMs and global dataset must have the same number of dimensions"
|
|
197
|
+
lam_lats = lam.latitudes
|
|
198
|
+
lam_lons = lam.longitudes
|
|
199
|
+
# Create a mask for the global dataset excluding all LAM points
|
|
200
|
+
global_overlap_mask = cutout_mask(
|
|
201
|
+
lam.latitudes,
|
|
202
|
+
lam.longitudes,
|
|
203
|
+
self.globe.latitudes,
|
|
204
|
+
self.globe.longitudes,
|
|
205
|
+
plot=False,
|
|
206
|
+
min_distance_km=self.min_distance_km,
|
|
207
|
+
cropping_distance=self.cropping_distance,
|
|
208
|
+
neighbours=self.neighbours,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Ensure the mask dimensions match the global grid points
|
|
212
|
+
if global_overlap_mask.shape[0] != self.globe.shape[-1]:
|
|
213
|
+
raise ValueError("Global mask dimension does not match global dataset grid " "points.")
|
|
214
|
+
self.global_mask[~global_overlap_mask] = False
|
|
215
|
+
|
|
216
|
+
# Create a mask for the LAM datasets hierarchically, excluding
|
|
217
|
+
# points from previous LAMs
|
|
218
|
+
lam_current_mask = np.ones(lam.shape[-1], dtype=bool)
|
|
219
|
+
if i > 0:
|
|
220
|
+
for j in range(i):
|
|
221
|
+
prev_lam = self.lams[j]
|
|
222
|
+
prev_lam_lats = prev_lam.latitudes
|
|
223
|
+
prev_lam_lons = prev_lam.longitudes
|
|
224
|
+
# Check for overlap by computing distances
|
|
225
|
+
if self.has_overlap(prev_lam_lats, prev_lam_lons, lam_lats, lam_lons):
|
|
226
|
+
lam_overlap_mask = cutout_mask(
|
|
227
|
+
prev_lam_lats,
|
|
228
|
+
prev_lam_lons,
|
|
229
|
+
lam_lats,
|
|
230
|
+
lam_lons,
|
|
231
|
+
plot=False,
|
|
232
|
+
min_distance_km=self.min_distance_km,
|
|
233
|
+
cropping_distance=self.cropping_distance,
|
|
234
|
+
neighbours=self.neighbours,
|
|
235
|
+
)
|
|
236
|
+
lam_current_mask[~lam_overlap_mask] = False
|
|
237
|
+
self.masks.append(lam_current_mask)
|
|
238
|
+
|
|
239
|
+
def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0):
|
|
240
|
+
"""Checks for overlapping points between two sets of latitudes and
|
|
241
|
+
longitudes within a specified distance threshold.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
lats1, lons1 (np.ndarray): Latitude and longitude arrays for the
|
|
245
|
+
first dataset.
|
|
246
|
+
lats2, lons2 (np.ndarray): Latitude and longitude arrays for the
|
|
247
|
+
second dataset.
|
|
248
|
+
distance_threshold (float): Distance in degrees to consider as
|
|
249
|
+
overlapping.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
bool: True if any points overlap within the distance threshold,
|
|
253
|
+
otherwise False.
|
|
254
|
+
"""
|
|
255
|
+
# Create KDTree for the first set of points
|
|
256
|
+
tree = cKDTree(np.vstack((lats1, lons1)).T)
|
|
257
|
+
|
|
258
|
+
# Query the second set of points against the first tree
|
|
259
|
+
distances, _ = tree.query(np.vstack((lats2, lons2)).T, k=1)
|
|
260
|
+
|
|
261
|
+
# Check if any distance is less than the specified threshold
|
|
262
|
+
return np.any(distances < distance_threshold)
|
|
263
|
+
|
|
264
|
+
def __getitem__(self, index):
|
|
265
|
+
"""Retrieves data from the masked LAMs and global dataset based on the
|
|
266
|
+
given index.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
index (int or slice or tuple): Index specifying the data to
|
|
270
|
+
retrieve.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
np.ndarray: Data array from the masked datasets based on the index.
|
|
274
|
+
"""
|
|
275
|
+
if isinstance(index, (int, slice)):
|
|
276
|
+
index = (index, slice(None), slice(None), slice(None))
|
|
277
|
+
return self._get_tuple(index)
|
|
278
|
+
|
|
279
|
+
def _get_tuple(self, index):
|
|
280
|
+
"""Helper method that applies masks and retrieves data from each dataset
|
|
281
|
+
according to the specified index.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
index (tuple): Index specifying slices to retrieve data.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
np.ndarray: Concatenated data array from all datasets based on the
|
|
288
|
+
index.
|
|
289
|
+
"""
|
|
290
|
+
index, changes = index_to_slices(index, self.shape)
|
|
291
|
+
# Select data from each LAM
|
|
292
|
+
lam_data = [lam[index] for lam in self.lams]
|
|
293
|
+
|
|
294
|
+
# First apply spatial indexing on `self.globe` and then apply the mask
|
|
295
|
+
globe_data_sliced = self.globe[index[:3]]
|
|
296
|
+
globe_data = globe_data_sliced[..., self.global_mask]
|
|
297
|
+
|
|
298
|
+
# Concatenate LAM data with global data
|
|
299
|
+
result = np.concatenate(lam_data + [globe_data], axis=self.axis)
|
|
300
|
+
return apply_index_to_slices_changes(result, changes)
|
|
301
|
+
|
|
302
|
+
def collect_supporting_arrays(self, collected, *path):
|
|
303
|
+
"""Collects supporting arrays, including masks for each LAM and the global
|
|
304
|
+
dataset.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
collected (list): List to which the supporting arrays are appended.
|
|
308
|
+
*path: Variable length argument list specifying the paths for the masks.
|
|
309
|
+
"""
|
|
310
|
+
# Append masks for each LAM
|
|
311
|
+
for i, (lam, mask) in enumerate(zip(self.lams, self.masks)):
|
|
312
|
+
collected.append((path + (f"lam_{i}",), "cutout_mask", mask))
|
|
313
|
+
|
|
314
|
+
# Append the global mask
|
|
315
|
+
collected.append((path + ("global",), "cutout_mask", self.global_mask))
|
|
156
316
|
|
|
157
317
|
@cached_property
|
|
158
318
|
def shape(self):
|
|
159
|
-
shape
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
319
|
+
"""Returns the shape of the Cutout, accounting for retained grid points
|
|
320
|
+
across all LAMs and the global dataset.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
tuple: Shape of the concatenated masked datasets.
|
|
324
|
+
"""
|
|
325
|
+
shapes = [np.sum(mask) for mask in self.masks]
|
|
326
|
+
global_shape = np.sum(self.global_mask)
|
|
327
|
+
return tuple(self.lams[0].shape[:-1] + (sum(shapes) + global_shape,))
|
|
163
328
|
|
|
164
329
|
def check_same_resolution(self, d1, d2):
|
|
165
330
|
# Turned off because we are combining different resolutions
|
|
166
331
|
pass
|
|
167
332
|
|
|
168
333
|
@property
|
|
169
|
-
def
|
|
170
|
-
|
|
334
|
+
def grids(self):
|
|
335
|
+
"""Returns the number of grid points for each LAM and the global dataset
|
|
336
|
+
after applying masks.
|
|
171
337
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
338
|
+
Returns:
|
|
339
|
+
tuple: Count of retained grid points for each dataset.
|
|
340
|
+
"""
|
|
341
|
+
grids = [np.sum(mask) for mask in self.masks]
|
|
342
|
+
grids.append(np.sum(self.global_mask))
|
|
343
|
+
return tuple(grids)
|
|
175
344
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
345
|
+
@property
|
|
346
|
+
def latitudes(self):
|
|
347
|
+
"""Returns the concatenated latitudes of each LAM and the global dataset
|
|
348
|
+
after applying masks.
|
|
180
349
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
None
|
|
186
|
-
), f"No support for selecting a subset of the 1D values {index} ({self.tree()})"
|
|
187
|
-
index, changes = index_to_slices(index, self.shape)
|
|
350
|
+
Returns:
|
|
351
|
+
np.ndarray: Concatenated latitude array for the masked datasets.
|
|
352
|
+
"""
|
|
353
|
+
lam_latitudes = np.concatenate([lam.latitudes[mask] for lam, mask in zip(self.lams, self.masks)])
|
|
188
354
|
|
|
189
|
-
|
|
190
|
-
|
|
355
|
+
assert (
|
|
356
|
+
len(lam_latitudes) + len(self.globe.latitudes[self.global_mask]) == self.shape[-1]
|
|
357
|
+
), "Mismatch in number of latitudes"
|
|
191
358
|
|
|
192
|
-
|
|
193
|
-
|
|
359
|
+
latitudes = np.concatenate([lam_latitudes, self.globe.latitudes[self.global_mask]])
|
|
360
|
+
return latitudes
|
|
194
361
|
|
|
195
|
-
|
|
362
|
+
@property
|
|
363
|
+
def longitudes(self):
|
|
364
|
+
"""Returns the concatenated longitudes of each LAM and the global dataset
|
|
365
|
+
after applying masks.
|
|
196
366
|
|
|
197
|
-
|
|
367
|
+
Returns:
|
|
368
|
+
np.ndarray: Concatenated longitude array for the masked datasets.
|
|
369
|
+
"""
|
|
370
|
+
lam_longitudes = np.concatenate([lam.longitudes[mask] for lam, mask in zip(self.lams, self.masks)])
|
|
198
371
|
|
|
199
|
-
|
|
372
|
+
assert (
|
|
373
|
+
len(lam_longitudes) + len(self.globe.longitudes[self.global_mask]) == self.shape[-1]
|
|
374
|
+
), "Mismatch in number of longitudes"
|
|
200
375
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
for d in self.datasets:
|
|
204
|
-
if len(d.grids) > 1:
|
|
205
|
-
raise NotImplementedError("CutoutGrids does not support multi-grids datasets as inputs")
|
|
206
|
-
shape = self.lam.shape
|
|
207
|
-
return (shape[-1], self.shape[-1] - shape[-1])
|
|
376
|
+
longitudes = np.concatenate([lam_longitudes, self.globe.longitudes[self.global_mask]])
|
|
377
|
+
return longitudes
|
|
208
378
|
|
|
209
379
|
def tree(self):
|
|
380
|
+
"""Generates a hierarchical tree structure for the `Cutout` instance and
|
|
381
|
+
its associated datasets.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
Node: A `Node` object representing the `Cutout` instance as the root
|
|
385
|
+
node, with each dataset in `self.datasets` represented as a child
|
|
386
|
+
node.
|
|
387
|
+
"""
|
|
210
388
|
return Node(self, [d.tree() for d in self.datasets])
|
|
211
389
|
|
|
212
390
|
|
|
@@ -238,7 +416,7 @@ def cutout_factory(args, kwargs):
|
|
|
238
416
|
neighbours = kwargs.pop("neighbours", 5)
|
|
239
417
|
|
|
240
418
|
assert len(args) == 0
|
|
241
|
-
assert isinstance(cutout, (list, tuple))
|
|
419
|
+
assert isinstance(cutout, (list, tuple)), "cutout must be a list or tuple"
|
|
242
420
|
|
|
243
421
|
datasets = [_open(e) for e in cutout]
|
|
244
422
|
datasets, kwargs = _auto_adjust(datasets, kwargs)
|
anemoi/datasets/data/indexing.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
|
-
# (C) Copyright 2024
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
2
3
|
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
4
|
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
4
6
|
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
7
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
8
|
# nor does it submit to any jurisdiction.
|
|
7
9
|
|
|
10
|
+
|
|
8
11
|
from functools import wraps
|
|
9
12
|
|
|
10
13
|
import numpy as np
|
|
@@ -1,10 +1,13 @@
|
|
|
1
|
-
# (C) Copyright 2024
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
2
3
|
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
4
|
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
4
6
|
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
7
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
8
|
# nor does it submit to any jurisdiction.
|
|
7
9
|
|
|
10
|
+
|
|
8
11
|
import logging
|
|
9
12
|
from functools import cached_property
|
|
10
13
|
|
anemoi/datasets/data/join.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
|
-
# (C) Copyright 2024
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
2
3
|
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
4
|
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
4
6
|
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
7
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
8
|
# nor does it submit to any jurisdiction.
|
|
7
9
|
|
|
10
|
+
|
|
8
11
|
import logging
|
|
9
12
|
from functools import cached_property
|
|
10
13
|
|
|
@@ -111,6 +114,19 @@ class Join(Combined):
|
|
|
111
114
|
|
|
112
115
|
return result
|
|
113
116
|
|
|
117
|
+
@property
|
|
118
|
+
def variables_metadata(self):
|
|
119
|
+
result = {}
|
|
120
|
+
variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))]
|
|
121
|
+
for d in self.datasets:
|
|
122
|
+
md = d.variables_metadata
|
|
123
|
+
for v in variables:
|
|
124
|
+
if v in md:
|
|
125
|
+
result[v] = md[v]
|
|
126
|
+
|
|
127
|
+
assert len(result) == len(variables), (result, variables)
|
|
128
|
+
return result
|
|
129
|
+
|
|
114
130
|
@cached_property
|
|
115
131
|
def name_to_index(self):
|
|
116
132
|
return {k: i for i, k in enumerate(self.variables)}
|
anemoi/datasets/data/masked.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
|
-
# (C) Copyright 2024
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
2
3
|
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
4
|
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
4
6
|
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
7
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
8
|
# nor does it submit to any jurisdiction.
|
|
7
9
|
|
|
10
|
+
|
|
8
11
|
import logging
|
|
9
12
|
from functools import cached_property
|
|
10
13
|
|
|
@@ -30,6 +33,8 @@ class Masked(Forwards):
|
|
|
30
33
|
self.mask = mask
|
|
31
34
|
self.axis = 3
|
|
32
35
|
|
|
36
|
+
self.mask_name = f"{self.__class__.__name__.lower()}_mask"
|
|
37
|
+
|
|
33
38
|
@cached_property
|
|
34
39
|
def shape(self):
|
|
35
40
|
return self.forward.shape[:-1] + (np.count_nonzero(self.mask),)
|
|
@@ -64,26 +69,46 @@ class Masked(Forwards):
|
|
|
64
69
|
result = apply_index_to_slices_changes(result, changes)
|
|
65
70
|
return result
|
|
66
71
|
|
|
72
|
+
def collect_supporting_arrays(self, collected, *path):
|
|
73
|
+
super().collect_supporting_arrays(collected, *path)
|
|
74
|
+
collected.append((path, self.mask_name, self.mask))
|
|
75
|
+
|
|
67
76
|
|
|
68
77
|
class Thinning(Masked):
|
|
78
|
+
|
|
69
79
|
def __init__(self, forward, thinning, method):
|
|
70
80
|
self.thinning = thinning
|
|
71
81
|
self.method = method
|
|
72
82
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
83
|
+
if thinning is not None:
|
|
84
|
+
|
|
85
|
+
shape = forward.field_shape
|
|
86
|
+
if len(shape) != 2:
|
|
87
|
+
raise ValueError("Thinning only works latitude/longitude fields")
|
|
76
88
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
longitudes = longitudes[::thinning, ::thinning].flatten()
|
|
89
|
+
# Make a copy, so we read the data only once from zarr
|
|
90
|
+
forward_latitudes = forward.latitudes.copy()
|
|
91
|
+
forward_longitudes = forward.longitudes.copy()
|
|
81
92
|
|
|
82
|
-
|
|
83
|
-
|
|
93
|
+
latitudes = forward_latitudes.reshape(shape)
|
|
94
|
+
longitudes = forward_longitudes.reshape(shape)
|
|
95
|
+
latitudes = latitudes[::thinning, ::thinning].flatten()
|
|
96
|
+
longitudes = longitudes[::thinning, ::thinning].flatten()
|
|
97
|
+
|
|
98
|
+
# TODO: This is not very efficient
|
|
99
|
+
|
|
100
|
+
mask = [lat in latitudes and lon in longitudes for lat, lon in zip(forward_latitudes, forward_longitudes)]
|
|
101
|
+
mask = np.array(mask, dtype=bool)
|
|
102
|
+
else:
|
|
103
|
+
mask = None
|
|
84
104
|
|
|
85
105
|
super().__init__(forward, mask)
|
|
86
106
|
|
|
107
|
+
def mutate(self) -> Dataset:
|
|
108
|
+
if self.thinning is None:
|
|
109
|
+
return self.forward.mutate()
|
|
110
|
+
return super().mutate()
|
|
111
|
+
|
|
87
112
|
def tree(self):
|
|
88
113
|
return Node(self, [self.forward.tree()], thinning=self.thinning, method=self.method)
|
|
89
114
|
|
|
@@ -92,6 +117,7 @@ class Thinning(Masked):
|
|
|
92
117
|
|
|
93
118
|
|
|
94
119
|
class Cropping(Masked):
|
|
120
|
+
|
|
95
121
|
def __init__(self, forward, area):
|
|
96
122
|
from ..data import open_dataset
|
|
97
123
|
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from functools import cached_property
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from . import MissingDateError
|
|
17
|
+
from .debug import Node
|
|
18
|
+
from .debug import debug_indexing
|
|
19
|
+
from .forwards import Combined
|
|
20
|
+
from .indexing import apply_index_to_slices_changes
|
|
21
|
+
from .indexing import expand_list_indexing
|
|
22
|
+
from .indexing import index_to_slices
|
|
23
|
+
from .indexing import update_tuple
|
|
24
|
+
from .misc import _auto_adjust
|
|
25
|
+
from .misc import _open
|
|
26
|
+
|
|
27
|
+
LOG = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Merge(Combined):
|
|
31
|
+
|
|
32
|
+
# d0 d2 d4 d6 ...
|
|
33
|
+
# d1 d3 d5 d7 ...
|
|
34
|
+
|
|
35
|
+
# gives
|
|
36
|
+
# d0 d1 d2 d3 ...
|
|
37
|
+
|
|
38
|
+
def __init__(self, datasets, allow_gaps_in_dates=False):
|
|
39
|
+
super().__init__(datasets)
|
|
40
|
+
|
|
41
|
+
self.allow_gaps_in_dates = allow_gaps_in_dates
|
|
42
|
+
|
|
43
|
+
dates = dict() # date -> (dataset_index, date_index)
|
|
44
|
+
|
|
45
|
+
for i, d in enumerate(datasets):
|
|
46
|
+
for j, date in enumerate(d.dates):
|
|
47
|
+
date = date.astype(object)
|
|
48
|
+
if date in dates:
|
|
49
|
+
|
|
50
|
+
d1 = datasets[dates[date][0]] # Selected
|
|
51
|
+
d2 = datasets[i] # The new one
|
|
52
|
+
|
|
53
|
+
if j in d2.missing:
|
|
54
|
+
# LOG.warning(f"Duplicate date {date} found in datasets {d1} and {d2}, but {date} is missing in {d}, ignoring")
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
k = dates[date][1]
|
|
58
|
+
if k in d1.missing:
|
|
59
|
+
# LOG.warning(f"Duplicate date {date} found in datasets {d1} and {d2}, but {date} is missing in {d}, ignoring")
|
|
60
|
+
dates[date] = (i, j) # Replace the missing date with the new one
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
raise ValueError(f"Duplicate date {date} found in datasets {d1} and {d2}")
|
|
64
|
+
else:
|
|
65
|
+
dates[date] = (i, j)
|
|
66
|
+
|
|
67
|
+
all_dates = sorted(dates)
|
|
68
|
+
start = all_dates[0]
|
|
69
|
+
end = all_dates[-1]
|
|
70
|
+
|
|
71
|
+
frequency = min(d2 - d1 for d1, d2 in zip(all_dates[:-1], all_dates[1:]))
|
|
72
|
+
|
|
73
|
+
date = start
|
|
74
|
+
indices = []
|
|
75
|
+
_dates = []
|
|
76
|
+
|
|
77
|
+
self._missing_index = len(datasets)
|
|
78
|
+
|
|
79
|
+
while date <= end:
|
|
80
|
+
if date not in dates:
|
|
81
|
+
if self.allow_gaps_in_dates:
|
|
82
|
+
dates[date] = (self._missing_index, -1)
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"merge: date {date} not covered by dataset. Start={start}, end={end}, frequency={frequency}"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
indices.append(dates[date])
|
|
89
|
+
_dates.append(date)
|
|
90
|
+
date += frequency
|
|
91
|
+
|
|
92
|
+
self._dates = np.array(_dates, dtype="datetime64[s]")
|
|
93
|
+
self._indices = np.array(indices)
|
|
94
|
+
self._frequency = frequency # .astype(object)
|
|
95
|
+
|
|
96
|
+
def __len__(self):
|
|
97
|
+
return len(self._dates)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def dates(self):
|
|
101
|
+
return self._dates
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def frequency(self):
|
|
105
|
+
return self._frequency
|
|
106
|
+
|
|
107
|
+
@cached_property
|
|
108
|
+
def missing(self):
|
|
109
|
+
# TODO: optimize
|
|
110
|
+
result = set()
|
|
111
|
+
|
|
112
|
+
for i, (dataset, row) in enumerate(self._indices):
|
|
113
|
+
if dataset == self._missing_index:
|
|
114
|
+
result.add(i)
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
if row in self.datasets[dataset].missing:
|
|
118
|
+
result.add(i)
|
|
119
|
+
|
|
120
|
+
return result
|
|
121
|
+
|
|
122
|
+
def check_same_lengths(self, d1, d2):
|
|
123
|
+
# Turned off because we are concatenating along the first axis
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
def check_same_dates(self, d1, d2):
|
|
127
|
+
# Turned off because we are concatenating along the dates axis
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
def check_compatibility(self, d1, d2):
|
|
131
|
+
super().check_compatibility(d1, d2)
|
|
132
|
+
self.check_same_sub_shapes(d1, d2, drop_axis=0)
|
|
133
|
+
|
|
134
|
+
def tree(self):
|
|
135
|
+
return Node(self, [d.tree() for d in self.datasets], allow_gaps_in_dates=self.allow_gaps_in_dates)
|
|
136
|
+
|
|
137
|
+
@debug_indexing
|
|
138
|
+
def __getitem__(self, n):
|
|
139
|
+
if isinstance(n, tuple):
|
|
140
|
+
return self._get_tuple(n)
|
|
141
|
+
|
|
142
|
+
if isinstance(n, slice):
|
|
143
|
+
return self._get_slice(n)
|
|
144
|
+
|
|
145
|
+
dataset, row = self._indices[n]
|
|
146
|
+
|
|
147
|
+
if dataset == self._missing_index:
|
|
148
|
+
raise MissingDateError(f"Date {self.dates[n]} is missing (index={n})")
|
|
149
|
+
|
|
150
|
+
return self.datasets[dataset][int(row)]
|
|
151
|
+
|
|
152
|
+
@debug_indexing
|
|
153
|
+
@expand_list_indexing
|
|
154
|
+
def _get_tuple(self, index):
|
|
155
|
+
index, changes = index_to_slices(index, self.shape)
|
|
156
|
+
index, previous = update_tuple(index, 0, slice(None))
|
|
157
|
+
result = self._get_slice(previous)
|
|
158
|
+
return apply_index_to_slices_changes(result[index], changes)
|
|
159
|
+
|
|
160
|
+
def _get_slice(self, s):
|
|
161
|
+
return np.stack([self[i] for i in range(*s.indices(self._len))])
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def merge_factory(args, kwargs):
|
|
165
|
+
|
|
166
|
+
datasets = kwargs.pop("merge")
|
|
167
|
+
|
|
168
|
+
assert isinstance(datasets, (list, tuple))
|
|
169
|
+
assert len(args) == 0
|
|
170
|
+
|
|
171
|
+
datasets = [_open(e) for e in datasets]
|
|
172
|
+
|
|
173
|
+
if len(datasets) == 1:
|
|
174
|
+
return datasets[0]._subset(**kwargs)
|
|
175
|
+
|
|
176
|
+
datasets, kwargs = _auto_adjust(datasets, kwargs)
|
|
177
|
+
|
|
178
|
+
allow_gaps_in_dates = kwargs.pop("allow_gaps_in_dates", False)
|
|
179
|
+
|
|
180
|
+
return Merge(datasets, allow_gaps_in_dates=allow_gaps_in_dates)._subset(**kwargs)
|