anemoi-datasets 0.5.6__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. anemoi/datasets/__init__.py +11 -3
  2. anemoi/datasets/__main__.py +2 -3
  3. anemoi/datasets/_version.py +2 -2
  4. anemoi/datasets/commands/__init__.py +2 -3
  5. anemoi/datasets/commands/cleanup.py +9 -0
  6. anemoi/datasets/commands/compare.py +3 -3
  7. anemoi/datasets/commands/copy.py +38 -68
  8. anemoi/datasets/commands/create.py +20 -5
  9. anemoi/datasets/commands/finalise-additions.py +9 -0
  10. anemoi/datasets/commands/finalise.py +9 -0
  11. anemoi/datasets/commands/init-additions.py +9 -0
  12. anemoi/datasets/commands/init.py +9 -0
  13. anemoi/datasets/commands/inspect.py +7 -1
  14. anemoi/datasets/commands/load-additions.py +9 -0
  15. anemoi/datasets/commands/load.py +9 -0
  16. anemoi/datasets/commands/patch.py +9 -0
  17. anemoi/datasets/commands/publish.py +9 -0
  18. anemoi/datasets/commands/scan.py +9 -0
  19. anemoi/datasets/compute/__init__.py +8 -0
  20. anemoi/datasets/compute/recentre.py +3 -2
  21. anemoi/datasets/create/__init__.py +64 -48
  22. anemoi/datasets/create/check.py +4 -3
  23. anemoi/datasets/create/chunks.py +3 -2
  24. anemoi/datasets/create/config.py +5 -5
  25. anemoi/datasets/create/functions/__init__.py +22 -7
  26. anemoi/datasets/create/functions/filters/__init__.py +2 -1
  27. anemoi/datasets/create/functions/filters/empty.py +3 -2
  28. anemoi/datasets/create/functions/filters/noop.py +2 -2
  29. anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +3 -2
  30. anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +3 -2
  31. anemoi/datasets/create/functions/filters/rename.py +16 -10
  32. anemoi/datasets/create/functions/filters/rotate_winds.py +3 -2
  33. anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +3 -2
  34. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +3 -2
  35. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +2 -2
  36. anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +2 -2
  37. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +3 -2
  38. anemoi/datasets/create/functions/filters/unrotate_winds.py +3 -2
  39. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +3 -2
  40. anemoi/datasets/create/functions/sources/__init__.py +2 -2
  41. anemoi/datasets/create/functions/sources/accumulations.py +10 -4
  42. anemoi/datasets/create/functions/sources/constants.py +3 -2
  43. anemoi/datasets/create/functions/sources/empty.py +3 -2
  44. anemoi/datasets/create/functions/sources/forcings.py +3 -2
  45. anemoi/datasets/create/functions/sources/grib.py +2 -2
  46. anemoi/datasets/create/functions/sources/hindcasts.py +3 -2
  47. anemoi/datasets/create/functions/sources/mars.py +97 -17
  48. anemoi/datasets/create/functions/sources/netcdf.py +3 -2
  49. anemoi/datasets/create/functions/sources/opendap.py +2 -2
  50. anemoi/datasets/create/functions/sources/recentre.py +3 -2
  51. anemoi/datasets/create/functions/sources/source.py +3 -2
  52. anemoi/datasets/create/functions/sources/tendencies.py +3 -2
  53. anemoi/datasets/create/functions/sources/xarray/__init__.py +8 -2
  54. anemoi/datasets/create/functions/sources/xarray/coordinates.py +5 -2
  55. anemoi/datasets/create/functions/sources/xarray/field.py +3 -2
  56. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +12 -2
  57. anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -16
  58. anemoi/datasets/create/functions/sources/xarray/grid.py +3 -2
  59. anemoi/datasets/create/functions/sources/xarray/metadata.py +3 -2
  60. anemoi/datasets/create/functions/sources/xarray/time.py +39 -4
  61. anemoi/datasets/create/functions/sources/xarray/variable.py +6 -6
  62. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +2 -2
  63. anemoi/datasets/create/functions/sources/xarray_zarr.py +2 -2
  64. anemoi/datasets/create/functions/sources/zenodo.py +2 -2
  65. anemoi/datasets/create/input/__init__.py +3 -17
  66. anemoi/datasets/create/input/action.py +3 -2
  67. anemoi/datasets/create/input/concat.py +3 -2
  68. anemoi/datasets/create/input/context.py +3 -2
  69. anemoi/datasets/create/input/data_sources.py +3 -2
  70. anemoi/datasets/create/input/empty.py +3 -2
  71. anemoi/datasets/create/input/filter.py +3 -2
  72. anemoi/datasets/create/input/function.py +3 -2
  73. anemoi/datasets/create/input/join.py +3 -2
  74. anemoi/datasets/create/input/misc.py +3 -2
  75. anemoi/datasets/create/input/pipe.py +3 -2
  76. anemoi/datasets/create/input/repeated_dates.py +3 -2
  77. anemoi/datasets/create/input/result.py +187 -3
  78. anemoi/datasets/create/input/step.py +4 -2
  79. anemoi/datasets/create/input/template.py +3 -2
  80. anemoi/datasets/create/input/trace.py +3 -2
  81. anemoi/datasets/create/patch.py +9 -1
  82. anemoi/datasets/create/persistent.py +7 -3
  83. anemoi/datasets/create/size.py +3 -2
  84. anemoi/datasets/create/statistics/__init__.py +7 -3
  85. anemoi/datasets/create/statistics/summary.py +3 -2
  86. anemoi/datasets/create/utils.py +15 -2
  87. anemoi/datasets/create/writer.py +3 -2
  88. anemoi/datasets/create/zarr.py +8 -3
  89. anemoi/datasets/data/__init__.py +27 -1
  90. anemoi/datasets/data/concat.py +5 -1
  91. anemoi/datasets/data/dataset.py +216 -37
  92. anemoi/datasets/data/debug.py +4 -1
  93. anemoi/datasets/data/ensemble.py +4 -1
  94. anemoi/datasets/data/fill_missing.py +165 -0
  95. anemoi/datasets/data/forwards.py +27 -2
  96. anemoi/datasets/data/grids.py +236 -58
  97. anemoi/datasets/data/indexing.py +4 -1
  98. anemoi/datasets/data/interpolate.py +4 -1
  99. anemoi/datasets/data/join.py +17 -1
  100. anemoi/datasets/data/masked.py +36 -10
  101. anemoi/datasets/data/merge.py +180 -0
  102. anemoi/datasets/data/misc.py +18 -3
  103. anemoi/datasets/data/missing.py +4 -1
  104. anemoi/datasets/data/rescale.py +4 -1
  105. anemoi/datasets/data/select.py +15 -1
  106. anemoi/datasets/data/statistics.py +4 -1
  107. anemoi/datasets/data/stores.py +70 -3
  108. anemoi/datasets/data/subset.py +6 -1
  109. anemoi/datasets/data/unchecked.py +9 -1
  110. anemoi/datasets/data/xy.py +20 -5
  111. anemoi/datasets/dates/__init__.py +9 -7
  112. anemoi/datasets/dates/groups.py +3 -1
  113. anemoi/datasets/fields.py +3 -1
  114. anemoi/datasets/grids.py +86 -2
  115. anemoi/datasets/testing.py +60 -0
  116. anemoi/datasets/utils/__init__.py +8 -0
  117. anemoi/datasets/utils/fields.py +2 -2
  118. {anemoi_datasets-0.5.6.dist-info → anemoi_datasets-0.5.10.dist-info}/METADATA +11 -29
  119. anemoi_datasets-0.5.10.dist-info/RECORD +124 -0
  120. {anemoi_datasets-0.5.6.dist-info → anemoi_datasets-0.5.10.dist-info}/WHEEL +1 -1
  121. anemoi_datasets-0.5.6.dist-info/RECORD +0 -121
  122. {anemoi_datasets-0.5.6.dist-info → anemoi_datasets-0.5.10.dist-info}/LICENSE +0 -0
  123. {anemoi_datasets-0.5.6.dist-info → anemoi_datasets-0.5.10.dist-info}/entry_points.txt +0 -0
  124. {anemoi_datasets-0.5.6.dist-info → anemoi_datasets-0.5.10.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,18 @@
1
- # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
7
9
 
10
+
8
11
  import logging
9
12
  from functools import cached_property
10
13
 
11
14
  import numpy as np
15
+ from scipy.spatial import cKDTree
12
16
 
13
17
  from .debug import Node
14
18
  from .debug import debug_indexing
@@ -105,6 +109,17 @@ class GridsBase(GivenAxis):
105
109
  # We don't check the resolution, because we want to be able to combine
106
110
  pass
107
111
 
112
+ def metadata_specific(self):
113
+ return super().metadata_specific(
114
+ multi_grids=True,
115
+ )
116
+
117
+ def collect_input_sources(self, collected):
118
+ # We assume that,because they have different grids, they have different input sources
119
+ for d in self.datasets:
120
+ collected.append(d)
121
+ d.collect_input_sources(collected)
122
+
108
123
 
109
124
  class Grids(GridsBase):
110
125
  # TODO: select the statistics of the most global grid?
@@ -128,85 +143,248 @@ class Grids(GridsBase):
128
143
 
129
144
 
130
145
  class Cutout(GridsBase):
131
- def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False):
132
- from anemoi.datasets.grids import cutout_mask
133
-
146
+ def __init__(self, datasets, axis=3, cropping_distance=2.0, neighbours=5, min_distance_km=None, plot=None):
147
+ """Initializes a Cutout object for hierarchical management of Limited Area
148
+ Models (LAMs) and a global dataset, handling overlapping regions.
149
+
150
+ Args:
151
+ datasets (list): List of LAM and global datasets.
152
+ axis (int): Concatenation axis, must be set to 3.
153
+ cropping_distance (float): Distance threshold in degrees for
154
+ cropping cutouts.
155
+ neighbours (int): Number of neighboring points to consider when
156
+ constructing masks.
157
+ min_distance_km (float, optional): Minimum distance threshold in km
158
+ between grid points.
159
+ plot (bool, optional): Flag to enable or disable visualization
160
+ plots.
161
+ """
134
162
  super().__init__(datasets, axis)
