anemoi-datasets 0.4.4__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/cleanup.py +44 -0
- anemoi/datasets/commands/create.py +50 -20
- anemoi/datasets/commands/finalise-additions.py +45 -0
- anemoi/datasets/commands/finalise.py +39 -0
- anemoi/datasets/commands/init-additions.py +45 -0
- anemoi/datasets/commands/init.py +67 -0
- anemoi/datasets/commands/inspect.py +1 -1
- anemoi/datasets/commands/load-additions.py +47 -0
- anemoi/datasets/commands/load.py +47 -0
- anemoi/datasets/commands/patch.py +39 -0
- anemoi/datasets/create/__init__.py +961 -146
- anemoi/datasets/create/check.py +5 -3
- anemoi/datasets/create/config.py +53 -2
- anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
- anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
- anemoi/datasets/create/functions/sources/xarray/field.py +1 -1
- anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
- anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
- anemoi/datasets/create/functions/sources/xarray/metadata.py +27 -29
- anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
- anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
- anemoi/datasets/create/input.py +23 -22
- anemoi/datasets/create/statistics/__init__.py +39 -23
- anemoi/datasets/create/utils.py +3 -2
- anemoi/datasets/data/__init__.py +1 -0
- anemoi/datasets/data/concat.py +46 -2
- anemoi/datasets/data/dataset.py +109 -34
- anemoi/datasets/data/forwards.py +17 -8
- anemoi/datasets/data/grids.py +17 -3
- anemoi/datasets/data/interpolate.py +133 -0
- anemoi/datasets/data/misc.py +56 -66
- anemoi/datasets/data/missing.py +240 -0
- anemoi/datasets/data/select.py +7 -1
- anemoi/datasets/data/stores.py +3 -3
- anemoi/datasets/data/subset.py +47 -5
- anemoi/datasets/data/unchecked.py +20 -22
- anemoi/datasets/data/xy.py +125 -0
- anemoi/datasets/dates/__init__.py +13 -66
- anemoi/datasets/dates/groups.py +2 -2
- anemoi/datasets/grids.py +66 -48
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/METADATA +5 -5
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/RECORD +47 -37
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/loaders.py +0 -936
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/top_level.txt +0 -0
|
@@ -79,6 +79,37 @@ def to_datetimes(dates):
|
|
|
79
79
|
return [to_datetime(d) for d in dates]
|
|
80
80
|
|
|
81
81
|
|
|
82
|
+
def fix_variance(x, name, count, sums, squares):
|
|
83
|
+
assert count.shape == sums.shape == squares.shape
|
|
84
|
+
assert isinstance(x, float)
|
|
85
|
+
|
|
86
|
+
mean = sums / count
|
|
87
|
+
assert mean.shape == count.shape
|
|
88
|
+
|
|
89
|
+
if x >= 0:
|
|
90
|
+
return x
|
|
91
|
+
|
|
92
|
+
LOG.warning(f"Negative variance for {name=}, variance={x}")
|
|
93
|
+
magnitude = np.sqrt((squares / count + mean * mean) / 2)
|
|
94
|
+
LOG.warning(f"square / count - mean * mean = {squares/count} - {mean*mean} = {squares/count - mean*mean}")
|
|
95
|
+
LOG.warning(f"Variable span order of magnitude is {magnitude}.")
|
|
96
|
+
LOG.warning(f"Count is {count}.")
|
|
97
|
+
|
|
98
|
+
variances = squares / count - mean * mean
|
|
99
|
+
assert variances.shape == squares.shape == mean.shape
|
|
100
|
+
if all(variances >= 0):
|
|
101
|
+
LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.")
|
|
102
|
+
return 0
|
|
103
|
+
|
|
104
|
+
# if abs(x) < magnitude * 1e-6 and abs(x) < range * 1e-6:
|
|
105
|
+
# LOG.warning("Variance is negative but very small.")
|
|
106
|
+
# variances = squares / count - mean * mean
|
|
107
|
+
# return 0
|
|
108
|
+
|
|
109
|
+
LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).")
|
|
110
|
+
return x
|
|
111
|
+
|
|
112
|
+
|
|
82
113
|
def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares):
|
|
83
114
|
if (x >= 0).all():
|
|
84
115
|
return
|
|
@@ -292,39 +323,24 @@ class StatAggregator:
|
|
|
292
323
|
def aggregate(self):
|
|
293
324
|
minimum = np.nanmin(self.minimum, axis=0)
|
|
294
325
|
maximum = np.nanmax(self.maximum, axis=0)
|
|
326
|
+
|
|
295
327
|
sums = np.nansum(self.sums, axis=0)
|
|
296
328
|
squares = np.nansum(self.squares, axis=0)
|
|
297
329
|
count = np.nansum(self.count, axis=0)
|
|
298
330
|
has_nans = np.any(self.has_nans, axis=0)
|
|
299
|
-
|
|
331
|
+
assert sums.shape == count.shape == squares.shape == minimum.shape == maximum.shape
|
|
300
332
|
|
|
301
|
-
|
|
333
|
+
mean = sums / count
|
|
334
|
+
assert mean.shape == minimum.shape
|
|
302
335
|
|
|
303
336
|
x = squares / count - mean * mean
|
|
304
|
-
|
|
305
|
-
# def fix_variance(x, name, minimum, maximum, mean, count, sums, squares):
|
|
306
|
-
# assert x.shape == minimum.shape == maximum.shape == mean.shape == count.shape == sums.shape == squares.shape
|
|
307
|
-
# assert x.shape == (1,)
|
|
308
|
-
# x, minimum, maximum, mean, count, sums, squares = x[0], minimum[0], maximum[0], mean[0], count[0], sums[0], squares[0]
|
|
309
|
-
# if x >= 0:
|
|
310
|
-
# return x
|
|
311
|
-
#
|
|
312
|
-
# order = np.sqrt((squares / count + mean * mean)/2)
|
|
313
|
-
# range = maximum - minimum
|
|
314
|
-
# LOG.warning(f"Negative variance for {name=}, variance={x}")
|
|
315
|
-
# LOG.warning(f"square / count - mean * mean = {squares / count} - {mean * mean} = {squares / count - mean * mean}")
|
|
316
|
-
# LOG.warning(f"Variable order of magnitude is {order}.")
|
|
317
|
-
# LOG.warning(f"Range is {range} ({maximum=} - {minimum=}).")
|
|
318
|
-
# LOG.warning(f"Count is {count}.")
|
|
319
|
-
# if abs(x) < order * 1e-6 and abs(x) < range * 1e-6:
|
|
320
|
-
# LOG.warning(f"Variance is negative but very small, setting to 0.")
|
|
321
|
-
# return x*0
|
|
322
|
-
# return x
|
|
337
|
+
assert x.shape == minimum.shape
|
|
323
338
|
|
|
324
339
|
for i, name in enumerate(self.variables_names):
|
|
325
340
|
# remove negative variance due to numerical errors
|
|
326
|
-
|
|
327
|
-
|
|
341
|
+
x[i] = fix_variance(x[i], name, self.count[i : i + 1], self.sums[i : i + 1], self.squares[i : i + 1])
|
|
342
|
+
|
|
343
|
+
for i, name in enumerate(self.variables_names):
|
|
328
344
|
check_variance(
|
|
329
345
|
x[i : i + 1],
|
|
330
346
|
[name],
|
anemoi/datasets/create/utils.py
CHANGED
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
# nor does it submit to any jurisdiction.
|
|
8
8
|
#
|
|
9
9
|
|
|
10
|
+
import datetime
|
|
10
11
|
import os
|
|
11
12
|
from contextlib import contextmanager
|
|
12
13
|
|
|
@@ -61,10 +62,10 @@ def make_list_int(value):
|
|
|
61
62
|
|
|
62
63
|
|
|
63
64
|
def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]"):
|
|
64
|
-
assert isinstance(frequency,
|
|
65
|
+
assert isinstance(frequency, datetime.timedelta), frequency
|
|
65
66
|
start = np.datetime64(start)
|
|
66
67
|
end = np.datetime64(end)
|
|
67
|
-
delta = np.timedelta64(frequency
|
|
68
|
+
delta = np.timedelta64(frequency)
|
|
68
69
|
|
|
69
70
|
res = []
|
|
70
71
|
while start <= end:
|
anemoi/datasets/data/__init__.py
CHANGED
anemoi/datasets/data/concat.py
CHANGED
|
@@ -9,6 +9,7 @@ import logging
|
|
|
9
9
|
from functools import cached_property
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
12
13
|
|
|
13
14
|
from .debug import Node
|
|
14
15
|
from .debug import debug_indexing
|
|
@@ -102,20 +103,63 @@ class Concat(ConcatMixin, Combined):
|
|
|
102
103
|
def tree(self):
|
|
103
104
|
return Node(self, [d.tree() for d in self.datasets])
|
|
104
105
|
|
|
106
|
+
@classmethod
|
|
107
|
+
def check_dataset_compatibility(cls, datasets, fill_missing_gaps=False):
|
|
108
|
+
# Study the dates
|
|
109
|
+
ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets]
|
|
105
110
|
|
|
106
|
-
|
|
111
|
+
# Make sure the dates are disjoint
|
|
112
|
+
for i in range(len(ranges)):
|
|
113
|
+
r = ranges[i]
|
|
114
|
+
for j in range(i + 1, len(ranges)):
|
|
115
|
+
s = ranges[j]
|
|
116
|
+
if r[0] <= s[0] <= r[1] or r[0] <= s[1] <= r[1]:
|
|
117
|
+
raise ValueError(f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})")
|
|
118
|
+
|
|
119
|
+
# For now we should have the datasets in order with no gaps
|
|
120
|
+
|
|
121
|
+
frequency = frequency_to_timedelta(datasets[0].frequency)
|
|
122
|
+
result = []
|
|
123
|
+
|
|
124
|
+
for i in range(len(ranges) - 1):
|
|
125
|
+
result.append(datasets[i])
|
|
126
|
+
r = ranges[i]
|
|
127
|
+
s = ranges[i + 1]
|
|
128
|
+
if r[1] + frequency != s[0]:
|
|
129
|
+
if fill_missing_gaps:
|
|
130
|
+
from .missing import MissingDataset
|
|
131
|
+
|
|
132
|
+
result.append(MissingDataset(datasets[i], r[1] + frequency, s[0] - frequency))
|
|
133
|
+
else:
|
|
134
|
+
r = [str(e) for e in r]
|
|
135
|
+
s = [str(e) for e in s]
|
|
136
|
+
raise ValueError(
|
|
137
|
+
"Datasets must be sorted by dates, with no gaps: "
|
|
138
|
+
f"{r} and {s} ({datasets[i]} {datasets[i+1]})"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
result.append(datasets[-1])
|
|
142
|
+
assert len(result) >= len(datasets), (len(result), len(datasets))
|
|
143
|
+
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def concat_factory(args, kwargs):
|
|
107
148
|
|
|
108
149
|
datasets = kwargs.pop("concat")
|
|
150
|
+
fill_missing_gaps = kwargs.pop("fill_missing_gaps", False)
|
|
109
151
|
assert isinstance(datasets, (list, tuple))
|
|
110
152
|
assert len(args) == 0
|
|
111
153
|
|
|
112
154
|
assert isinstance(datasets, (list, tuple))
|
|
113
155
|
|
|
114
|
-
datasets = [_open(e
|
|
156
|
+
datasets = [_open(e) for e in datasets]
|
|
115
157
|
|
|
116
158
|
if len(datasets) == 1:
|
|
117
159
|
return datasets[0]._subset(**kwargs)
|
|
118
160
|
|
|
119
161
|
datasets, kwargs = _auto_adjust(datasets, kwargs)
|
|
120
162
|
|
|
163
|
+
datasets = Concat.check_dataset_compatibility(datasets, fill_missing_gaps)
|
|
164
|
+
|
|
121
165
|
return Concat(datasets)._subset(**kwargs)
|
anemoi/datasets/data/dataset.py
CHANGED
|
@@ -5,24 +5,37 @@
|
|
|
5
5
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
6
|
# nor does it submit to any jurisdiction.
|
|
7
7
|
|
|
8
|
+
import datetime
|
|
9
|
+
import json
|
|
8
10
|
import logging
|
|
9
11
|
import os
|
|
12
|
+
import pprint
|
|
10
13
|
import warnings
|
|
11
14
|
from functools import cached_property
|
|
12
15
|
|
|
16
|
+
from anemoi.utils.dates import frequency_to_seconds
|
|
17
|
+
from anemoi.utils.dates import frequency_to_string
|
|
18
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
19
|
+
|
|
13
20
|
LOG = logging.getLogger(__name__)
|
|
14
21
|
|
|
15
22
|
|
|
16
23
|
class Dataset:
|
|
17
24
|
arguments = {}
|
|
18
25
|
|
|
26
|
+
def mutate(self):
|
|
27
|
+
return self
|
|
28
|
+
|
|
29
|
+
def swap_with_parent(self, parent):
|
|
30
|
+
return parent
|
|
31
|
+
|
|
19
32
|
@cached_property
|
|
20
33
|
def _len(self):
|
|
21
34
|
return len(self)
|
|
22
35
|
|
|
23
36
|
def _subset(self, **kwargs):
|
|
24
37
|
if not kwargs:
|
|
25
|
-
return self
|
|
38
|
+
return self.mutate()
|
|
26
39
|
|
|
27
40
|
if "start" in kwargs or "end" in kwargs:
|
|
28
41
|
start = kwargs.pop("start", None)
|
|
@@ -30,37 +43,52 @@ class Dataset:
|
|
|
30
43
|
|
|
31
44
|
from .subset import Subset
|
|
32
45
|
|
|
33
|
-
return
|
|
46
|
+
return (
|
|
47
|
+
Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs).mutate()
|
|
48
|
+
)
|
|
34
49
|
|
|
35
50
|
if "frequency" in kwargs:
|
|
36
51
|
from .subset import Subset
|
|
37
52
|
|
|
53
|
+
if "interpolate_frequency" in kwargs:
|
|
54
|
+
raise ValueError("Cannot use both `frequency` and `interpolate_frequency`")
|
|
55
|
+
|
|
38
56
|
frequency = kwargs.pop("frequency")
|
|
39
|
-
return
|
|
57
|
+
return (
|
|
58
|
+
Subset(self, self._frequency_to_indices(frequency), dict(frequency=frequency))
|
|
59
|
+
._subset(**kwargs)
|
|
60
|
+
.mutate()
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if "interpolate_frequency" in kwargs:
|
|
64
|
+
from .interpolate import InterpolateFrequency
|
|
65
|
+
|
|
66
|
+
interpolate_frequency = kwargs.pop("interpolate_frequency")
|
|
67
|
+
return InterpolateFrequency(self, interpolate_frequency)._subset(**kwargs).mutate()
|
|
40
68
|
|
|
41
69
|
if "select" in kwargs:
|
|
42
70
|
from .select import Select
|
|
43
71
|
|
|
44
72
|
select = kwargs.pop("select")
|
|
45
|
-
return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs)
|
|
73
|
+
return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs).mutate()
|
|
46
74
|
|
|
47
75
|
if "drop" in kwargs:
|
|
48
76
|
from .select import Select
|
|
49
77
|
|
|
50
78
|
drop = kwargs.pop("drop")
|
|
51
|
-
return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs)
|
|
79
|
+
return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs).mutate()
|
|
52
80
|
|
|
53
81
|
if "reorder" in kwargs:
|
|
54
82
|
from .select import Select
|
|
55
83
|
|
|
56
84
|
reorder = kwargs.pop("reorder")
|
|
57
|
-
return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs)
|
|
85
|
+
return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate()
|
|
58
86
|
|
|
59
87
|
if "rename" in kwargs:
|
|
60
88
|
from .select import Rename
|
|
61
89
|
|
|
62
90
|
rename = kwargs.pop("rename")
|
|
63
|
-
return Rename(self, rename)._subset(**kwargs)
|
|
91
|
+
return Rename(self, rename)._subset(**kwargs).mutate()
|
|
64
92
|
|
|
65
93
|
if "statistics" in kwargs:
|
|
66
94
|
from ..data import open_dataset
|
|
@@ -68,20 +96,38 @@ class Dataset:
|
|
|
68
96
|
|
|
69
97
|
statistics = kwargs.pop("statistics")
|
|
70
98
|
|
|
71
|
-
return Statistics(self, open_dataset(statistics))._subset(**kwargs)
|
|
99
|
+
return Statistics(self, open_dataset(statistics))._subset(**kwargs).mutate()
|
|
72
100
|
|
|
73
101
|
if "thinning" in kwargs:
|
|
74
102
|
from .masked import Thinning
|
|
75
103
|
|
|
76
104
|
thinning = kwargs.pop("thinning")
|
|
77
105
|
method = kwargs.pop("method", "every-nth")
|
|
78
|
-
return Thinning(self, thinning, method)._subset(**kwargs)
|
|
106
|
+
return Thinning(self, thinning, method)._subset(**kwargs).mutate()
|
|
79
107
|
|
|
80
108
|
if "area" in kwargs:
|
|
81
109
|
from .masked import Cropping
|
|
82
110
|
|
|
83
111
|
bbox = kwargs.pop("area")
|
|
84
|
-
return Cropping(self, bbox)._subset(**kwargs)
|
|
112
|
+
return Cropping(self, bbox)._subset(**kwargs).mutate()
|
|
113
|
+
|
|
114
|
+
if "missing_dates" in kwargs:
|
|
115
|
+
from .missing import MissingDates
|
|
116
|
+
|
|
117
|
+
missing_dates = kwargs.pop("missing_dates")
|
|
118
|
+
return MissingDates(self, missing_dates)._subset(**kwargs).mutate()
|
|
119
|
+
|
|
120
|
+
if "skip_missing_dates" in kwargs:
|
|
121
|
+
from .missing import SkipMissingDates
|
|
122
|
+
|
|
123
|
+
if "expected_access" not in kwargs:
|
|
124
|
+
raise ValueError("`expected_access` is required with `skip_missing_dates`")
|
|
125
|
+
|
|
126
|
+
skip_missing_dates = kwargs.pop("skip_missing_dates")
|
|
127
|
+
expected_access = kwargs.pop("expected_access")
|
|
128
|
+
|
|
129
|
+
if skip_missing_dates:
|
|
130
|
+
return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate()
|
|
85
131
|
|
|
86
132
|
# Keep last
|
|
87
133
|
if "shuffle" in kwargs:
|
|
@@ -90,15 +136,14 @@ class Dataset:
|
|
|
90
136
|
shuffle = kwargs.pop("shuffle")
|
|
91
137
|
|
|
92
138
|
if shuffle:
|
|
93
|
-
return Subset(self, self._shuffle_indices(), dict(shuffle=True))._subset(**kwargs)
|
|
139
|
+
return Subset(self, self._shuffle_indices(), dict(shuffle=True))._subset(**kwargs).mutate()
|
|
94
140
|
|
|
95
141
|
raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs))
|
|
96
142
|
|
|
97
143
|
def _frequency_to_indices(self, frequency):
|
|
98
|
-
from .misc import _frequency_to_hours
|
|
99
144
|
|
|
100
|
-
requested_frequency =
|
|
101
|
-
dataset_frequency =
|
|
145
|
+
requested_frequency = frequency_to_seconds(frequency)
|
|
146
|
+
dataset_frequency = frequency_to_seconds(self.frequency)
|
|
102
147
|
assert requested_frequency % dataset_frequency == 0
|
|
103
148
|
# Question: where do we start? first date, or first date that is a multiple of the frequency?
|
|
104
149
|
step = requested_frequency // dataset_frequency
|
|
@@ -171,37 +216,71 @@ class Dataset:
|
|
|
171
216
|
import anemoi
|
|
172
217
|
|
|
173
218
|
def tidy(v):
|
|
174
|
-
if isinstance(v, (list, tuple)):
|
|
219
|
+
if isinstance(v, (list, tuple, set)):
|
|
175
220
|
return [tidy(i) for i in v]
|
|
176
221
|
if isinstance(v, dict):
|
|
177
222
|
return {k: tidy(v) for k, v in v.items()}
|
|
178
223
|
if isinstance(v, str) and v.startswith("/"):
|
|
179
224
|
return os.path.basename(v)
|
|
225
|
+
if isinstance(v, datetime.datetime):
|
|
226
|
+
return v.isoformat()
|
|
227
|
+
if isinstance(v, datetime.date):
|
|
228
|
+
return v.isoformat()
|
|
229
|
+
if isinstance(v, datetime.timedelta):
|
|
230
|
+
return frequency_to_string(v)
|
|
231
|
+
|
|
232
|
+
if isinstance(v, Dataset):
|
|
233
|
+
# That can happen in the `arguments`
|
|
234
|
+
# if a dataset is passed as an argument
|
|
235
|
+
return repr(v)
|
|
236
|
+
|
|
237
|
+
if isinstance(v, slice):
|
|
238
|
+
return (v.start, v.stop, v.step)
|
|
239
|
+
|
|
180
240
|
return v
|
|
181
241
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
)
|
|
242
|
+
md = dict(
|
|
243
|
+
version=anemoi.datasets.__version__,
|
|
244
|
+
arguments=self.arguments,
|
|
245
|
+
**self.dataset_metadata(),
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
try:
|
|
249
|
+
return json.loads(json.dumps(tidy(md)))
|
|
250
|
+
except Exception:
|
|
251
|
+
LOG.exception("Failed to serialize metadata")
|
|
252
|
+
pprint.pprint(md)
|
|
253
|
+
|
|
254
|
+
raise
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def start_date(self):
|
|
258
|
+
return self.dates[0]
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def end_date(self):
|
|
262
|
+
return self.dates[-1]
|
|
263
|
+
|
|
264
|
+
def dataset_metadata(self):
|
|
265
|
+
return dict(
|
|
266
|
+
specific=self.metadata_specific(),
|
|
267
|
+
frequency=self.frequency,
|
|
268
|
+
variables=self.variables,
|
|
269
|
+
shape=self.shape,
|
|
270
|
+
start_date=self.start_date.astype(str),
|
|
271
|
+
end_date=self.end_date.astype(str),
|
|
193
272
|
)
|
|
194
273
|
|
|
195
274
|
def metadata_specific(self, **kwargs):
|
|
196
275
|
action = self.__class__.__name__.lower()
|
|
197
|
-
assert isinstance(self.frequency,
|
|
276
|
+
# assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action)
|
|
198
277
|
return dict(
|
|
199
278
|
action=action,
|
|
200
279
|
variables=self.variables,
|
|
201
280
|
shape=self.shape,
|
|
202
|
-
frequency=self.frequency,
|
|
203
|
-
start_date=self.
|
|
204
|
-
end_date=self.
|
|
281
|
+
frequency=frequency_to_string(frequency_to_timedelta(self.frequency)),
|
|
282
|
+
start_date=self.start_date.astype(str),
|
|
283
|
+
end_date=self.end_date.astype(str),
|
|
205
284
|
**kwargs,
|
|
206
285
|
)
|
|
207
286
|
|
|
@@ -220,10 +299,6 @@ class Dataset:
|
|
|
220
299
|
if n.startswith("_") and not n.startswith("__"):
|
|
221
300
|
warnings.warn(f"Private method {n} is overriden in {ds.__class__.__name__}")
|
|
222
301
|
|
|
223
|
-
# for n in ('metadata_specific', 'tree', 'source'):
|
|
224
|
-
# if n not in overriden:
|
|
225
|
-
# warnings.warn(f"Method {n} is not overriden in {ds.__class__.__name__}")
|
|
226
|
-
|
|
227
302
|
def _repr_html_(self):
|
|
228
303
|
return self.tree().html()
|
|
229
304
|
|
anemoi/datasets/data/forwards.py
CHANGED
|
@@ -23,7 +23,7 @@ LOG = logging.getLogger(__name__)
|
|
|
23
23
|
|
|
24
24
|
class Forwards(Dataset):
|
|
25
25
|
def __init__(self, forward):
|
|
26
|
-
self.forward = forward
|
|
26
|
+
self.forward = forward.mutate()
|
|
27
27
|
|
|
28
28
|
def __len__(self):
|
|
29
29
|
return len(self.forward)
|
|
@@ -118,6 +118,9 @@ class Combined(Forwards):
|
|
|
118
118
|
# Forward most properties to the first dataset
|
|
119
119
|
super().__init__(datasets[0])
|
|
120
120
|
|
|
121
|
+
def mutate(self):
|
|
122
|
+
return self
|
|
123
|
+
|
|
121
124
|
def check_same_resolution(self, d1, d2):
|
|
122
125
|
if d1.resolution != d2.resolution:
|
|
123
126
|
raise ValueError(f"Incompatible resolutions: {d1.resolution} and {d2.resolution} ({d1} {d2})")
|
|
@@ -187,14 +190,9 @@ class Combined(Forwards):
|
|
|
187
190
|
**kwargs,
|
|
188
191
|
)
|
|
189
192
|
|
|
190
|
-
@
|
|
193
|
+
@property
|
|
191
194
|
def missing(self):
|
|
192
|
-
|
|
193
|
-
result = set()
|
|
194
|
-
for d in self.datasets:
|
|
195
|
-
result.update(offset + m for m in d.missing)
|
|
196
|
-
offset += len(d)
|
|
197
|
-
return result
|
|
195
|
+
raise NotImplementedError("missing() not implemented for Combined")
|
|
198
196
|
|
|
199
197
|
def get_dataset_names(self, names):
|
|
200
198
|
for d in self.datasets:
|
|
@@ -249,3 +247,14 @@ class GivenAxis(Combined):
|
|
|
249
247
|
return self._get_slice(n)
|
|
250
248
|
|
|
251
249
|
return np.concatenate([d[n] for d in self.datasets], axis=self.axis - 1)
|
|
250
|
+
|
|
251
|
+
@cached_property
|
|
252
|
+
def missing(self):
|
|
253
|
+
offset = 0
|
|
254
|
+
result = set()
|
|
255
|
+
for d in self.datasets:
|
|
256
|
+
print("--->", d.missing, d)
|
|
257
|
+
result.update(offset + m for m in d.missing)
|
|
258
|
+
if self.axis == 0: # Advance if axis is time
|
|
259
|
+
offset += len(d)
|
|
260
|
+
return result
|
anemoi/datasets/data/grids.py
CHANGED
|
@@ -128,7 +128,7 @@ class Grids(GridsBase):
|
|
|
128
128
|
|
|
129
129
|
|
|
130
130
|
class Cutout(GridsBase):
|
|
131
|
-
def __init__(self, datasets, axis):
|
|
131
|
+
def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False):
|
|
132
132
|
from anemoi.datasets.grids import cutout_mask
|
|
133
133
|
|
|
134
134
|
super().__init__(datasets, axis)
|
|
@@ -144,7 +144,10 @@ class Cutout(GridsBase):
|
|
|
144
144
|
self.lam.longitudes,
|
|
145
145
|
self.globe.latitudes,
|
|
146
146
|
self.globe.longitudes,
|
|
147
|
-
|
|
147
|
+
plot=plot,
|
|
148
|
+
min_distance_km=min_distance_km,
|
|
149
|
+
cropping_distance=cropping_distance,
|
|
150
|
+
neighbours=neighbours,
|
|
148
151
|
)
|
|
149
152
|
assert len(self.mask) == self.globe.shape[3], (
|
|
150
153
|
len(self.mask),
|
|
@@ -229,6 +232,10 @@ def cutout_factory(args, kwargs):
|
|
|
229
232
|
|
|
230
233
|
cutout = kwargs.pop("cutout")
|
|
231
234
|
axis = kwargs.pop("axis", 3)
|
|
235
|
+
plot = kwargs.pop("plot", None)
|
|
236
|
+
min_distance_km = kwargs.pop("min_distance_km", None)
|
|
237
|
+
cropping_distance = kwargs.pop("cropping_distance", 2.0)
|
|
238
|
+
neighbours = kwargs.pop("neighbours", 5)
|
|
232
239
|
|
|
233
240
|
assert len(args) == 0
|
|
234
241
|
assert isinstance(cutout, (list, tuple))
|
|
@@ -236,4 +243,11 @@ def cutout_factory(args, kwargs):
|
|
|
236
243
|
datasets = [_open(e) for e in cutout]
|
|
237
244
|
datasets, kwargs = _auto_adjust(datasets, kwargs)
|
|
238
245
|
|
|
239
|
-
return Cutout(
|
|
246
|
+
return Cutout(
|
|
247
|
+
datasets,
|
|
248
|
+
axis=axis,
|
|
249
|
+
neighbours=neighbours,
|
|
250
|
+
min_distance_km=min_distance_km,
|
|
251
|
+
cropping_distance=cropping_distance,
|
|
252
|
+
plot=plot,
|
|
253
|
+
)._subset(**kwargs)
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
|
|
2
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
4
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
|
+
# nor does it submit to any jurisdiction.
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from functools import cached_property
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
13
|
+
|
|
14
|
+
from .debug import Node
|
|
15
|
+
from .debug import debug_indexing
|
|
16
|
+
from .forwards import Forwards
|
|
17
|
+
from .indexing import apply_index_to_slices_changes
|
|
18
|
+
from .indexing import expand_list_indexing
|
|
19
|
+
from .indexing import index_to_slices
|
|
20
|
+
from .indexing import update_tuple
|
|
21
|
+
|
|
22
|
+
LOG = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class InterpolateFrequency(Forwards):
|
|
26
|
+
|
|
27
|
+
def __init__(self, dataset, frequency):
|
|
28
|
+
super().__init__(dataset)
|
|
29
|
+
self._frequency = frequency_to_timedelta(frequency)
|
|
30
|
+
|
|
31
|
+
self.seconds = self._frequency.total_seconds()
|
|
32
|
+
other_seconds = dataset.frequency.total_seconds()
|
|
33
|
+
|
|
34
|
+
self.seconds = int(self.seconds)
|
|
35
|
+
assert self.seconds == self._frequency.total_seconds()
|
|
36
|
+
|
|
37
|
+
other_seconds = int(other_seconds)
|
|
38
|
+
assert other_seconds == dataset.frequency.total_seconds()
|
|
39
|
+
|
|
40
|
+
if self.seconds >= other_seconds:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Interpolate frequency {self._frequency} must be more frequent than dataset frequency {dataset.frequency}"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
if other_seconds % self.seconds != 0:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Interpolate frequency {self._frequency} must be a multiple of the dataset frequency {dataset.frequency}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
self.ratio = other_seconds // self.seconds
|
|
51
|
+
self.alphas = np.linspace(0, 1, self.ratio + 1)
|
|
52
|
+
self.other_len = len(dataset)
|
|
53
|
+
|
|
54
|
+
@debug_indexing
|
|
55
|
+
@expand_list_indexing
|
|
56
|
+
def _get_tuple(self, index):
|
|
57
|
+
index, changes = index_to_slices(index, self.shape)
|
|
58
|
+
index, previous = update_tuple(index, 0, slice(None))
|
|
59
|
+
result = self._get_slice(previous)
|
|
60
|
+
return apply_index_to_slices_changes(result[index], changes)
|
|
61
|
+
|
|
62
|
+
def _get_slice(self, s):
|
|
63
|
+
return np.stack([self[i] for i in range(*s.indices(self._len))])
|
|
64
|
+
|
|
65
|
+
@debug_indexing
|
|
66
|
+
def __getitem__(self, n):
|
|
67
|
+
if isinstance(n, tuple):
|
|
68
|
+
return self._get_tuple(n)
|
|
69
|
+
|
|
70
|
+
if isinstance(n, slice):
|
|
71
|
+
return self._get_slice(n)
|
|
72
|
+
|
|
73
|
+
if n < 0:
|
|
74
|
+
n += self._len
|
|
75
|
+
|
|
76
|
+
if n == self._len - 1:
|
|
77
|
+
# Special case for the last element
|
|
78
|
+
return self.forward[-1]
|
|
79
|
+
|
|
80
|
+
i = n // self.ratio
|
|
81
|
+
x = n % self.ratio
|
|
82
|
+
|
|
83
|
+
if x == 0:
|
|
84
|
+
# No interpolation needed
|
|
85
|
+
return self.forward[i]
|
|
86
|
+
|
|
87
|
+
alpha = self.alphas[x]
|
|
88
|
+
|
|
89
|
+
assert 0 < alpha < 1, alpha
|
|
90
|
+
return self.forward[i] * (1 - alpha) + self.forward[i + 1] * alpha
|
|
91
|
+
|
|
92
|
+
def __len__(self):
|
|
93
|
+
return (self.other_len - 1) * self.ratio + 1
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def frequency(self):
|
|
97
|
+
return self._frequency
|
|
98
|
+
|
|
99
|
+
@cached_property
|
|
100
|
+
def dates(self):
|
|
101
|
+
result = []
|
|
102
|
+
deltas = [np.timedelta64(self.seconds * i, "s") for i in range(self.ratio)]
|
|
103
|
+
for d in self.forward.dates[:-1]:
|
|
104
|
+
for i in deltas:
|
|
105
|
+
result.append(d + i)
|
|
106
|
+
result.append(self.forward.dates[-1])
|
|
107
|
+
return np.array(result)
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def shape(self):
|
|
111
|
+
return (self._len,) + self.forward.shape[1:]
|
|
112
|
+
|
|
113
|
+
def tree(self):
|
|
114
|
+
return Node(self, [self.forward.tree()], frequency=self.frequency)
|
|
115
|
+
|
|
116
|
+
@cached_property
|
|
117
|
+
def missing(self):
|
|
118
|
+
result = []
|
|
119
|
+
j = 0
|
|
120
|
+
for i in range(self.other_len):
|
|
121
|
+
missing = i in self.forward.missing
|
|
122
|
+
for _ in range(self.ratio):
|
|
123
|
+
if missing:
|
|
124
|
+
result.append(j)
|
|
125
|
+
j += 1
|
|
126
|
+
|
|
127
|
+
result = set(x for x in result if x < self._len)
|
|
128
|
+
return result
|
|
129
|
+
|
|
130
|
+
def subclass_metadata_specific(self):
|
|
131
|
+
return {
|
|
132
|
+
# "frequency": frequency_to_string(self._frequency),
|
|
133
|
+
}
|