anemoi-datasets 0.5.12__py3-none-any.whl → 0.5.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/create.py +4 -2
- anemoi/datasets/create/__init__.py +22 -6
- anemoi/datasets/create/check.py +1 -1
- anemoi/datasets/create/functions/__init__.py +15 -1
- anemoi/datasets/create/functions/filters/orog_to_z.py +58 -0
- anemoi/datasets/create/functions/filters/sum.py +71 -0
- anemoi/datasets/create/functions/filters/wz_to_w.py +79 -0
- anemoi/datasets/create/functions/sources/accumulations.py +7 -2
- anemoi/datasets/create/functions/sources/eccc_fstd.py +16 -0
- anemoi/datasets/create/functions/sources/mars.py +5 -1
- anemoi/datasets/create/functions/sources/xarray/__init__.py +3 -3
- anemoi/datasets/create/functions/sources/xarray/field.py +5 -1
- anemoi/datasets/create/functions/sources/xarray/fieldlist.py +10 -1
- anemoi/datasets/create/functions/sources/xarray/metadata.py +5 -11
- anemoi/datasets/create/functions/sources/xarray/patch.py +44 -0
- anemoi/datasets/create/functions/sources/xarray/time.py +15 -0
- anemoi/datasets/create/functions/sources/xarray/variable.py +18 -2
- anemoi/datasets/create/input/repeated_dates.py +18 -0
- anemoi/datasets/create/input/result.py +1 -1
- anemoi/datasets/create/statistics/__init__.py +7 -4
- anemoi/datasets/create/utils.py +4 -0
- anemoi/datasets/data/complement.py +164 -0
- anemoi/datasets/data/dataset.py +68 -5
- anemoi/datasets/data/ensemble.py +55 -0
- anemoi/datasets/data/join.py +1 -2
- anemoi/datasets/data/merge.py +3 -0
- anemoi/datasets/data/misc.py +34 -1
- anemoi/datasets/grids.py +29 -10
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/METADATA +2 -2
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/RECORD +35 -29
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/WHEEL +1 -1
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/top_level.txt +0 -0
|
@@ -37,7 +37,7 @@ class Variable:
|
|
|
37
37
|
self.coordinates = coordinates
|
|
38
38
|
|
|
39
39
|
self._metadata = metadata.copy()
|
|
40
|
-
self._metadata.update({"variable": variable.name})
|
|
40
|
+
self._metadata.update({"variable": variable.name, "param": variable.name})
|
|
41
41
|
|
|
42
42
|
self.time = time
|
|
43
43
|
|
|
@@ -45,6 +45,9 @@ class Variable:
|
|
|
45
45
|
self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid}
|
|
46
46
|
self.by_name = {c.variable.name: c for c in coordinates}
|
|
47
47
|
|
|
48
|
+
# We need that alias for the time dimension
|
|
49
|
+
self._aliases = dict(valid_datetime="time")
|
|
50
|
+
|
|
48
51
|
self.length = math.prod(self.shape)
|
|
49
52
|
|
|
50
53
|
@property
|
|
@@ -96,15 +99,28 @@ class Variable:
|
|
|
96
99
|
|
|
97
100
|
k, v = kwargs.popitem()
|
|
98
101
|
|
|
102
|
+
user_provided_k = k
|
|
103
|
+
|
|
104
|
+
if k == "valid_datetime":
|
|
105
|
+
# Ask the Time object to select the valid datetime
|
|
106
|
+
k = self.time.select_valid_datetime(self)
|
|
107
|
+
if k is None:
|
|
108
|
+
return None
|
|
109
|
+
|
|
99
110
|
c = self.by_name.get(k)
|
|
100
111
|
|
|
112
|
+
# assert c is not None, f"Could not find coordinate {k} in {self.variable.name} {self.coordinates} {list(self.by_name)}"
|
|
113
|
+
|
|
101
114
|
if c is None:
|
|
102
115
|
missing[k] = v
|
|
103
116
|
return self.sel(missing, **kwargs)
|
|
104
117
|
|
|
105
118
|
i = c.index(v)
|
|
106
119
|
if i is None:
|
|
107
|
-
|
|
120
|
+
if k != user_provided_k:
|
|
121
|
+
LOG.warning(f"Could not find {user_provided_k}={v} in {c} (alias of {k})")
|
|
122
|
+
else:
|
|
123
|
+
LOG.warning(f"Could not find {k}={v} in {c}")
|
|
108
124
|
return None
|
|
109
125
|
|
|
110
126
|
coordinates = [x.reduced(i) if c is x else x for x in self.coordinates]
|
|
@@ -72,6 +72,11 @@ class DateMapperClosest(DateMapper):
|
|
|
72
72
|
end += self.frequency
|
|
73
73
|
|
|
74
74
|
to_try = sorted(to_try - self.tried)
|
|
75
|
+
info = {k: "no-data" for k in to_try}
|
|
76
|
+
|
|
77
|
+
if not to_try:
|
|
78
|
+
LOG.warning(f"No new dates to try for {group_of_dates} in {self.source}")
|
|
79
|
+
# return []
|
|
75
80
|
|
|
76
81
|
if to_try:
|
|
77
82
|
result = self.source.select(
|
|
@@ -82,19 +87,32 @@ class DateMapperClosest(DateMapper):
|
|
|
82
87
|
)
|
|
83
88
|
)
|
|
84
89
|
|
|
90
|
+
cnt = 0
|
|
85
91
|
for f in result.datasource:
|
|
92
|
+
cnt += 1
|
|
86
93
|
# We could keep the fields in a dictionary, but we don't want to keep the fields in memory
|
|
87
94
|
date = as_datetime(f.metadata("valid_datetime"))
|
|
88
95
|
|
|
89
96
|
if self.skip_all_nans:
|
|
90
97
|
if np.isnan(f.to_numpy()).all():
|
|
91
98
|
LOG.warning(f"Skipping {date} because all values are NaN")
|
|
99
|
+
info[date] = "all-nans"
|
|
92
100
|
continue
|
|
93
101
|
|
|
102
|
+
info[date] = "ok"
|
|
94
103
|
self.found.add(date)
|
|
95
104
|
|
|
105
|
+
if cnt == 0:
|
|
106
|
+
raise ValueError(f"No data found for {group_of_dates} in {self.source}")
|
|
107
|
+
|
|
96
108
|
self.tried.update(to_try)
|
|
97
109
|
|
|
110
|
+
if not self.found:
|
|
111
|
+
for k, v in info.items():
|
|
112
|
+
LOG.warning(f"{k}: {v}")
|
|
113
|
+
|
|
114
|
+
raise ValueError(f"No matching data found for {asked_dates} in {self.source}")
|
|
115
|
+
|
|
98
116
|
new_dates = defaultdict(list)
|
|
99
117
|
|
|
100
118
|
for date in asked_dates:
|
|
@@ -459,7 +459,7 @@ class Result:
|
|
|
459
459
|
if self.group_of_dates is not None:
|
|
460
460
|
dates = f" {len(self.group_of_dates)} dates"
|
|
461
461
|
dates += " ("
|
|
462
|
-
dates += "/".join(d.strftime("%Y-%m-%
|
|
462
|
+
dates += "/".join(d.strftime("%Y-%m-%dT%H:%M") for d in self.group_of_dates)
|
|
463
463
|
if len(dates) > 100:
|
|
464
464
|
dates = dates[:100] + "..."
|
|
465
465
|
dates += ")"
|
|
@@ -18,6 +18,7 @@ import shutil
|
|
|
18
18
|
import socket
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
|
+
import tqdm
|
|
21
22
|
from anemoi.utils.provenance import gather_provenance_info
|
|
22
23
|
|
|
23
24
|
from ..check import check_data_values
|
|
@@ -98,7 +99,7 @@ def fix_variance(x, name, count, sums, squares):
|
|
|
98
99
|
|
|
99
100
|
variances = squares / count - mean * mean
|
|
100
101
|
assert variances.shape == squares.shape == mean.shape
|
|
101
|
-
if all(variances >= 0):
|
|
102
|
+
if np.all(variances >= 0):
|
|
102
103
|
LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.")
|
|
103
104
|
return 0
|
|
104
105
|
|
|
@@ -108,7 +109,7 @@ def fix_variance(x, name, count, sums, squares):
|
|
|
108
109
|
# return 0
|
|
109
110
|
|
|
110
111
|
LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).")
|
|
111
|
-
return
|
|
112
|
+
return 0
|
|
112
113
|
|
|
113
114
|
|
|
114
115
|
def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares):
|
|
@@ -134,7 +135,7 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa
|
|
|
134
135
|
|
|
135
136
|
def compute_statistics(array, check_variables_names=None, allow_nans=False):
|
|
136
137
|
"""Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary."""
|
|
137
|
-
|
|
138
|
+
LOG.info(f"Computing statistics for {array.shape} array")
|
|
138
139
|
nvars = array.shape[1]
|
|
139
140
|
|
|
140
141
|
LOG.debug(f"Stats {nvars}, {array.shape}, {check_variables_names}")
|
|
@@ -149,7 +150,7 @@ def compute_statistics(array, check_variables_names=None, allow_nans=False):
|
|
|
149
150
|
maximum = np.zeros(stats_shape, dtype=np.float64)
|
|
150
151
|
has_nans = np.zeros(stats_shape, dtype=np.bool_)
|
|
151
152
|
|
|
152
|
-
for i, chunk in enumerate(array):
|
|
153
|
+
for i, chunk in tqdm.tqdm(enumerate(array), delay=1, total=array.shape[0], desc="Computing statistics"):
|
|
153
154
|
values = chunk.reshape((nvars, -1))
|
|
154
155
|
|
|
155
156
|
for j, name in enumerate(check_variables_names):
|
|
@@ -166,6 +167,8 @@ def compute_statistics(array, check_variables_names=None, allow_nans=False):
|
|
|
166
167
|
count[i] = np.sum(~np.isnan(values), axis=1)
|
|
167
168
|
has_nans[i] = np.isnan(values).any()
|
|
168
169
|
|
|
170
|
+
LOG.info(f"Statistics computed for {nvars} variables.")
|
|
171
|
+
|
|
169
172
|
return {
|
|
170
173
|
"minimum": minimum,
|
|
171
174
|
"maximum": maximum,
|
anemoi/datasets/create/utils.py
CHANGED
|
@@ -54,6 +54,10 @@ def to_datetime(*args, **kwargs):
|
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
def make_list_int(value):
|
|
57
|
+
# Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers.
|
|
58
|
+
# Moved to anemoi.utils.humanize
|
|
59
|
+
# replace with from anemoi.utils.humanize import make_list_int
|
|
60
|
+
# when anemoi-utils is released and pyproject.toml is updated
|
|
57
61
|
if isinstance(value, str):
|
|
58
62
|
if "/" not in value:
|
|
59
63
|
return [value]
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from functools import cached_property
|
|
13
|
+
|
|
14
|
+
from ..grids import nearest_grid_points
|
|
15
|
+
from .debug import Node
|
|
16
|
+
from .forwards import Combined
|
|
17
|
+
from .indexing import apply_index_to_slices_changes
|
|
18
|
+
from .indexing import index_to_slices
|
|
19
|
+
from .indexing import update_tuple
|
|
20
|
+
from .misc import _auto_adjust
|
|
21
|
+
from .misc import _open
|
|
22
|
+
|
|
23
|
+
LOG = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Complement(Combined):
|
|
27
|
+
|
|
28
|
+
def __init__(self, target, source, what="variables", interpolation="nearest"):
|
|
29
|
+
super().__init__([target, source])
|
|
30
|
+
|
|
31
|
+
# We had the variables of dataset[1] to dataset[0]
|
|
32
|
+
# interpoated on the grid of dataset[0]
|
|
33
|
+
|
|
34
|
+
self.target = target
|
|
35
|
+
self.source = source
|
|
36
|
+
|
|
37
|
+
self._variables = []
|
|
38
|
+
|
|
39
|
+
# Keep the same order as the original dataset
|
|
40
|
+
for v in self.source.variables:
|
|
41
|
+
if v not in self.target.variables:
|
|
42
|
+
self._variables.append(v)
|
|
43
|
+
|
|
44
|
+
if not self._variables:
|
|
45
|
+
raise ValueError("Augment: no missing variables")
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def variables(self):
|
|
49
|
+
return self._variables
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def name_to_index(self):
|
|
53
|
+
return {v: i for i, v in enumerate(self.variables)}
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def shape(self):
|
|
57
|
+
shape = self.target.shape
|
|
58
|
+
return (shape[0], len(self._variables)) + shape[2:]
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def variables_metadata(self):
|
|
62
|
+
return {k: v for k, v in self.source.variables_metadata.items() if k in self._variables}
|
|
63
|
+
|
|
64
|
+
def check_same_variables(self, d1, d2):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@cached_property
|
|
68
|
+
def missing(self):
|
|
69
|
+
missing = self.source.missing.copy()
|
|
70
|
+
missing = missing | self.target.missing
|
|
71
|
+
return set(missing)
|
|
72
|
+
|
|
73
|
+
def tree(self):
|
|
74
|
+
"""Generates a hierarchical tree structure for the `Cutout` instance and
|
|
75
|
+
its associated datasets.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Node: A `Node` object representing the `Cutout` instance as the root
|
|
79
|
+
node, with each dataset in `self.datasets` represented as a child
|
|
80
|
+
node.
|
|
81
|
+
"""
|
|
82
|
+
return Node(self, [d.tree() for d in (self.target, self.source)])
|
|
83
|
+
|
|
84
|
+
def __getitem__(self, index):
|
|
85
|
+
if isinstance(index, (int, slice)):
|
|
86
|
+
index = (index, slice(None), slice(None), slice(None))
|
|
87
|
+
return self._get_tuple(index)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ComplementNone(Complement):
|
|
91
|
+
|
|
92
|
+
def __init__(self, target, source):
|
|
93
|
+
super().__init__(target, source)
|
|
94
|
+
|
|
95
|
+
def _get_tuple(self, index):
|
|
96
|
+
index, changes = index_to_slices(index, self.shape)
|
|
97
|
+
result = self.source[index]
|
|
98
|
+
return apply_index_to_slices_changes(result, changes)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ComplementNearest(Complement):
|
|
102
|
+
|
|
103
|
+
def __init__(self, target, source):
|
|
104
|
+
super().__init__(target, source)
|
|
105
|
+
|
|
106
|
+
self._nearest_grid_points = nearest_grid_points(
|
|
107
|
+
self.source.latitudes,
|
|
108
|
+
self.source.longitudes,
|
|
109
|
+
self.target.latitudes,
|
|
110
|
+
self.target.longitudes,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def check_compatibility(self, d1, d2):
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
def _get_tuple(self, index):
|
|
117
|
+
variable_index = 1
|
|
118
|
+
index, changes = index_to_slices(index, self.shape)
|
|
119
|
+
index, previous = update_tuple(index, variable_index, slice(None))
|
|
120
|
+
source_index = [self.source.name_to_index[x] for x in self.variables[previous]]
|
|
121
|
+
source_data = self.source[index[0], source_index, index[2], ...]
|
|
122
|
+
target_data = source_data[..., self._nearest_grid_points]
|
|
123
|
+
|
|
124
|
+
result = target_data[..., index[3]]
|
|
125
|
+
|
|
126
|
+
return apply_index_to_slices_changes(result, changes)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def complement_factory(args, kwargs):
|
|
130
|
+
from .select import Select
|
|
131
|
+
|
|
132
|
+
assert len(args) == 0, args
|
|
133
|
+
|
|
134
|
+
source = kwargs.pop("source")
|
|
135
|
+
target = kwargs.pop("complement")
|
|
136
|
+
what = kwargs.pop("what", "variables")
|
|
137
|
+
interpolation = kwargs.pop("interpolation", "none")
|
|
138
|
+
|
|
139
|
+
if what != "variables":
|
|
140
|
+
raise NotImplementedError(f"Complement what={what} not implemented")
|
|
141
|
+
|
|
142
|
+
if interpolation not in ("none", "nearest"):
|
|
143
|
+
raise NotImplementedError(f"Complement method={interpolation} not implemented")
|
|
144
|
+
|
|
145
|
+
source = _open(source)
|
|
146
|
+
target = _open(target)
|
|
147
|
+
# `select` is the same as `variables`
|
|
148
|
+
(source, target), kwargs = _auto_adjust([source, target], kwargs, exclude=["select"])
|
|
149
|
+
|
|
150
|
+
Class = {
|
|
151
|
+
None: ComplementNone,
|
|
152
|
+
"none": ComplementNone,
|
|
153
|
+
"nearest": ComplementNearest,
|
|
154
|
+
}[interpolation]
|
|
155
|
+
|
|
156
|
+
complement = Class(target=target, source=source)._subset(**kwargs)
|
|
157
|
+
|
|
158
|
+
# Will join the datasets along the variables axis
|
|
159
|
+
reorder = source.variables
|
|
160
|
+
complemented = _open([target, complement])
|
|
161
|
+
ordered = (
|
|
162
|
+
Select(complemented, complemented._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate()
|
|
163
|
+
)
|
|
164
|
+
return ordered
|
anemoi/datasets/data/dataset.py
CHANGED
|
@@ -168,6 +168,16 @@ class Dataset:
|
|
|
168
168
|
bbox = kwargs.pop("area")
|
|
169
169
|
return Cropping(self, bbox)._subset(**kwargs).mutate()
|
|
170
170
|
|
|
171
|
+
if "number" in kwargs or "numbers" or "member" in kwargs or "members" in kwargs:
|
|
172
|
+
from .ensemble import Number
|
|
173
|
+
|
|
174
|
+
members = {}
|
|
175
|
+
for key in ["number", "numbers", "member", "members"]:
|
|
176
|
+
if key in kwargs:
|
|
177
|
+
members[key] = kwargs.pop(key)
|
|
178
|
+
|
|
179
|
+
return Number(self, **members)._subset(**kwargs).mutate()
|
|
180
|
+
|
|
171
181
|
if "set_missing_dates" in kwargs:
|
|
172
182
|
from .missing import MissingDates
|
|
173
183
|
|
|
@@ -251,13 +261,19 @@ class Dataset:
|
|
|
251
261
|
return sorted([v for k, v in self.name_to_index.items() if k not in vars])
|
|
252
262
|
|
|
253
263
|
def _reorder_to_columns(self, vars):
|
|
264
|
+
if isinstance(vars, str) and vars == "sort":
|
|
265
|
+
# Sorting the variables alphabetically.
|
|
266
|
+
# This is cruical for pre-training then transfer learning in combination with
|
|
267
|
+
# cutout and adjust = 'all'
|
|
268
|
+
|
|
269
|
+
indices = [self.name_to_index[k] for k, v in sorted(self.name_to_index.items(), key=lambda x: x[0])]
|
|
270
|
+
assert set(indices) == set(range(len(self.name_to_index)))
|
|
271
|
+
return indices
|
|
272
|
+
|
|
254
273
|
if isinstance(vars, (list, tuple)):
|
|
255
274
|
vars = {k: i for i, k in enumerate(vars)}
|
|
256
275
|
|
|
257
|
-
indices = []
|
|
258
|
-
|
|
259
|
-
for k, v in sorted(vars.items(), key=lambda x: x[1]):
|
|
260
|
-
indices.append(self.name_to_index[k])
|
|
276
|
+
indices = [self.name_to_index[k] for k, v in sorted(vars.items(), key=lambda x: x[1])]
|
|
261
277
|
|
|
262
278
|
# Make sure we don't forget any variables
|
|
263
279
|
assert set(indices) == set(range(len(self.name_to_index)))
|
|
@@ -469,7 +485,7 @@ class Dataset:
|
|
|
469
485
|
sample_count = min(4, len(indices))
|
|
470
486
|
count = len(indices)
|
|
471
487
|
|
|
472
|
-
p = slice(0, count, count // (sample_count - 1))
|
|
488
|
+
p = slice(0, count, count // max(1, sample_count - 1))
|
|
473
489
|
samples = list(range(*p.indices(count)))
|
|
474
490
|
|
|
475
491
|
samples.append(count - 1) # Add last
|
|
@@ -502,3 +518,50 @@ class Dataset:
|
|
|
502
518
|
result.append(v)
|
|
503
519
|
|
|
504
520
|
return result
|
|
521
|
+
|
|
522
|
+
def plot(self, date, variable, member=0, **kwargs):
|
|
523
|
+
"""For debugging purposes, plot a field.
|
|
524
|
+
|
|
525
|
+
Parameters
|
|
526
|
+
----------
|
|
527
|
+
date : int or datetime.datetime or numpy.datetime64 or str
|
|
528
|
+
The date to plot.
|
|
529
|
+
variable : int or str
|
|
530
|
+
The variable to plot.
|
|
531
|
+
member : int, optional
|
|
532
|
+
The ensemble member to plot.
|
|
533
|
+
|
|
534
|
+
**kwargs:
|
|
535
|
+
Additional arguments to pass to matplotlib.pyplot.tricontourf
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
Returns
|
|
539
|
+
-------
|
|
540
|
+
matplotlib.pyplot.Axes
|
|
541
|
+
"""
|
|
542
|
+
|
|
543
|
+
from anemoi.utils.devtools import plot_values
|
|
544
|
+
from earthkit.data.utils.dates import to_datetime
|
|
545
|
+
|
|
546
|
+
if not isinstance(date, int):
|
|
547
|
+
date = np.datetime64(to_datetime(date)).astype(self.dates[0].dtype)
|
|
548
|
+
index = np.where(self.dates == date)[0]
|
|
549
|
+
if len(index) == 0:
|
|
550
|
+
raise ValueError(
|
|
551
|
+
f"Date {date} not found in the dataset {self.dates[0]} to {self.dates[-1]} by {self.frequency}"
|
|
552
|
+
)
|
|
553
|
+
date_index = index[0]
|
|
554
|
+
else:
|
|
555
|
+
date_index = date
|
|
556
|
+
|
|
557
|
+
if isinstance(variable, int):
|
|
558
|
+
variable_index = variable
|
|
559
|
+
else:
|
|
560
|
+
if variable not in self.variables:
|
|
561
|
+
raise ValueError(f"Unknown variable {variable} (available: {self.variables})")
|
|
562
|
+
|
|
563
|
+
variable_index = self.name_to_index[variable]
|
|
564
|
+
|
|
565
|
+
values = self[date_index, variable_index, member]
|
|
566
|
+
|
|
567
|
+
return plot_values(values, self.latitudes, self.longitudes, **kwargs)
|
anemoi/datasets/data/ensemble.py
CHANGED
|
@@ -10,13 +10,68 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
13
15
|
from .debug import Node
|
|
16
|
+
from .forwards import Forwards
|
|
14
17
|
from .forwards import GivenAxis
|
|
18
|
+
from .indexing import apply_index_to_slices_changes
|
|
19
|
+
from .indexing import index_to_slices
|
|
20
|
+
from .indexing import update_tuple
|
|
15
21
|
from .misc import _auto_adjust
|
|
16
22
|
from .misc import _open
|
|
17
23
|
|
|
18
24
|
LOG = logging.getLogger(__name__)
|
|
19
25
|
|
|
26
|
+
OFFSETS = dict(number=1, numbers=1, member=0, members=0)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Number(Forwards):
|
|
30
|
+
def __init__(self, forward, **kwargs):
|
|
31
|
+
super().__init__(forward)
|
|
32
|
+
|
|
33
|
+
self.members = []
|
|
34
|
+
for key, values in kwargs.items():
|
|
35
|
+
if not isinstance(values, (list, tuple)):
|
|
36
|
+
values = [values]
|
|
37
|
+
self.members.extend([int(v) - OFFSETS[key] for v in values])
|
|
38
|
+
|
|
39
|
+
self.members = sorted(set(self.members))
|
|
40
|
+
for n in self.members:
|
|
41
|
+
if not (0 <= n < forward.shape[2]):
|
|
42
|
+
raise ValueError(f"Member {n} is out of range. `number(s)` is one-based, `member(s)` is zero-based.")
|
|
43
|
+
|
|
44
|
+
self.mask = np.array([n in self.members for n in range(forward.shape[2])], dtype=bool)
|
|
45
|
+
self._shape, _ = update_tuple(forward.shape, 2, len(self.members))
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def shape(self):
|
|
49
|
+
return self._shape
|
|
50
|
+
|
|
51
|
+
def __getitem__(self, index):
|
|
52
|
+
if isinstance(index, int):
|
|
53
|
+
result = self.forward[index]
|
|
54
|
+
result = result[:, self.mask, :]
|
|
55
|
+
return result
|
|
56
|
+
|
|
57
|
+
if isinstance(index, slice):
|
|
58
|
+
result = self.forward[index]
|
|
59
|
+
result = result[:, :, self.mask, :]
|
|
60
|
+
return result
|
|
61
|
+
|
|
62
|
+
index, changes = index_to_slices(index, self.shape)
|
|
63
|
+
result = self.forward[index]
|
|
64
|
+
result = result[:, :, self.mask, :]
|
|
65
|
+
return apply_index_to_slices_changes(result, changes)
|
|
66
|
+
|
|
67
|
+
def tree(self):
|
|
68
|
+
return Node(self, [self.forward.tree()], numbers=[n + 1 for n in self.members])
|
|
69
|
+
|
|
70
|
+
def metadata_specific(self):
|
|
71
|
+
return {
|
|
72
|
+
"numbers": [n + 1 for n in self.members],
|
|
73
|
+
}
|
|
74
|
+
|
|
20
75
|
|
|
21
76
|
class Ensemble(GivenAxis):
|
|
22
77
|
def tree(self):
|
anemoi/datasets/data/join.py
CHANGED
|
@@ -118,6 +118,7 @@ class Join(Combined):
|
|
|
118
118
|
def variables_metadata(self):
|
|
119
119
|
result = {}
|
|
120
120
|
variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))]
|
|
121
|
+
|
|
121
122
|
for d in self.datasets:
|
|
122
123
|
md = d.variables_metadata
|
|
123
124
|
for v in variables:
|
|
@@ -130,8 +131,6 @@ class Join(Combined):
|
|
|
130
131
|
if v not in result:
|
|
131
132
|
LOG.error("Missing metadata for %r.", v)
|
|
132
133
|
|
|
133
|
-
raise ValueError("Some variables are missing metadata.")
|
|
134
|
-
|
|
135
134
|
return result
|
|
136
135
|
|
|
137
136
|
@cached_property
|
anemoi/datasets/data/merge.py
CHANGED
|
@@ -134,6 +134,9 @@ class Merge(Combined):
|
|
|
134
134
|
def tree(self):
|
|
135
135
|
return Node(self, [d.tree() for d in self.datasets], allow_gaps_in_dates=self.allow_gaps_in_dates)
|
|
136
136
|
|
|
137
|
+
def metadata_specific(self):
|
|
138
|
+
return {"allow_gaps_in_dates": self.allow_gaps_in_dates}
|
|
139
|
+
|
|
137
140
|
@debug_indexing
|
|
138
141
|
def __getitem__(self, n):
|
|
139
142
|
if isinstance(n, tuple):
|
anemoi/datasets/data/misc.py
CHANGED
|
@@ -103,6 +103,30 @@ def _as_date(d, dates, last):
|
|
|
103
103
|
|
|
104
104
|
if isinstance(d, str):
|
|
105
105
|
|
|
106
|
+
def isfloat(s):
|
|
107
|
+
try:
|
|
108
|
+
float(s)
|
|
109
|
+
return True
|
|
110
|
+
except ValueError:
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
if d.endswith("%") and isfloat(d[:-1]):
|
|
114
|
+
x = float(d[:-1])
|
|
115
|
+
if not 0 <= x <= 100:
|
|
116
|
+
raise ValueError(f"Invalid date: {d}")
|
|
117
|
+
i_float = x * len(dates) / 100
|
|
118
|
+
|
|
119
|
+
epsilon = 2 ** (-30)
|
|
120
|
+
if len(dates) > 1 / epsilon:
|
|
121
|
+
LOG.warning("Too many dates to use percentage, one date may be lost in rounding")
|
|
122
|
+
|
|
123
|
+
if last:
|
|
124
|
+
index = int(i_float + epsilon) - 1
|
|
125
|
+
else:
|
|
126
|
+
index = int(i_float - epsilon)
|
|
127
|
+
index = max(0, min(len(dates) - 1, index))
|
|
128
|
+
return dates[index]
|
|
129
|
+
|
|
106
130
|
if "-" in d and ":" in d:
|
|
107
131
|
date, time = d.replace(" ", "T").split("T")
|
|
108
132
|
year, month, day = [int(_) for _ in date.split("-")]
|
|
@@ -194,7 +218,7 @@ def _open(a):
|
|
|
194
218
|
raise NotImplementedError(f"Unsupported argument: {type(a)}")
|
|
195
219
|
|
|
196
220
|
|
|
197
|
-
def _auto_adjust(datasets, kwargs):
|
|
221
|
+
def _auto_adjust(datasets, kwargs, exclude=None):
|
|
198
222
|
|
|
199
223
|
if "adjust" not in kwargs:
|
|
200
224
|
return datasets, kwargs
|
|
@@ -214,6 +238,9 @@ def _auto_adjust(datasets, kwargs):
|
|
|
214
238
|
for a in adjust_list:
|
|
215
239
|
adjust_set.update(ALIASES.get(a, [a]))
|
|
216
240
|
|
|
241
|
+
if exclude is not None:
|
|
242
|
+
adjust_set -= set(exclude)
|
|
243
|
+
|
|
217
244
|
extra = set(adjust_set) - set(ALIASES["all"])
|
|
218
245
|
if extra:
|
|
219
246
|
raise ValueError(f"Invalid adjust keys: {extra}")
|
|
@@ -335,6 +362,12 @@ def _open_dataset(*args, **kwargs):
|
|
|
335
362
|
assert not sets, sets
|
|
336
363
|
return cutout_factory(args, kwargs).mutate()
|
|
337
364
|
|
|
365
|
+
if "complement" in kwargs:
|
|
366
|
+
from .complement import complement_factory
|
|
367
|
+
|
|
368
|
+
assert not sets, sets
|
|
369
|
+
return complement_factory(args, kwargs).mutate()
|
|
370
|
+
|
|
338
371
|
for name in ("datasets", "dataset"):
|
|
339
372
|
if name in kwargs:
|
|
340
373
|
datasets = kwargs.pop(name)
|
anemoi/datasets/grids.py
CHANGED
|
@@ -62,6 +62,8 @@ def plot_mask(path, mask, lats, lons, global_lats, global_lons):
|
|
|
62
62
|
plt.savefig(path + "-global-zoomed.png")
|
|
63
63
|
|
|
64
64
|
|
|
65
|
+
# TODO: Use the one from anemoi.utils.grids instead
|
|
66
|
+
# from anemoi.utils.grids import ...
|
|
65
67
|
def xyz_to_latlon(x, y, z):
|
|
66
68
|
return (
|
|
67
69
|
np.rad2deg(np.arcsin(np.minimum(1.0, np.maximum(-1.0, z)))),
|
|
@@ -69,6 +71,8 @@ def xyz_to_latlon(x, y, z):
|
|
|
69
71
|
)
|
|
70
72
|
|
|
71
73
|
|
|
74
|
+
# TODO: Use the one from anemoi.utils.grids instead
|
|
75
|
+
# from anemoi.utils.grids import ...
|
|
72
76
|
def latlon_to_xyz(lat, lon, radius=1.0):
|
|
73
77
|
# https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates
|
|
74
78
|
# We assume that the Earth is a sphere of radius 1 so N(phi) = 1
|
|
@@ -152,7 +156,7 @@ def cutout_mask(
|
|
|
152
156
|
plot=None,
|
|
153
157
|
):
|
|
154
158
|
"""Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]"""
|
|
155
|
-
from scipy.spatial import
|
|
159
|
+
from scipy.spatial import cKDTree
|
|
156
160
|
|
|
157
161
|
# TODO: transform min_distance from lat/lon to xyz
|
|
158
162
|
|
|
@@ -195,13 +199,13 @@ def cutout_mask(
|
|
|
195
199
|
min_distance = min_distance_km / 6371.0
|
|
196
200
|
else:
|
|
197
201
|
points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km]
|
|
198
|
-
distances, _ =
|
|
202
|
+
distances, _ = cKDTree(points).query(points, k=2)
|
|
199
203
|
min_distance = np.min(distances[:, 1])
|
|
200
204
|
|
|
201
205
|
LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km")
|
|
202
206
|
|
|
203
|
-
# Use a
|
|
204
|
-
distances, indices =
|
|
207
|
+
# Use a cKDTree to find the nearest points
|
|
208
|
+
distances, indices = cKDTree(lam_points).query(global_points, k=neighbours)
|
|
205
209
|
|
|
206
210
|
# Centre of the Earth
|
|
207
211
|
zero = np.array([0.0, 0.0, 0.0])
|
|
@@ -255,7 +259,7 @@ def thinning_mask(
|
|
|
255
259
|
cropping_distance=2.0,
|
|
256
260
|
):
|
|
257
261
|
"""Return the list of points in [lats, lons] closest to [global_lats, global_lons]"""
|
|
258
|
-
from scipy.spatial import
|
|
262
|
+
from scipy.spatial import cKDTree
|
|
259
263
|
|
|
260
264
|
assert global_lats.ndim == 1
|
|
261
265
|
assert global_lons.ndim == 1
|
|
@@ -291,20 +295,20 @@ def thinning_mask(
|
|
|
291
295
|
xyx = latlon_to_xyz(lats, lons)
|
|
292
296
|
points = np.array(xyx).transpose()
|
|
293
297
|
|
|
294
|
-
# Use a
|
|
295
|
-
_, indices =
|
|
298
|
+
# Use a cKDTree to find the nearest points
|
|
299
|
+
_, indices = cKDTree(points).query(global_points, k=1)
|
|
296
300
|
|
|
297
301
|
return np.array([i for i in indices])
|
|
298
302
|
|
|
299
303
|
|
|
300
304
|
def outline(lats, lons, neighbours=5):
|
|
301
|
-
from scipy.spatial import
|
|
305
|
+
from scipy.spatial import cKDTree
|
|
302
306
|
|
|
303
307
|
xyx = latlon_to_xyz(lats, lons)
|
|
304
308
|
grid_points = np.array(xyx).transpose()
|
|
305
309
|
|
|
306
|
-
# Use a
|
|
307
|
-
_, indices =
|
|
310
|
+
# Use a cKDTree to find the nearest points
|
|
311
|
+
_, indices = cKDTree(grid_points).query(grid_points, k=neighbours)
|
|
308
312
|
|
|
309
313
|
# Centre of the Earth
|
|
310
314
|
zero = np.array([0.0, 0.0, 0.0])
|
|
@@ -379,6 +383,21 @@ def serialise_mask(mask):
|
|
|
379
383
|
return result
|
|
380
384
|
|
|
381
385
|
|
|
386
|
+
def nearest_grid_points(source_latitudes, source_longitudes, target_latitudes, target_longitudes):
|
|
387
|
+
# TODO: Use the one from anemoi.utils.grids instead
|
|
388
|
+
# from anemoi.utils.grids import ...
|
|
389
|
+
from scipy.spatial import cKDTree
|
|
390
|
+
|
|
391
|
+
source_xyz = latlon_to_xyz(source_latitudes, source_longitudes)
|
|
392
|
+
source_points = np.array(source_xyz).transpose()
|
|
393
|
+
|
|
394
|
+
target_xyz = latlon_to_xyz(target_latitudes, target_longitudes)
|
|
395
|
+
target_points = np.array(target_xyz).transpose()
|
|
396
|
+
|
|
397
|
+
_, indices = cKDTree(source_points).query(target_points, k=1)
|
|
398
|
+
return indices
|
|
399
|
+
|
|
400
|
+
|
|
382
401
|
if __name__ == "__main__":
|
|
383
402
|
global_lats, global_lons = np.meshgrid(
|
|
384
403
|
np.linspace(90, -90, 90),
|