135
- assert len(datasets) == 2, "CutoutGrids requires two datasets"
163
+ assert len(datasets) >= 2, "CutoutGrids requires at least two datasets"
136
164
  assert axis == 3, "CutoutGrids requires axis=3"
165
+ assert cropping_distance >= 0, "cropping_distance must be a non-negative number"
166
+ if min_distance_km is not None:
167
+ assert min_distance_km >= 0, "min_distance_km must be a non-negative number"
168
+
169
+ self.lams = datasets[:-1] # Assume the last dataset is the global one
170
+ self.globe = datasets[-1]
171
+ self.axis = axis
172
+ self.cropping_distance = cropping_distance
173
+ self.neighbours = neighbours
174
+ self.min_distance_km = min_distance_km
175
+ self.plot = plot
176
+ self.masks = [] # To store the masks for each LAM dataset
177
+ self.global_mask = np.ones(self.globe.shape[-1], dtype=bool)
178
+
179
+ # Initialize cumulative masks
180
+ self._initialize_masks()
181
+
182
+ def _initialize_masks(self):
183
+ """Generates hierarchical masks for each LAM dataset by excluding
184
+ overlapping regions with previous LAMs and creating a global mask for
185
+ the global dataset.
186
+
187
+ Raises:
188
+ ValueError: If the global mask dimension does not match the global
189
+ dataset grid points.
190
+ """
191
+ from anemoi.datasets.grids import cutout_mask
137
192
 
