anemoi-datasets 0.5.11__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/create/__init__.py +8 -4
- anemoi/datasets/create/check.py +1 -1
- anemoi/datasets/create/functions/__init__.py +15 -1
- anemoi/datasets/create/functions/filters/orog_to_z.py +58 -0
- anemoi/datasets/create/functions/filters/sum.py +71 -0
- anemoi/datasets/create/functions/filters/wz_to_w.py +79 -0
- anemoi/datasets/create/functions/sources/accumulations.py +1 -0
- anemoi/datasets/create/functions/sources/xarray/__init__.py +3 -3
- anemoi/datasets/create/functions/sources/xarray/field.py +5 -1
- anemoi/datasets/create/functions/sources/xarray/fieldlist.py +10 -1
- anemoi/datasets/create/functions/sources/xarray/metadata.py +5 -11
- anemoi/datasets/create/functions/sources/xarray/patch.py +44 -0
- anemoi/datasets/create/functions/sources/xarray/time.py +15 -0
- anemoi/datasets/create/functions/sources/xarray/variable.py +18 -2
- anemoi/datasets/create/input/repeated_dates.py +18 -0
- anemoi/datasets/create/statistics/__init__.py +2 -2
- anemoi/datasets/create/utils.py +4 -0
- anemoi/datasets/data/complement.py +164 -0
- anemoi/datasets/data/dataset.py +74 -6
- anemoi/datasets/data/ensemble.py +55 -0
- anemoi/datasets/data/grids.py +6 -4
- anemoi/datasets/data/join.py +7 -1
- anemoi/datasets/data/merge.py +3 -0
- anemoi/datasets/data/misc.py +10 -1
- anemoi/datasets/grids.py +23 -10
- {anemoi_datasets-0.5.11.dist-info → anemoi_datasets-0.5.13.dist-info}/METADATA +27 -28
- {anemoi_datasets-0.5.11.dist-info → anemoi_datasets-0.5.13.dist-info}/RECORD +32 -27
- {anemoi_datasets-0.5.11.dist-info → anemoi_datasets-0.5.13.dist-info}/WHEEL +1 -1
- {anemoi_datasets-0.5.11.dist-info → anemoi_datasets-0.5.13.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.5.11.dist-info → anemoi_datasets-0.5.13.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.11.dist-info → anemoi_datasets-0.5.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from functools import cached_property
|
|
13
|
+
|
|
14
|
+
from ..grids import nearest_grid_points
|
|
15
|
+
from .debug import Node
|
|
16
|
+
from .forwards import Combined
|
|
17
|
+
from .indexing import apply_index_to_slices_changes
|
|
18
|
+
from .indexing import index_to_slices
|
|
19
|
+
from .indexing import update_tuple
|
|
20
|
+
from .misc import _auto_adjust
|
|
21
|
+
from .misc import _open
|
|
22
|
+
|
|
23
|
+
LOG = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Complement(Combined):
|
|
27
|
+
|
|
28
|
+
def __init__(self, target, source, what="variables", interpolation="nearest"):
|
|
29
|
+
super().__init__([target, source])
|
|
30
|
+
|
|
31
|
+
# We had the variables of dataset[1] to dataset[0]
|
|
32
|
+
# interpoated on the grid of dataset[0]
|
|
33
|
+
|
|
34
|
+
self.target = target
|
|
35
|
+
self.source = source
|
|
36
|
+
|
|
37
|
+
self._variables = []
|
|
38
|
+
|
|
39
|
+
# Keep the same order as the original dataset
|
|
40
|
+
for v in self.source.variables:
|
|
41
|
+
if v not in self.target.variables:
|
|
42
|
+
self._variables.append(v)
|
|
43
|
+
|
|
44
|
+
if not self._variables:
|
|
45
|
+
raise ValueError("Augment: no missing variables")
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def variables(self):
|
|
49
|
+
return self._variables
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def name_to_index(self):
|
|
53
|
+
return {v: i for i, v in enumerate(self.variables)}
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def shape(self):
|
|
57
|
+
shape = self.target.shape
|
|
58
|
+
return (shape[0], len(self._variables)) + shape[2:]
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def variables_metadata(self):
|
|
62
|
+
return {k: v for k, v in self.source.variables_metadata.items() if k in self._variables}
|
|
63
|
+
|
|
64
|
+
def check_same_variables(self, d1, d2):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@cached_property
|
|
68
|
+
def missing(self):
|
|
69
|
+
missing = self.source.missing.copy()
|
|
70
|
+
missing = missing | self.target.missing
|
|
71
|
+
return set(missing)
|
|
72
|
+
|
|
73
|
+
def tree(self):
|
|
74
|
+
"""Generates a hierarchical tree structure for the `Cutout` instance and
|
|
75
|
+
its associated datasets.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Node: A `Node` object representing the `Cutout` instance as the root
|
|
79
|
+
node, with each dataset in `self.datasets` represented as a child
|
|
80
|
+
node.
|
|
81
|
+
"""
|
|
82
|
+
return Node(self, [d.tree() for d in (self.target, self.source)])
|
|
83
|
+
|
|
84
|
+
def __getitem__(self, index):
|
|
85
|
+
if isinstance(index, (int, slice)):
|
|
86
|
+
index = (index, slice(None), slice(None), slice(None))
|
|
87
|
+
return self._get_tuple(index)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ComplementNone(Complement):
|
|
91
|
+
|
|
92
|
+
def __init__(self, target, source):
|
|
93
|
+
super().__init__(target, source)
|
|
94
|
+
|
|
95
|
+
def _get_tuple(self, index):
|
|
96
|
+
index, changes = index_to_slices(index, self.shape)
|
|
97
|
+
result = self.source[index]
|
|
98
|
+
return apply_index_to_slices_changes(result, changes)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ComplementNearest(Complement):
|
|
102
|
+
|
|
103
|
+
def __init__(self, target, source):
|
|
104
|
+
super().__init__(target, source)
|
|
105
|
+
|
|
106
|
+
self._nearest_grid_points = nearest_grid_points(
|
|
107
|
+
self.source.latitudes,
|
|
108
|
+
self.source.longitudes,
|
|
109
|
+
self.target.latitudes,
|
|
110
|
+
self.target.longitudes,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def check_compatibility(self, d1, d2):
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
def _get_tuple(self, index):
|
|
117
|
+
variable_index = 1
|
|
118
|
+
index, changes = index_to_slices(index, self.shape)
|
|
119
|
+
index, previous = update_tuple(index, variable_index, slice(None))
|
|
120
|
+
source_index = [self.source.name_to_index[x] for x in self.variables[previous]]
|
|
121
|
+
source_data = self.source[index[0], source_index, index[2], ...]
|
|
122
|
+
target_data = source_data[..., self._nearest_grid_points]
|
|
123
|
+
|
|
124
|
+
result = target_data[..., index[3]]
|
|
125
|
+
|
|
126
|
+
return apply_index_to_slices_changes(result, changes)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def complement_factory(args, kwargs):
|
|
130
|
+
from .select import Select
|
|
131
|
+
|
|
132
|
+
assert len(args) == 0, args
|
|
133
|
+
|
|
134
|
+
source = kwargs.pop("source")
|
|
135
|
+
target = kwargs.pop("complement")
|
|
136
|
+
what = kwargs.pop("what", "variables")
|
|
137
|
+
interpolation = kwargs.pop("interpolation", "none")
|
|
138
|
+
|
|
139
|
+
if what != "variables":
|
|
140
|
+
raise NotImplementedError(f"Complement what={what} not implemented")
|
|
141
|
+
|
|
142
|
+
if interpolation not in ("none", "nearest"):
|
|
143
|
+
raise NotImplementedError(f"Complement method={interpolation} not implemented")
|
|
144
|
+
|
|
145
|
+
source = _open(source)
|
|
146
|
+
target = _open(target)
|
|
147
|
+
# `select` is the same as `variables`
|
|
148
|
+
(source, target), kwargs = _auto_adjust([source, target], kwargs, exclude=["select"])
|
|
149
|
+
|
|
150
|
+
Class = {
|
|
151
|
+
None: ComplementNone,
|
|
152
|
+
"none": ComplementNone,
|
|
153
|
+
"nearest": ComplementNearest,
|
|
154
|
+
}[interpolation]
|
|
155
|
+
|
|
156
|
+
complement = Class(target=target, source=source)._subset(**kwargs)
|
|
157
|
+
|
|
158
|
+
# Will join the datasets along the variables axis
|
|
159
|
+
reorder = source.variables
|
|
160
|
+
complemented = _open([target, complement])
|
|
161
|
+
ordered = (
|
|
162
|
+
Select(complemented, complemented._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate()
|
|
163
|
+
)
|
|
164
|
+
return ordered
|
anemoi/datasets/data/dataset.py
CHANGED
|
@@ -15,6 +15,7 @@ import pprint
|
|
|
15
15
|
import warnings
|
|
16
16
|
from functools import cached_property
|
|
17
17
|
|
|
18
|
+
import numpy as np
|
|
18
19
|
from anemoi.utils.dates import frequency_to_seconds
|
|
19
20
|
from anemoi.utils.dates import frequency_to_string
|
|
20
21
|
from anemoi.utils.dates import frequency_to_timedelta
|
|
@@ -42,6 +43,9 @@ def _tidy(v):
|
|
|
42
43
|
if isinstance(v, slice):
|
|
43
44
|
return (v.start, v.stop, v.step)
|
|
44
45
|
|
|
46
|
+
if isinstance(v, np.integer):
|
|
47
|
+
return int(v)
|
|
48
|
+
|
|
45
49
|
return v
|
|
46
50
|
|
|
47
51
|
|
|
@@ -164,6 +168,16 @@ class Dataset:
|
|
|
164
168
|
bbox = kwargs.pop("area")
|
|
165
169
|
return Cropping(self, bbox)._subset(**kwargs).mutate()
|
|
166
170
|
|
|
171
|
+
if "number" in kwargs or "numbers" or "member" in kwargs or "members" in kwargs:
|
|
172
|
+
from .ensemble import Number
|
|
173
|
+
|
|
174
|
+
members = {}
|
|
175
|
+
for key in ["number", "numbers", "member", "members"]:
|
|
176
|
+
if key in kwargs:
|
|
177
|
+
members[key] = kwargs.pop(key)
|
|
178
|
+
|
|
179
|
+
return Number(self, **members)._subset(**kwargs).mutate()
|
|
180
|
+
|
|
167
181
|
if "set_missing_dates" in kwargs:
|
|
168
182
|
from .missing import MissingDates
|
|
169
183
|
|
|
@@ -241,18 +255,25 @@ class Dataset:
|
|
|
241
255
|
if not isinstance(vars, (list, tuple, set)):
|
|
242
256
|
vars = [vars]
|
|
243
257
|
|
|
244
|
-
|
|
258
|
+
if not set(vars) <= set(self.name_to_index):
|
|
259
|
+
raise ValueError(f"drop: unknown variables: {set(vars) - set(self.name_to_index)}")
|
|
245
260
|
|
|
246
261
|
return sorted([v for k, v in self.name_to_index.items() if k not in vars])
|
|
247
262
|
|
|
248
263
|
def _reorder_to_columns(self, vars):
|
|
264
|
+
if isinstance(vars, str) and vars == "sort":
|
|
265
|
+
# Sorting the variables alphabetically.
|
|
266
|
+
# This is cruical for pre-training then transfer learning in combination with
|
|
267
|
+
# cutout and adjust = 'all'
|
|
268
|
+
|
|
269
|
+
indices = [self.name_to_index[k] for k, v in sorted(self.name_to_index.items(), key=lambda x: x[0])]
|
|
270
|
+
assert set(indices) == set(range(len(self.name_to_index)))
|
|
271
|
+
return indices
|
|
272
|
+
|
|
249
273
|
if isinstance(vars, (list, tuple)):
|
|
250
274
|
vars = {k: i for i, k in enumerate(vars)}
|
|
251
275
|
|
|
252
|
-
indices = []
|
|
253
|
-
|
|
254
|
-
for k, v in sorted(vars.items(), key=lambda x: x[1]):
|
|
255
|
-
indices.append(self.name_to_index[k])
|
|
276
|
+
indices = [self.name_to_index[k] for k, v in sorted(vars.items(), key=lambda x: x[1])]
|
|
256
277
|
|
|
257
278
|
# Make sure we don't forget any variables
|
|
258
279
|
assert set(indices) == set(range(len(self.name_to_index)))
|
|
@@ -464,7 +485,7 @@ class Dataset:
|
|
|
464
485
|
sample_count = min(4, len(indices))
|
|
465
486
|
count = len(indices)
|
|
466
487
|
|
|
467
|
-
p = slice(0, count, count // (sample_count - 1))
|
|
488
|
+
p = slice(0, count, count // max(1, sample_count - 1))
|
|
468
489
|
samples = list(range(*p.indices(count)))
|
|
469
490
|
|
|
470
491
|
samples.append(count - 1) # Add last
|
|
@@ -497,3 +518,50 @@ class Dataset:
|
|
|
497
518
|
result.append(v)
|
|
498
519
|
|
|
499
520
|
return result
|
|
521
|
+
|
|
522
|
+
def plot(self, date, variable, member=0, **kwargs):
|
|
523
|
+
"""For debugging purposes, plot a field.
|
|
524
|
+
|
|
525
|
+
Parameters
|
|
526
|
+
----------
|
|
527
|
+
date : int or datetime.datetime or numpy.datetime64 or str
|
|
528
|
+
The date to plot.
|
|
529
|
+
variable : int or str
|
|
530
|
+
The variable to plot.
|
|
531
|
+
member : int, optional
|
|
532
|
+
The ensemble member to plot.
|
|
533
|
+
|
|
534
|
+
**kwargs:
|
|
535
|
+
Additional arguments to pass to matplotlib.pyplot.tricontourf
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
Returns
|
|
539
|
+
-------
|
|
540
|
+
matplotlib.pyplot.Axes
|
|
541
|
+
"""
|
|
542
|
+
|
|
543
|
+
from anemoi.utils.devtools import plot_values
|
|
544
|
+
from earthkit.data.utils.dates import to_datetime
|
|
545
|
+
|
|
546
|
+
if not isinstance(date, int):
|
|
547
|
+
date = np.datetime64(to_datetime(date)).astype(self.dates[0].dtype)
|
|
548
|
+
index = np.where(self.dates == date)[0]
|
|
549
|
+
if len(index) == 0:
|
|
550
|
+
raise ValueError(
|
|
551
|
+
f"Date {date} not found in the dataset {self.dates[0]} to {self.dates[-1]} by {self.frequency}"
|
|
552
|
+
)
|
|
553
|
+
date_index = index[0]
|
|
554
|
+
else:
|
|
555
|
+
date_index = date
|
|
556
|
+
|
|
557
|
+
if isinstance(variable, int):
|
|
558
|
+
variable_index = variable
|
|
559
|
+
else:
|
|
560
|
+
if variable not in self.variables:
|
|
561
|
+
raise ValueError(f"Unknown variable {variable} (available: {self.variables})")
|
|
562
|
+
|
|
563
|
+
variable_index = self.name_to_index[variable]
|
|
564
|
+
|
|
565
|
+
values = self[date_index, variable_index, member]
|
|
566
|
+
|
|
567
|
+
return plot_values(values, self.latitudes, self.longitudes, **kwargs)
|
anemoi/datasets/data/ensemble.py
CHANGED
|
@@ -10,13 +10,68 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
13
15
|
from .debug import Node
|
|
16
|
+
from .forwards import Forwards
|
|
14
17
|
from .forwards import GivenAxis
|
|
18
|
+
from .indexing import apply_index_to_slices_changes
|
|
19
|
+
from .indexing import index_to_slices
|
|
20
|
+
from .indexing import update_tuple
|
|
15
21
|
from .misc import _auto_adjust
|
|
16
22
|
from .misc import _open
|
|
17
23
|
|
|
18
24
|
LOG = logging.getLogger(__name__)
|
|
19
25
|
|
|
26
|
+
OFFSETS = dict(number=1, numbers=1, member=0, members=0)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Number(Forwards):
|
|
30
|
+
def __init__(self, forward, **kwargs):
|
|
31
|
+
super().__init__(forward)
|
|
32
|
+
|
|
33
|
+
self.members = []
|
|
34
|
+
for key, values in kwargs.items():
|
|
35
|
+
if not isinstance(values, (list, tuple)):
|
|
36
|
+
values = [values]
|
|
37
|
+
self.members.extend([int(v) - OFFSETS[key] for v in values])
|
|
38
|
+
|
|
39
|
+
self.members = sorted(set(self.members))
|
|
40
|
+
for n in self.members:
|
|
41
|
+
if not (0 <= n < forward.shape[2]):
|
|
42
|
+
raise ValueError(f"Member {n} is out of range. `number(s)` is one-based, `member(s)` is zero-based.")
|
|
43
|
+
|
|
44
|
+
self.mask = np.array([n in self.members for n in range(forward.shape[2])], dtype=bool)
|
|
45
|
+
self._shape, _ = update_tuple(forward.shape, 2, len(self.members))
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def shape(self):
|
|
49
|
+
return self._shape
|
|
50
|
+
|
|
51
|
+
def __getitem__(self, index):
|
|
52
|
+
if isinstance(index, int):
|
|
53
|
+
result = self.forward[index]
|
|
54
|
+
result = result[:, self.mask, :]
|
|
55
|
+
return result
|
|
56
|
+
|
|
57
|
+
if isinstance(index, slice):
|
|
58
|
+
result = self.forward[index]
|
|
59
|
+
result = result[:, :, self.mask, :]
|
|
60
|
+
return result
|
|
61
|
+
|
|
62
|
+
index, changes = index_to_slices(index, self.shape)
|
|
63
|
+
result = self.forward[index]
|
|
64
|
+
result = result[:, :, self.mask, :]
|
|
65
|
+
return apply_index_to_slices_changes(result, changes)
|
|
66
|
+
|
|
67
|
+
def tree(self):
|
|
68
|
+
return Node(self, [self.forward.tree()], numbers=[n + 1 for n in self.members])
|
|
69
|
+
|
|
70
|
+
def metadata_specific(self):
|
|
71
|
+
return {
|
|
72
|
+
"numbers": [n + 1 for n in self.members],
|
|
73
|
+
}
|
|
74
|
+
|
|
20
75
|
|
|
21
76
|
class Ensemble(GivenAxis):
|
|
22
77
|
def tree(self):
|
anemoi/datasets/data/grids.py
CHANGED
|
@@ -289,14 +289,15 @@ class Cutout(GridsBase):
|
|
|
289
289
|
"""
|
|
290
290
|
index, changes = index_to_slices(index, self.shape)
|
|
291
291
|
# Select data from each LAM
|
|
292
|
-
lam_data = [lam[index] for lam in self.lams]
|
|
292
|
+
lam_data = [lam[index[:3]] for lam in self.lams]
|
|
293
293
|
|
|
294
294
|
# First apply spatial indexing on `self.globe` and then apply the mask
|
|
295
295
|
globe_data_sliced = self.globe[index[:3]]
|
|
296
296
|
globe_data = globe_data_sliced[..., self.global_mask]
|
|
297
297
|
|
|
298
|
-
# Concatenate LAM data with global data
|
|
299
|
-
result = np.concatenate(lam_data + [globe_data], axis=self.axis)
|
|
298
|
+
# Concatenate LAM data with global data, apply the grid slicing
|
|
299
|
+
result = np.concatenate(lam_data + [globe_data], axis=self.axis)[..., index[3]]
|
|
300
|
+
|
|
300
301
|
return apply_index_to_slices_changes(result, changes)
|
|
301
302
|
|
|
302
303
|
def collect_supporting_arrays(self, collected, *path):
|
|
@@ -324,7 +325,8 @@ class Cutout(GridsBase):
|
|
|
324
325
|
"""
|
|
325
326
|
shapes = [np.sum(mask) for mask in self.masks]
|
|
326
327
|
global_shape = np.sum(self.global_mask)
|
|
327
|
-
|
|
328
|
+
total_shape = sum(shapes) + global_shape
|
|
329
|
+
return tuple(self.lams[0].shape[:-1] + (int(total_shape),))
|
|
328
330
|
|
|
329
331
|
def check_same_resolution(self, d1, d2):
|
|
330
332
|
# Turned off because we are combining different resolutions
|
anemoi/datasets/data/join.py
CHANGED
|
@@ -118,13 +118,19 @@ class Join(Combined):
|
|
|
118
118
|
def variables_metadata(self):
|
|
119
119
|
result = {}
|
|
120
120
|
variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))]
|
|
121
|
+
|
|
121
122
|
for d in self.datasets:
|
|
122
123
|
md = d.variables_metadata
|
|
123
124
|
for v in variables:
|
|
124
125
|
if v in md:
|
|
125
126
|
result[v] = md[v]
|
|
126
127
|
|
|
127
|
-
|
|
128
|
+
if len(result) != len(variables):
|
|
129
|
+
LOG.error("Some variables are missing metadata.")
|
|
130
|
+
for v in variables:
|
|
131
|
+
if v not in result:
|
|
132
|
+
LOG.error("Missing metadata for %r.", v)
|
|
133
|
+
|
|
128
134
|
return result
|
|
129
135
|
|
|
130
136
|
@cached_property
|
anemoi/datasets/data/merge.py
CHANGED
|
@@ -134,6 +134,9 @@ class Merge(Combined):
|
|
|
134
134
|
def tree(self):
|
|
135
135
|
return Node(self, [d.tree() for d in self.datasets], allow_gaps_in_dates=self.allow_gaps_in_dates)
|
|
136
136
|
|
|
137
|
+
def metadata_specific(self):
|
|
138
|
+
return {"allow_gaps_in_dates": self.allow_gaps_in_dates}
|
|
139
|
+
|
|
137
140
|
@debug_indexing
|
|
138
141
|
def __getitem__(self, n):
|
|
139
142
|
if isinstance(n, tuple):
|
anemoi/datasets/data/misc.py
CHANGED
|
@@ -194,7 +194,7 @@ def _open(a):
|
|
|
194
194
|
raise NotImplementedError(f"Unsupported argument: {type(a)}")
|
|
195
195
|
|
|
196
196
|
|
|
197
|
-
def _auto_adjust(datasets, kwargs):
|
|
197
|
+
def _auto_adjust(datasets, kwargs, exclude=None):
|
|
198
198
|
|
|
199
199
|
if "adjust" not in kwargs:
|
|
200
200
|
return datasets, kwargs
|
|
@@ -214,6 +214,9 @@ def _auto_adjust(datasets, kwargs):
|
|
|
214
214
|
for a in adjust_list:
|
|
215
215
|
adjust_set.update(ALIASES.get(a, [a]))
|
|
216
216
|
|
|
217
|
+
if exclude is not None:
|
|
218
|
+
adjust_set -= set(exclude)
|
|
219
|
+
|
|
217
220
|
extra = set(adjust_set) - set(ALIASES["all"])
|
|
218
221
|
if extra:
|
|
219
222
|
raise ValueError(f"Invalid adjust keys: {extra}")
|
|
@@ -335,6 +338,12 @@ def _open_dataset(*args, **kwargs):
|
|
|
335
338
|
assert not sets, sets
|
|
336
339
|
return cutout_factory(args, kwargs).mutate()
|
|
337
340
|
|
|
341
|
+
if "complement" in kwargs:
|
|
342
|
+
from .complement import complement_factory
|
|
343
|
+
|
|
344
|
+
assert not sets, sets
|
|
345
|
+
return complement_factory(args, kwargs).mutate()
|
|
346
|
+
|
|
338
347
|
for name in ("datasets", "dataset"):
|
|
339
348
|
if name in kwargs:
|
|
340
349
|
datasets = kwargs.pop(name)
|
anemoi/datasets/grids.py
CHANGED
|
@@ -152,7 +152,7 @@ def cutout_mask(
|
|
|
152
152
|
plot=None,
|
|
153
153
|
):
|
|
154
154
|
"""Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]"""
|
|
155
|
-
from scipy.spatial import
|
|
155
|
+
from scipy.spatial import cKDTree
|
|
156
156
|
|
|
157
157
|
# TODO: transform min_distance from lat/lon to xyz
|
|
158
158
|
|
|
@@ -195,13 +195,13 @@ def cutout_mask(
|
|
|
195
195
|
min_distance = min_distance_km / 6371.0
|
|
196
196
|
else:
|
|
197
197
|
points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km]
|
|
198
|
-
distances, _ =
|
|
198
|
+
distances, _ = cKDTree(points).query(points, k=2)
|
|
199
199
|
min_distance = np.min(distances[:, 1])
|
|
200
200
|
|
|
201
201
|
LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km")
|
|
202
202
|
|
|
203
|
-
# Use a
|
|
204
|
-
distances, indices =
|
|
203
|
+
# Use a cKDTree to find the nearest points
|
|
204
|
+
distances, indices = cKDTree(lam_points).query(global_points, k=neighbours)
|
|
205
205
|
|
|
206
206
|
# Centre of the Earth
|
|
207
207
|
zero = np.array([0.0, 0.0, 0.0])
|
|
@@ -255,7 +255,7 @@ def thinning_mask(
|
|
|
255
255
|
cropping_distance=2.0,
|
|
256
256
|
):
|
|
257
257
|
"""Return the list of points in [lats, lons] closest to [global_lats, global_lons]"""
|
|
258
|
-
from scipy.spatial import
|
|
258
|
+
from scipy.spatial import cKDTree
|
|
259
259
|
|
|
260
260
|
assert global_lats.ndim == 1
|
|
261
261
|
assert global_lons.ndim == 1
|
|
@@ -291,20 +291,20 @@ def thinning_mask(
|
|
|
291
291
|
xyx = latlon_to_xyz(lats, lons)
|
|
292
292
|
points = np.array(xyx).transpose()
|
|
293
293
|
|
|
294
|
-
# Use a
|
|
295
|
-
_, indices =
|
|
294
|
+
# Use a cKDTree to find the nearest points
|
|
295
|
+
_, indices = cKDTree(points).query(global_points, k=1)
|
|
296
296
|
|
|
297
297
|
return np.array([i for i in indices])
|
|
298
298
|
|
|
299
299
|
|
|
300
300
|
def outline(lats, lons, neighbours=5):
|
|
301
|
-
from scipy.spatial import
|
|
301
|
+
from scipy.spatial import cKDTree
|
|
302
302
|
|
|
303
303
|
xyx = latlon_to_xyz(lats, lons)
|
|
304
304
|
grid_points = np.array(xyx).transpose()
|
|
305
305
|
|
|
306
|
-
# Use a
|
|
307
|
-
_, indices =
|
|
306
|
+
# Use a cKDTree to find the nearest points
|
|
307
|
+
_, indices = cKDTree(grid_points).query(grid_points, k=neighbours)
|
|
308
308
|
|
|
309
309
|
# Centre of the Earth
|
|
310
310
|
zero = np.array([0.0, 0.0, 0.0])
|
|
@@ -379,6 +379,19 @@ def serialise_mask(mask):
|
|
|
379
379
|
return result
|
|
380
380
|
|
|
381
381
|
|
|
382
|
+
def nearest_grid_points(source_latitudes, source_longitudes, target_latitudes, target_longitudes):
|
|
383
|
+
from scipy.spatial import cKDTree
|
|
384
|
+
|
|
385
|
+
source_xyz = latlon_to_xyz(source_latitudes, source_longitudes)
|
|
386
|
+
source_points = np.array(source_xyz).transpose()
|
|
387
|
+
|
|
388
|
+
target_xyz = latlon_to_xyz(target_latitudes, target_longitudes)
|
|
389
|
+
target_points = np.array(target_xyz).transpose()
|
|
390
|
+
|
|
391
|
+
_, indices = cKDTree(source_points).query(target_points, k=1)
|
|
392
|
+
return indices
|
|
393
|
+
|
|
394
|
+
|
|
382
395
|
if __name__ == "__main__":
|
|
383
396
|
global_lats, global_lons = np.meshgrid(
|
|
384
397
|
np.linspace(90, -90, 90),
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: anemoi-datasets
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.13
|
|
4
4
|
Summary: A package to hold various functions to support training of ML models on ECMWF data.
|
|
5
5
|
Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
|
|
6
|
-
License:
|
|
6
|
+
License: Apache License
|
|
7
7
|
Version 2.0, January 2004
|
|
8
8
|
http://www.apache.org/licenses/
|
|
9
9
|
|
|
@@ -224,40 +224,39 @@ Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
|
224
224
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
225
225
|
Requires-Python: >=3.9
|
|
226
226
|
License-File: LICENSE
|
|
227
|
-
Requires-Dist: anemoi-transform
|
|
228
|
-
Requires-Dist: anemoi-utils[provenance]
|
|
227
|
+
Requires-Dist: anemoi-transform>=0.1
|
|
228
|
+
Requires-Dist: anemoi-utils[provenance]>=0.4.9
|
|
229
229
|
Requires-Dist: cfunits
|
|
230
230
|
Requires-Dist: numpy
|
|
231
231
|
Requires-Dist: pyyaml
|
|
232
232
|
Requires-Dist: semantic-version
|
|
233
233
|
Requires-Dist: tqdm
|
|
234
|
-
Requires-Dist: zarr
|
|
234
|
+
Requires-Dist: zarr<=2.17
|
|
235
235
|
Provides-Extra: all
|
|
236
|
-
Requires-Dist: anemoi-datasets[create,remote,xarray]
|
|
236
|
+
Requires-Dist: anemoi-datasets[create,remote,xarray]; extra == "all"
|
|
237
237
|
Provides-Extra: create
|
|
238
|
-
Requires-Dist: earthkit-data[mars]
|
|
239
|
-
Requires-Dist: earthkit-geo
|
|
240
|
-
Requires-Dist: earthkit-meteo
|
|
241
|
-
Requires-Dist: eccodes
|
|
242
|
-
Requires-Dist: entrypoints
|
|
243
|
-
Requires-Dist: pyproj
|
|
238
|
+
Requires-Dist: earthkit-data[mars]>=0.10.7; extra == "create"
|
|
239
|
+
Requires-Dist: earthkit-geo>=0.2; extra == "create"
|
|
240
|
+
Requires-Dist: earthkit-meteo; extra == "create"
|
|
241
|
+
Requires-Dist: eccodes>=2.38.1; extra == "create"
|
|
242
|
+
Requires-Dist: entrypoints; extra == "create"
|
|
243
|
+
Requires-Dist: pyproj; extra == "create"
|
|
244
244
|
Provides-Extra: dev
|
|
245
|
-
Requires-Dist: anemoi-datasets[all,docs,tests]
|
|
245
|
+
Requires-Dist: anemoi-datasets[all,docs,tests]; extra == "dev"
|
|
246
246
|
Provides-Extra: docs
|
|
247
|
-
Requires-Dist: nbsphinx
|
|
248
|
-
Requires-Dist: pandoc
|
|
249
|
-
Requires-Dist: sphinx
|
|
250
|
-
Requires-Dist: sphinx-argparse
|
|
251
|
-
Requires-Dist: sphinx-rtd-theme
|
|
247
|
+
Requires-Dist: nbsphinx; extra == "docs"
|
|
248
|
+
Requires-Dist: pandoc; extra == "docs"
|
|
249
|
+
Requires-Dist: sphinx; extra == "docs"
|
|
250
|
+
Requires-Dist: sphinx-argparse; extra == "docs"
|
|
251
|
+
Requires-Dist: sphinx-rtd-theme; extra == "docs"
|
|
252
252
|
Provides-Extra: remote
|
|
253
|
-
Requires-Dist: boto3
|
|
254
|
-
Requires-Dist: requests
|
|
253
|
+
Requires-Dist: boto3; extra == "remote"
|
|
254
|
+
Requires-Dist: requests; extra == "remote"
|
|
255
255
|
Provides-Extra: tests
|
|
256
|
-
Requires-Dist: pytest
|
|
256
|
+
Requires-Dist: pytest; extra == "tests"
|
|
257
257
|
Provides-Extra: xarray
|
|
258
|
-
Requires-Dist: gcsfs
|
|
259
|
-
Requires-Dist: kerchunk
|
|
260
|
-
Requires-Dist: pandas
|
|
261
|
-
Requires-Dist: planetary-computer
|
|
262
|
-
Requires-Dist: pystac-client
|
|
263
|
-
|
|
258
|
+
Requires-Dist: gcsfs; extra == "xarray"
|
|
259
|
+
Requires-Dist: kerchunk; extra == "xarray"
|
|
260
|
+
Requires-Dist: pandas; extra == "xarray"
|
|
261
|
+
Requires-Dist: planetary-computer; extra == "xarray"
|
|
262
|
+
Requires-Dist: pystac-client; extra == "xarray"
|