138
- # We assume that the LAM is the first dataset, and the global is the second
139
- # Note: the second fields does not really need to be global
140
-
141
- self.lam, self.globe = datasets
142
- self.mask = cutout_mask(
143
- self.lam.latitudes,
144
- self.lam.longitudes,
145
- self.globe.latitudes,
146
- self.globe.longitudes,
147
- plot=plot,
148
- min_distance_km=min_distance_km,
149
- cropping_distance=cropping_distance,
150
- neighbours=neighbours,
151
- )
152
- assert len(self.mask) == self.globe.shape[3], (
153
- len(self.mask),
154
- self.globe.shape[3],
155
- )
193
+ for i, lam in enumerate(self.lams):
194
+ assert len(lam.shape) == len(
195
+ self.globe.shape
196
+ ), "LAMs and global dataset must have the same number of dimensions"
197
+ lam_lats = lam.latitudes
198
+ lam_lons = lam.longitudes
199
+ # Create a mask for the global dataset excluding all LAM points
200
+ global_overlap_mask = cutout_mask(
201
+ lam.latitudes,
202
+ lam.longitudes,
203
+ self.globe.latitudes,
204
+ self.globe.longitudes,
205
+ plot=False,
206
+ min_distance_km=self.min_distance_km,
207
+ cropping_distance=self.cropping_distance,
208
+ neighbours=self.neighbours,
209
+ )
210
+
211
+ # Ensure the mask dimensions match the global grid points
212
+ if global_overlap_mask.shape[0] != self.globe.shape[-1]:
213
+ raise ValueError("Global mask dimension does not match global dataset grid " "points.")
214
+ self.global_mask[~global_overlap_mask] = False
215
+
216
+ # Create a mask for the LAM datasets hierarchically, excluding
217
+ # points from previous LAMs
218
+ lam_current_mask = np.ones(lam.shape[-1], dtype=bool)
219
+ if i > 0:
220
+ for j in range(i):
221
+ prev_lam = self.lams[j]
222
+ prev_lam_lats = prev_lam.latitudes
223
+ prev_lam_lons = prev_lam.longitudes
224
+ # Check for overlap by computing distances
225
+ if self.has_overlap(prev_lam_lats, prev_lam_lons, lam_lats, lam_lons):
226
+ lam_overlap_mask = cutout_mask(
227
+ prev_lam_lats,
228
+ prev_lam_lons,
229
+ lam_lats,
230
+ lam_lons,
231
+ plot=False,
232
+ min_distance_km=self.min_distance_km,
233
+ cropping_distance=self.cropping_distance,
234
+ neighbours=self.neighbours,
235
+ )
236
+ lam_current_mask[~lam_overlap_mask] = False
237
+ self.masks.append(lam_current_mask)
238
+
239
+ def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0):
240
+ """Checks for overlapping points between two sets of latitudes and
241
+ longitudes within a specified distance threshold.
242
+
243
+ Args:
244
+ lats1, lons1 (np.ndarray): Latitude and longitude arrays for the
245
+ first dataset.
246
+ lats2, lons2 (np.ndarray): Latitude and longitude arrays for the
247
+ second dataset.
248
+ distance_threshold (float): Distance in degrees to consider as
249
+ overlapping.
250
+
251
+ Returns:
252
+ bool: True if any points overlap within the distance threshold,
253
+ otherwise False.
254
+ """
255
+ # Create KDTree for the first set of points
256
+ tree = cKDTree(np.vstack((lats1, lons1)).T)
257
+
258
+ # Query the second set of points against the first tree
259
+ distances, _ = tree.query(np.vstack((lats2, lons2)).T, k=1)
260
+
261
+ # Check if any distance is less than the specified threshold
262
+ return np.any(distances < distance_threshold)
263
+
264
+ def __getitem__(self, index):
265
+ """Retrieves data from the masked LAMs and global dataset based on the
266
+ given index.
267
+
268
+ Args:
269
+ index (int or slice or tuple): Index specifying the data to
270
+ retrieve.
271
+
272
+ Returns:
273
+ np.ndarray: Data array from the masked datasets based on the index.
274
+ """
275
+ if isinstance(index, (int, slice)):
276
+ index = (index, slice(None), slice(None), slice(None))
277
+ return self._get_tuple(index)
278
+
279
+ def _get_tuple(self, index):
280
+ """Helper method that applies masks and retrieves data from each dataset
281
+ according to the specified index.
282
+
283
+ Args:
284
+ index (tuple): Index specifying slices to retrieve data.
285
+
286
+ Returns:
287
+ np.ndarray: Concatenated data array from all datasets based on the
288
+ index.
289
+ """
290
+ index, changes = index_to_slices(index, self.shape)
291
+ # Select data from each LAM
292
+ lam_data = [lam[index] for lam in self.lams]
293
+
294
+ # First apply spatial indexing on `self.globe` and then apply the mask
295
+ globe_data_sliced = self.globe[index[:3]]
296
+ globe_data = globe_data_sliced[..., self.global_mask]
297
+
298
+ # Concatenate LAM data with global data
299
+ result = np.concatenate(lam_data + [globe_data], axis=self.axis)
300
+ return apply_index_to_slices_changes(result, changes)
301
+
302
+ def collect_supporting_arrays(self, collected, *path):
303
+ """Collects supporting arrays, including masks for each LAM and the global
304
+ dataset.
305
+
306
+ Args:
307
+ collected (list): List to which the supporting arrays are appended.
308
+ *path: Variable length argument list specifying the paths for the masks.
309
+ """
310
+ # Append masks for each LAM
311
+ for i, (lam, mask) in enumerate(zip(self.lams, self.masks)):
312
+ collected.append((path + (f"lam_{i}",), "cutout_mask", mask))
313
+
314
+ # Append the global mask
315
+ collected.append((path + ("global",), "cutout_mask", self.global_mask))
156
316
 
157
317
  @cached_property
158
318
  def shape(self):
159
- shape = self.lam.shape
160
- # Number of non-zero masked values in the globe dataset
161
- nb_globe = np.count_nonzero(self.mask)
162
- return shape[:-1] + (shape[-1] + nb_globe,)
319
+ """Returns the shape of the Cutout, accounting for retained grid points
320
+ across all LAMs and the global dataset.
321
+
322
+ Returns:
323
+ tuple: Shape of the concatenated masked datasets.
324
+ """
325
+ shapes = [np.sum(mask) for mask in self.masks]
326
+ global_shape = np.sum(self.global_mask)
327
+ return tuple(self.lams[0].shape[:-1] + (sum(shapes) + global_shape,))
163
328
 
164
329
  def check_same_resolution(self, d1, d2):
165
330
  # Turned off because we are combining different resolutions
166
331
  pass
167
332
 
168
333
  @property
169
- def latitudes(self):
170
- return np.concatenate([self.lam.latitudes, self.globe.latitudes[self.mask]])
334
+ def grids(self):
335
+ """Returns the number of grid points for each LAM and the global dataset
336
+ after applying masks.
171
337
 
172
- @property
173
- def longitudes(self):
174
- return np.concatenate([self.lam.longitudes, self.globe.longitudes[self.mask]])
338
+ Returns:
339
+ tuple: Count of retained grid points for each dataset.
340
+ """
341
+ grids = [np.sum(mask) for mask in self.masks]
342
+ grids.append(np.sum(self.global_mask))
343
+ return tuple(grids)
175
344
 
176
- def __getitem__(self, index):
177
- if isinstance(index, (int, slice)):
178
- index = (index, slice(None), slice(None), slice(None))
179
- return self._get_tuple(index)
345
+ @property
346
+ def latitudes(self):
347
+ """Returns the concatenated latitudes of each LAM and the global dataset
348
+ after applying masks.
180
349
 
181
- @debug_indexing
182
- @expand_list_indexing
183
- def _get_tuple(self, index):
184
- assert self.axis >= len(index) or index[self.axis] == slice(
185
- None
186
- ), f"No support for selecting a subset of the 1D values {index} ({self.tree()})"
187
- index, changes = index_to_slices(index, self.shape)
350
+ Returns:
351
+ np.ndarray: Concatenated latitude array for the masked datasets.
352
+ """
353
+ lam_latitudes = np.concatenate([lam.latitudes[mask] for lam, mask in zip(self.lams, self.masks)])
188
354
 
189
- # In case index_to_slices has changed the last slice
190
- index, _ = update_tuple(index, self.axis, slice(None))
355
+ assert (
356
+ len(lam_latitudes) + len(self.globe.latitudes[self.global_mask]) == self.shape[-1]
357
+ ), "Mismatch in number of latitudes"
191
358
 
192
- lam_data = self.lam[index]
193
- globe_data = self.globe[index]
359
+ latitudes = np.concatenate([lam_latitudes, self.globe.latitudes[self.global_mask]])
360
+ return latitudes
194
361
 
195
- globe_data = globe_data[:, :, :, self.mask]
362
+ @property
363
+ def longitudes(self):
364
+ """Returns the concatenated longitudes of each LAM and the global dataset
365
+ after applying masks.
196
366
 
197
- result = np.concatenate([lam_data, globe_data], axis=self.axis)
367
+ Returns:
368
+ np.ndarray: Concatenated longitude array for the masked datasets.
369
+ """
370
+ lam_longitudes = np.concatenate([lam.longitudes[mask] for lam, mask in zip(self.lams, self.masks)])
198
371
 
199
- return apply_index_to_slices_changes(result, changes)
372
+ assert (
373
+ len(lam_longitudes) + len(self.globe.longitudes[self.global_mask]) == self.shape[-1]
374
+ ), "Mismatch in number of longitudes"
200
375
 
201
- @property
202
- def grids(self):
203
- for d in self.datasets:
204
- if len(d.grids) > 1:
205
- raise NotImplementedError("CutoutGrids does not support multi-grids datasets as inputs")
206
- shape = self.lam.shape
207
- return (shape[-1], self.shape[-1] - shape[-1])
376
+ longitudes = np.concatenate([lam_longitudes, self.globe.longitudes[self.global_mask]])
377
+ return longitudes
208
378
 
209
379
  def tree(self):
380
+ """Generates a hierarchical tree structure for the `Cutout` instance and
381
+ its associated datasets.
382
+
383
+ Returns:
384
+ Node: A `Node` object representing the `Cutout` instance as the root
385
+ node, with each dataset in `self.datasets` represented as a child
386
+ node.
387
+ """
210
388
  return Node(self, [d.tree() for d in self.datasets])
211
389
 
212
390
 
@@ -238,7 +416,7 @@ def cutout_factory(args, kwargs):
238
416
  neighbours = kwargs.pop("neighbours", 5)
239
417
 
240
418
  assert len(args) == 0
241
- assert isinstance(cutout, (list, tuple))
419
+ assert isinstance(cutout, (list, tuple)), "cutout must be a list or tuple"
242
420
 
243
421
  datasets = [_open(e) for e in cutout]
244
422
  datasets, kwargs = _auto_adjust(datasets, kwargs)
@@ -1,10 +1,13 @@
1
- # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
7
9
 
10
+
8
11
  from functools import wraps
9
12
 
10
13
  import numpy as np
@@ -1,10 +1,13 @@
1
- # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
7
9
 
10
+
8
11
  import logging
9
12
  from functools import cached_property
10
13
 
@@ -1,10 +1,13 @@
1
- # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
7
9
 
10
+
8
11
  import logging
9
12
  from functools import cached_property
10
13
 
@@ -111,6 +114,19 @@ class Join(Combined):
111
114
 
112
115
  return result
113
116
 
117
+ @property
118
+ def variables_metadata(self):
119
+ result = {}
120
+ variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))]
121
+ for d in self.datasets:
122
+ md = d.variables_metadata
123
+ for v in variables:
124
+ if v in md:
125
+ result[v] = md[v]
126
+
127
+ assert len(result) == len(variables), (result, variables)
128
+ return result
129
+
114
130
  @cached_property
115
131
  def name_to_index(self):
116
132
  return {k: i for i, k in enumerate(self.variables)}
@@ -1,10 +1,13 @@
1
- # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
7
9
 
10
+
8
11
  import logging
9
12
  from functools import cached_property
10
13
 
@@ -30,6 +33,8 @@ class Masked(Forwards):
30
33
  self.mask = mask
31
34
  self.axis = 3
32
35
 
36
+ self.mask_name = f"{self.__class__.__name__.lower()}_mask"
37
+
33
38
  @cached_property
34
39
  def shape(self):
35
40
  return self.forward.shape[:-1] + (np.count_nonzero(self.mask),)
@@ -64,26 +69,46 @@ class Masked(Forwards):
64
69
  result = apply_index_to_slices_changes(result, changes)
65
70
  return result
66
71
 
72
+ def collect_supporting_arrays(self, collected, *path):
73
+ super().collect_supporting_arrays(collected, *path)
74
+ collected.append((path, self.mask_name, self.mask))
75
+
67
76
 
68
77
  class Thinning(Masked):
78
+
69
79
  def __init__(self, forward, thinning, method):
70
80
  self.thinning = thinning
71
81
  self.method = method
72
82
 
73
- shape = forward.field_shape
74
- if len(shape) != 2:
75
- raise ValueError("Thinning only works latitude/longitude fields")
83
+ if thinning is not None:
84
+
85
+ shape = forward.field_shape
86
+ if len(shape) != 2:
87
+ raise ValueError("Thinning only works latitude/longitude fields")
76
88
 
77
- latitudes = forward.latitudes.reshape(shape)
78
- longitudes = forward.longitudes.reshape(shape)
79
- latitudes = latitudes[::thinning, ::thinning].flatten()
80
- longitudes = longitudes[::thinning, ::thinning].flatten()
89
+ # Make a copy, so we read the data only once from zarr
90
+ forward_latitudes = forward.latitudes.copy()
91
+ forward_longitudes = forward.longitudes.copy()
81
92
 
82
- mask = [lat in latitudes and lon in longitudes for lat, lon in zip(forward.latitudes, forward.longitudes)]
83
- mask = np.array(mask, dtype=bool)
93
+ latitudes = forward_latitudes.reshape(shape)
94
+ longitudes = forward_longitudes.reshape(shape)
95
+ latitudes = latitudes[::thinning, ::thinning].flatten()
96
+ longitudes = longitudes[::thinning, ::thinning].flatten()
97
+
98
+ # TODO: This is not very efficient
99
+
100
+ mask = [lat in latitudes and lon in longitudes for lat, lon in zip(forward_latitudes, forward_longitudes)]
101
+ mask = np.array(mask, dtype=bool)
102
+ else:
103
+ mask = None
84
104
 
85
105
  super().__init__(forward, mask)
86
106
 
107
+ def mutate(self) -> Dataset:
108
+ if self.thinning is None:
109
+ return self.forward.mutate()
110
+ return super().mutate()
111
+
87
112
  def tree(self):
88
113
  return Node(self, [self.forward.tree()], thinning=self.thinning, method=self.method)
89
114
 
@@ -92,6 +117,7 @@ class Thinning(Masked):
92
117
 
93
118
 
94
119
  class Cropping(Masked):
120
+
95
121
  def __init__(self, forward, area):
96
122
  from ..data import open_dataset
97
123
 
@@ -0,0 +1,180 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import logging
12
+ from functools import cached_property
13
+
14
+ import numpy as np
15
+
16
+ from . import MissingDateError
17
+ from .debug import Node
18
+ from .debug import debug_indexing
19
+ from .forwards import Combined
20
+ from .indexing import apply_index_to_slices_changes
21
+ from .indexing import expand_list_indexing
22
+ from .indexing import index_to_slices
23
+ from .indexing import update_tuple
24
+ from .misc import _auto_adjust
25
+ from .misc import _open
26
+
27
+ LOG = logging.getLogger(__name__)
28
+
29
+
30
+ class Merge(Combined):
31
+
32
+ # d0 d2 d4 d6 ...
33
+ # d1 d3 d5 d7 ...
34
+
35
+ # gives
36
+ # d0 d1 d2 d3 ...
37
+
38
+ def __init__(self, datasets, allow_gaps_in_dates=False):
39
+ super().__init__(datasets)
40
+
41
+ self.allow_gaps_in_dates = allow_gaps_in_dates
42
+
43
+ dates = dict() # date -> (dataset_index, date_index)
44
+
45
+ for i, d in enumerate(datasets):
46
+ for j, date in enumerate(d.dates):
47
+ date = date.astype(object)
48
+ if date in dates:
49
+
50
+ d1 = datasets[dates[date][0]] # Selected
51
+ d2 = datasets[i] # The new one
52
+
53
+ if j in d2.missing:
54
+ # LOG.warning(f"Duplicate date {date} found in datasets {d1} and {d2}, but {date} is missing in {d}, ignoring")
55
+ continue
56
+
57
+ k = dates[date][1]
58
+ if k in d1.missing:
59
+ # LOG.warning(f"Duplicate date {date} found in datasets {d1} and {d2}, but {date} is missing in {d}, ignoring")
60
+ dates[date] = (i, j) # Replace the missing date with the new one
61
+ continue
62
+
63
+ raise ValueError(f"Duplicate date {date} found in datasets {d1} and {d2}")
64
+ else:
65
+ dates[date] = (i, j)
66
+
67
+ all_dates = sorted(dates)
68
+ start = all_dates[0]
69
+ end = all_dates[-1]
70
+
71
+ frequency = min(d2 - d1 for d1, d2 in zip(all_dates[:-1], all_dates[1:]))
72
+
73
+ date = start
74
+ indices = []
75
+ _dates = []
76
+
77
+ self._missing_index = len(datasets)
78
+
79
+ while date <= end:
80
+ if date not in dates:
81
+ if self.allow_gaps_in_dates:
82
+ dates[date] = (self._missing_index, -1)
83
+ else:
84
+ raise ValueError(
85
+ f"merge: date {date} not covered by dataset. Start={start}, end={end}, frequency={frequency}"
86
+ )
87
+
88
+ indices.append(dates[date])
89
+ _dates.append(date)
90
+ date += frequency
91
+
92
+ self._dates = np.array(_dates, dtype="datetime64[s]")
93
+ self._indices = np.array(indices)
94
+ self._frequency = frequency # .astype(object)
95
+
96
+ def __len__(self):
97
+ return len(self._dates)
98
+
99
+ @property
100
+ def dates(self):
101
+ return self._dates
102
+
103
+ @property
104
+ def frequency(self):
105
+ return self._frequency
106
+
107
+ @cached_property
108
+ def missing(self):
109
+ # TODO: optimize
110
+ result = set()
111
+
112
+ for i, (dataset, row) in enumerate(self._indices):
113
+ if dataset == self._missing_index:
114
+ result.add(i)
115
+ continue
116
+
117
+ if row in self.datasets[dataset].missing:
118
+ result.add(i)
119
+
120
+ return result
121
+
122
+ def check_same_lengths(self, d1, d2):
123
+ # Turned off because we are concatenating along the first axis
124
+ pass
125
+
126
+ def check_same_dates(self, d1, d2):
127
+ # Turned off because we are concatenating along the dates axis
128
+ pass
129
+
130
+ def check_compatibility(self, d1, d2):
131
+ super().check_compatibility(d1, d2)
132
+ self.check_same_sub_shapes(d1, d2, drop_axis=0)
133
+
134
+ def tree(self):
135
+ return Node(self, [d.tree() for d in self.datasets], allow_gaps_in_dates=self.allow_gaps_in_dates)
136
+
137
+ @debug_indexing
138
+ def __getitem__(self, n):
139
+ if isinstance(n, tuple):
140
+ return self._get_tuple(n)
141
+
142
+ if isinstance(n, slice):
143
+ return self._get_slice(n)
144
+
145
+ dataset, row = self._indices[n]
146
+
147
+ if dataset == self._missing_index:
148
+ raise MissingDateError(f"Date {self.dates[n]} is missing (index={n})")
149
+
150
+ return self.datasets[dataset][int(row)]
151
+
152
+ @debug_indexing
153
+ @expand_list_indexing
154
+ def _get_tuple(self, index):
155
+ index, changes = index_to_slices(index, self.shape)
156
+ index, previous = update_tuple(index, 0, slice(None))
157
+ result = self._get_slice(previous)
158
+ return apply_index_to_slices_changes(result[index], changes)
159
+
160
+ def _get_slice(self, s):
161
+ return np.stack([self[i] for i in range(*s.indices(self._len))])
162
+
163
+
164
+ def merge_factory(args, kwargs):
165
+
166
+ datasets = kwargs.pop("merge")
167
+
168
+ assert isinstance(datasets, (list, tuple))
169
+ assert len(args) == 0
170
+
171
+ datasets = [_open(e) for e in datasets]
172
+
173
+ if len(datasets) == 1:
174
+ return datasets[0]._subset(**kwargs)
175
+
176
+ datasets, kwargs = _auto_adjust(datasets, kwargs)
177
+
178
+ allow_gaps_in_dates = kwargs.pop("allow_gaps_in_dates", False)
179
+
180
+ return Merge(datasets, allow_gaps_in_dates=allow_gaps_in_dates)._subset(**kwargs)