anemoi-datasets 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/cleanup.py +44 -0
- anemoi/datasets/commands/create.py +50 -20
- anemoi/datasets/commands/finalise-additions.py +45 -0
- anemoi/datasets/commands/finalise.py +39 -0
- anemoi/datasets/commands/init-additions.py +45 -0
- anemoi/datasets/commands/init.py +67 -0
- anemoi/datasets/commands/inspect.py +1 -1
- anemoi/datasets/commands/load-additions.py +47 -0
- anemoi/datasets/commands/load.py +47 -0
- anemoi/datasets/commands/patch.py +39 -0
- anemoi/datasets/compute/recentre.py +1 -1
- anemoi/datasets/create/__init__.py +961 -146
- anemoi/datasets/create/check.py +5 -3
- anemoi/datasets/create/config.py +53 -2
- anemoi/datasets/create/functions/sources/accumulations.py +6 -22
- anemoi/datasets/create/functions/sources/hindcasts.py +27 -12
- anemoi/datasets/create/functions/sources/tendencies.py +1 -1
- anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
- anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
- anemoi/datasets/create/functions/sources/xarray/field.py +1 -1
- anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
- anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
- anemoi/datasets/create/functions/sources/xarray/metadata.py +27 -29
- anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
- anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
- anemoi/datasets/create/input.py +62 -25
- anemoi/datasets/create/statistics/__init__.py +39 -23
- anemoi/datasets/create/utils.py +3 -2
- anemoi/datasets/data/__init__.py +1 -0
- anemoi/datasets/data/concat.py +46 -2
- anemoi/datasets/data/dataset.py +109 -34
- anemoi/datasets/data/forwards.py +17 -8
- anemoi/datasets/data/grids.py +17 -3
- anemoi/datasets/data/interpolate.py +133 -0
- anemoi/datasets/data/misc.py +56 -66
- anemoi/datasets/data/missing.py +240 -0
- anemoi/datasets/data/select.py +7 -1
- anemoi/datasets/data/stores.py +3 -3
- anemoi/datasets/data/subset.py +47 -5
- anemoi/datasets/data/unchecked.py +20 -22
- anemoi/datasets/data/xy.py +125 -0
- anemoi/datasets/dates/__init__.py +33 -20
- anemoi/datasets/dates/groups.py +2 -2
- anemoi/datasets/grids.py +66 -48
- {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/METADATA +5 -5
- {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/RECORD +51 -41
- {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/loaders.py +0 -924
- {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/top_level.txt +0 -0
|
@@ -14,34 +14,32 @@ from functools import cached_property
|
|
|
14
14
|
import numpy as np
|
|
15
15
|
from earthkit.data.utils.array import ensure_backend
|
|
16
16
|
|
|
17
|
-
from anemoi.datasets.create.functions.sources.xarray.metadata import MDMapping
|
|
18
|
-
|
|
19
17
|
from .field import XArrayField
|
|
20
18
|
|
|
21
19
|
LOG = logging.getLogger(__name__)
|
|
22
20
|
|
|
23
21
|
|
|
24
22
|
class Variable:
|
|
25
|
-
def __init__(
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
*,
|
|
26
|
+
ds,
|
|
27
|
+
var,
|
|
28
|
+
coordinates,
|
|
29
|
+
grid,
|
|
30
|
+
time,
|
|
31
|
+
metadata,
|
|
32
|
+
array_backend=None,
|
|
33
|
+
):
|
|
26
34
|
self.ds = ds
|
|
27
35
|
self.var = var
|
|
28
36
|
|
|
29
37
|
self.grid = grid
|
|
30
38
|
self.coordinates = coordinates
|
|
31
39
|
|
|
32
|
-
# print("Variable", var.name)
|
|
33
|
-
# for c in coordinates:
|
|
34
|
-
# print(" ", c)
|
|
35
|
-
|
|
36
40
|
self._metadata = metadata.copy()
|
|
37
|
-
# self._metadata.update(var.attrs)
|
|
38
41
|
self._metadata.update({"variable": var.name})
|
|
39
42
|
|
|
40
|
-
# self._metadata.setdefault("level", None)
|
|
41
|
-
# self._metadata.setdefault("number", 0)
|
|
42
|
-
# self._metadata.setdefault("levtype", "sfc")
|
|
43
|
-
self._mapping = mapping
|
|
44
|
-
|
|
45
43
|
self.time = time
|
|
46
44
|
|
|
47
45
|
self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid)
|
|
@@ -51,23 +49,6 @@ class Variable:
|
|
|
51
49
|
self.length = math.prod(self.shape)
|
|
52
50
|
self.array_backend = ensure_backend(array_backend)
|
|
53
51
|
|
|
54
|
-
def update_metadata_mapping(self, kwargs):
|
|
55
|
-
|
|
56
|
-
result = {}
|
|
57
|
-
|
|
58
|
-
for k, v in kwargs.items():
|
|
59
|
-
if k == "param":
|
|
60
|
-
result[k] = "variable"
|
|
61
|
-
continue
|
|
62
|
-
|
|
63
|
-
for c in self.coordinates:
|
|
64
|
-
if k in c.mars_names:
|
|
65
|
-
for v in c.mars_names:
|
|
66
|
-
result[v] = c.variable.name
|
|
67
|
-
break
|
|
68
|
-
|
|
69
|
-
self._mapping = MDMapping(result)
|
|
70
|
-
|
|
71
52
|
@property
|
|
72
53
|
def name(self):
|
|
73
54
|
return self.var.name
|
|
@@ -111,17 +92,11 @@ class Variable:
|
|
|
111
92
|
kwargs = {k: v for k, v in zip(self.names, coords)}
|
|
112
93
|
return XArrayField(self, self.var.isel(kwargs))
|
|
113
94
|
|
|
114
|
-
@property
|
|
115
|
-
def mapping(self):
|
|
116
|
-
return self._mapping
|
|
117
|
-
|
|
118
95
|
def sel(self, missing, **kwargs):
|
|
119
96
|
|
|
120
97
|
if not kwargs:
|
|
121
98
|
return self
|
|
122
99
|
|
|
123
|
-
kwargs = self._mapping.from_user(kwargs)
|
|
124
|
-
|
|
125
100
|
k, v = kwargs.popitem()
|
|
126
101
|
|
|
127
102
|
c = self.by_name.get(k)
|
|
@@ -147,13 +122,15 @@ class Variable:
|
|
|
147
122
|
grid=self.grid,
|
|
148
123
|
time=self.time,
|
|
149
124
|
metadata=metadata,
|
|
150
|
-
mapping=self.mapping,
|
|
151
125
|
)
|
|
152
126
|
|
|
153
127
|
return variable.sel(missing, **kwargs)
|
|
154
128
|
|
|
155
129
|
def match(self, **kwargs):
|
|
156
|
-
|
|
130
|
+
|
|
131
|
+
if "param" in kwargs:
|
|
132
|
+
assert "variable" not in kwargs
|
|
133
|
+
kwargs["variable"] = kwargs.pop("param")
|
|
157
134
|
|
|
158
135
|
if "variable" in kwargs:
|
|
159
136
|
name = kwargs.pop("variable")
|
anemoi/datasets/create/input.py
CHANGED
|
@@ -106,30 +106,32 @@ def _data_request(data):
|
|
|
106
106
|
area = grid = None
|
|
107
107
|
|
|
108
108
|
for field in data:
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
if date is None:
|
|
113
|
-
date = field.datetime()["valid_time"]
|
|
114
|
-
|
|
115
|
-
if field.datetime()["valid_time"] != date:
|
|
116
|
-
continue
|
|
109
|
+
try:
|
|
110
|
+
if date is None:
|
|
111
|
+
date = field.datetime()["valid_time"]
|
|
117
112
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
levtype = as_mars.get("levtype", "sfc")
|
|
121
|
-
param = as_mars["param"]
|
|
122
|
-
levelist = as_mars.get("levelist", None)
|
|
123
|
-
area = field.mars_area
|
|
124
|
-
grid = field.mars_grid
|
|
113
|
+
if field.datetime()["valid_time"] != date:
|
|
114
|
+
continue
|
|
125
115
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
116
|
+
as_mars = field.metadata(namespace="mars")
|
|
117
|
+
if not as_mars:
|
|
118
|
+
continue
|
|
119
|
+
step = as_mars.get("step")
|
|
120
|
+
levtype = as_mars.get("levtype", "sfc")
|
|
121
|
+
param = as_mars["param"]
|
|
122
|
+
levelist = as_mars.get("levelist", None)
|
|
123
|
+
area = field.mars_area
|
|
124
|
+
grid = field.mars_grid
|
|
125
|
+
|
|
126
|
+
if levelist is None:
|
|
127
|
+
params_levels[levtype].add(param)
|
|
128
|
+
else:
|
|
129
|
+
params_levels[levtype].add((param, levelist))
|
|
130
130
|
|
|
131
|
-
|
|
132
|
-
|
|
131
|
+
if step:
|
|
132
|
+
params_steps[levtype].add((param, step))
|
|
133
|
+
except Exception:
|
|
134
|
+
LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True)
|
|
133
135
|
|
|
134
136
|
def sort(old_dic):
|
|
135
137
|
new_dic = {}
|
|
@@ -277,6 +279,9 @@ class Result:
|
|
|
277
279
|
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
|
278
280
|
args = args[0]
|
|
279
281
|
|
|
282
|
+
# print("Executing", self.action_path)
|
|
283
|
+
# print("Dates:", compress_dates(self.dates))
|
|
284
|
+
|
|
280
285
|
names = []
|
|
281
286
|
for a in args:
|
|
282
287
|
if isinstance(a, str):
|
|
@@ -285,14 +290,13 @@ class Result:
|
|
|
285
290
|
names += list(a.keys())
|
|
286
291
|
|
|
287
292
|
print(f"Building a {len(names)}D hypercube using", names)
|
|
288
|
-
|
|
289
293
|
ds = ds.order_by(*args, remapping=remapping, patches=patches)
|
|
290
|
-
user_coords = ds.unique_values(*names, remapping=remapping, patches=patches)
|
|
294
|
+
user_coords = ds.unique_values(*names, remapping=remapping, patches=patches, progress_bar=False)
|
|
291
295
|
|
|
292
296
|
print()
|
|
293
297
|
print("Number of unique values found for each coordinate:")
|
|
294
298
|
for k, v in user_coords.items():
|
|
295
|
-
print(f" {k:20}:", len(v))
|
|
299
|
+
print(f" {k:20}:", len(v), shorten_list(v, max_length=10))
|
|
296
300
|
print()
|
|
297
301
|
user_shape = tuple(len(v) for k, v in user_coords.items())
|
|
298
302
|
print("Shape of the hypercube :", user_shape)
|
|
@@ -305,13 +309,18 @@ class Result:
|
|
|
305
309
|
|
|
306
310
|
remapping = build_remapping(remapping, patches)
|
|
307
311
|
expected = set(itertools.product(*user_coords.values()))
|
|
312
|
+
extra = set()
|
|
308
313
|
|
|
309
314
|
if math.prod(user_shape) > len(ds):
|
|
310
315
|
print(f"This means that all the fields in the datasets do not exists for all combinations of {names}.")
|
|
311
316
|
|
|
312
317
|
for f in ds:
|
|
313
318
|
metadata = remapping(f.metadata)
|
|
314
|
-
|
|
319
|
+
key = tuple(metadata(n, default=None) for n in names)
|
|
320
|
+
if key in expected:
|
|
321
|
+
expected.remove(key)
|
|
322
|
+
else:
|
|
323
|
+
extra.add(key)
|
|
315
324
|
|
|
316
325
|
print("Missing fields:")
|
|
317
326
|
print()
|
|
@@ -321,7 +330,35 @@ class Result:
|
|
|
321
330
|
print("...", len(expected) - i - 1, "more")
|
|
322
331
|
break
|
|
323
332
|
|
|
333
|
+
print("Extra fields:")
|
|
334
|
+
print()
|
|
335
|
+
for i, f in enumerate(sorted(extra)):
|
|
336
|
+
print(" ", f)
|
|
337
|
+
if i >= 9 and len(extra) > 10:
|
|
338
|
+
print("...", len(extra) - i - 1, "more")
|
|
339
|
+
break
|
|
340
|
+
|
|
324
341
|
print()
|
|
342
|
+
print("Missing values:")
|
|
343
|
+
per_name = defaultdict(set)
|
|
344
|
+
for e in expected:
|
|
345
|
+
for n, v in zip(names, e):
|
|
346
|
+
per_name[n].add(v)
|
|
347
|
+
|
|
348
|
+
for n, v in per_name.items():
|
|
349
|
+
print(" ", n, len(v), shorten_list(sorted(v), max_length=10))
|
|
350
|
+
print()
|
|
351
|
+
|
|
352
|
+
print("Extra values:")
|
|
353
|
+
per_name = defaultdict(set)
|
|
354
|
+
for e in extra:
|
|
355
|
+
for n, v in zip(names, e):
|
|
356
|
+
per_name[n].add(v)
|
|
357
|
+
|
|
358
|
+
for n, v in per_name.items():
|
|
359
|
+
print(" ", n, len(v), shorten_list(sorted(v), max_length=10))
|
|
360
|
+
print()
|
|
361
|
+
|
|
325
362
|
print("To solve this issue, you can:")
|
|
326
363
|
print(
|
|
327
364
|
" - Provide a better selection, like 'step: 0' or 'level: 1000' to "
|
|
@@ -79,6 +79,37 @@ def to_datetimes(dates):
|
|
|
79
79
|
return [to_datetime(d) for d in dates]
|
|
80
80
|
|
|
81
81
|
|
|
82
|
+
def fix_variance(x, name, count, sums, squares):
|
|
83
|
+
assert count.shape == sums.shape == squares.shape
|
|
84
|
+
assert isinstance(x, float)
|
|
85
|
+
|
|
86
|
+
mean = sums / count
|
|
87
|
+
assert mean.shape == count.shape
|
|
88
|
+
|
|
89
|
+
if x >= 0:
|
|
90
|
+
return x
|
|
91
|
+
|
|
92
|
+
LOG.warning(f"Negative variance for {name=}, variance={x}")
|
|
93
|
+
magnitude = np.sqrt((squares / count + mean * mean) / 2)
|
|
94
|
+
LOG.warning(f"square / count - mean * mean = {squares/count} - {mean*mean} = {squares/count - mean*mean}")
|
|
95
|
+
LOG.warning(f"Variable span order of magnitude is {magnitude}.")
|
|
96
|
+
LOG.warning(f"Count is {count}.")
|
|
97
|
+
|
|
98
|
+
variances = squares / count - mean * mean
|
|
99
|
+
assert variances.shape == squares.shape == mean.shape
|
|
100
|
+
if all(variances >= 0):
|
|
101
|
+
LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.")
|
|
102
|
+
return 0
|
|
103
|
+
|
|
104
|
+
# if abs(x) < magnitude * 1e-6 and abs(x) < range * 1e-6:
|
|
105
|
+
# LOG.warning("Variance is negative but very small.")
|
|
106
|
+
# variances = squares / count - mean * mean
|
|
107
|
+
# return 0
|
|
108
|
+
|
|
109
|
+
LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).")
|
|
110
|
+
return x
|
|
111
|
+
|
|
112
|
+
|
|
82
113
|
def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares):
|
|
83
114
|
if (x >= 0).all():
|
|
84
115
|
return
|
|
@@ -292,39 +323,24 @@ class StatAggregator:
|
|
|
292
323
|
def aggregate(self):
|
|
293
324
|
minimum = np.nanmin(self.minimum, axis=0)
|
|
294
325
|
maximum = np.nanmax(self.maximum, axis=0)
|
|
326
|
+
|
|
295
327
|
sums = np.nansum(self.sums, axis=0)
|
|
296
328
|
squares = np.nansum(self.squares, axis=0)
|
|
297
329
|
count = np.nansum(self.count, axis=0)
|
|
298
330
|
has_nans = np.any(self.has_nans, axis=0)
|
|
299
|
-
|
|
331
|
+
assert sums.shape == count.shape == squares.shape == minimum.shape == maximum.shape
|
|
300
332
|
|
|
301
|
-
|
|
333
|
+
mean = sums / count
|
|
334
|
+
assert mean.shape == minimum.shape
|
|
302
335
|
|
|
303
336
|
x = squares / count - mean * mean
|
|
304
|
-
|
|
305
|
-
# def fix_variance(x, name, minimum, maximum, mean, count, sums, squares):
|
|
306
|
-
# assert x.shape == minimum.shape == maximum.shape == mean.shape == count.shape == sums.shape == squares.shape
|
|
307
|
-
# assert x.shape == (1,)
|
|
308
|
-
# x, minimum, maximum, mean, count, sums, squares = x[0], minimum[0], maximum[0], mean[0], count[0], sums[0], squares[0]
|
|
309
|
-
# if x >= 0:
|
|
310
|
-
# return x
|
|
311
|
-
#
|
|
312
|
-
# order = np.sqrt((squares / count + mean * mean)/2)
|
|
313
|
-
# range = maximum - minimum
|
|
314
|
-
# LOG.warning(f"Negative variance for {name=}, variance={x}")
|
|
315
|
-
# LOG.warning(f"square / count - mean * mean = {squares / count} - {mean * mean} = {squares / count - mean * mean}")
|
|
316
|
-
# LOG.warning(f"Variable order of magnitude is {order}.")
|
|
317
|
-
# LOG.warning(f"Range is {range} ({maximum=} - {minimum=}).")
|
|
318
|
-
# LOG.warning(f"Count is {count}.")
|
|
319
|
-
# if abs(x) < order * 1e-6 and abs(x) < range * 1e-6:
|
|
320
|
-
# LOG.warning(f"Variance is negative but very small, setting to 0.")
|
|
321
|
-
# return x*0
|
|
322
|
-
# return x
|
|
337
|
+
assert x.shape == minimum.shape
|
|
323
338
|
|
|
324
339
|
for i, name in enumerate(self.variables_names):
|
|
325
340
|
# remove negative variance due to numerical errors
|
|
326
|
-
|
|
327
|
-
|
|
341
|
+
x[i] = fix_variance(x[i], name, self.count[i : i + 1], self.sums[i : i + 1], self.squares[i : i + 1])
|
|
342
|
+
|
|
343
|
+
for i, name in enumerate(self.variables_names):
|
|
328
344
|
check_variance(
|
|
329
345
|
x[i : i + 1],
|
|
330
346
|
[name],
|
anemoi/datasets/create/utils.py
CHANGED
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
# nor does it submit to any jurisdiction.
|
|
8
8
|
#
|
|
9
9
|
|
|
10
|
+
import datetime
|
|
10
11
|
import os
|
|
11
12
|
from contextlib import contextmanager
|
|
12
13
|
|
|
@@ -61,10 +62,10 @@ def make_list_int(value):
|
|
|
61
62
|
|
|
62
63
|
|
|
63
64
|
def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]"):
|
|
64
|
-
assert isinstance(frequency,
|
|
65
|
+
assert isinstance(frequency, datetime.timedelta), frequency
|
|
65
66
|
start = np.datetime64(start)
|
|
66
67
|
end = np.datetime64(end)
|
|
67
|
-
delta = np.timedelta64(frequency
|
|
68
|
+
delta = np.timedelta64(frequency)
|
|
68
69
|
|
|
69
70
|
res = []
|
|
70
71
|
while start <= end:
|
anemoi/datasets/data/__init__.py
CHANGED
anemoi/datasets/data/concat.py
CHANGED
|
@@ -9,6 +9,7 @@ import logging
|
|
|
9
9
|
from functools import cached_property
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
12
13
|
|
|
13
14
|
from .debug import Node
|
|
14
15
|
from .debug import debug_indexing
|
|
@@ -102,20 +103,63 @@ class Concat(ConcatMixin, Combined):
|
|
|
102
103
|
def tree(self):
|
|
103
104
|
return Node(self, [d.tree() for d in self.datasets])
|
|
104
105
|
|
|
106
|
+
@classmethod
|
|
107
|
+
def check_dataset_compatibility(cls, datasets, fill_missing_gaps=False):
|
|
108
|
+
# Study the dates
|
|
109
|
+
ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets]
|
|
105
110
|
|
|
106
|
-
|
|
111
|
+
# Make sure the dates are disjoint
|
|
112
|
+
for i in range(len(ranges)):
|
|
113
|
+
r = ranges[i]
|
|
114
|
+
for j in range(i + 1, len(ranges)):
|
|
115
|
+
s = ranges[j]
|
|
116
|
+
if r[0] <= s[0] <= r[1] or r[0] <= s[1] <= r[1]:
|
|
117
|
+
raise ValueError(f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})")
|
|
118
|
+
|
|
119
|
+
# For now we should have the datasets in order with no gaps
|
|
120
|
+
|
|
121
|
+
frequency = frequency_to_timedelta(datasets[0].frequency)
|
|
122
|
+
result = []
|
|
123
|
+
|
|
124
|
+
for i in range(len(ranges) - 1):
|
|
125
|
+
result.append(datasets[i])
|
|
126
|
+
r = ranges[i]
|
|
127
|
+
s = ranges[i + 1]
|
|
128
|
+
if r[1] + frequency != s[0]:
|
|
129
|
+
if fill_missing_gaps:
|
|
130
|
+
from .missing import MissingDataset
|
|
131
|
+
|
|
132
|
+
result.append(MissingDataset(datasets[i], r[1] + frequency, s[0] - frequency))
|
|
133
|
+
else:
|
|
134
|
+
r = [str(e) for e in r]
|
|
135
|
+
s = [str(e) for e in s]
|
|
136
|
+
raise ValueError(
|
|
137
|
+
"Datasets must be sorted by dates, with no gaps: "
|
|
138
|
+
f"{r} and {s} ({datasets[i]} {datasets[i+1]})"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
result.append(datasets[-1])
|
|
142
|
+
assert len(result) >= len(datasets), (len(result), len(datasets))
|
|
143
|
+
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def concat_factory(args, kwargs):
|
|
107
148
|
|
|
108
149
|
datasets = kwargs.pop("concat")
|
|
150
|
+
fill_missing_gaps = kwargs.pop("fill_missing_gaps", False)
|
|
109
151
|
assert isinstance(datasets, (list, tuple))
|
|
110
152
|
assert len(args) == 0
|
|
111
153
|
|
|
112
154
|
assert isinstance(datasets, (list, tuple))
|
|
113
155
|
|
|
114
|
-
datasets = [_open(e
|
|
156
|
+
datasets = [_open(e) for e in datasets]
|
|
115
157
|
|
|
116
158
|
if len(datasets) == 1:
|
|
117
159
|
return datasets[0]._subset(**kwargs)
|
|
118
160
|
|
|
119
161
|
datasets, kwargs = _auto_adjust(datasets, kwargs)
|
|
120
162
|
|
|
163
|
+
datasets = Concat.check_dataset_compatibility(datasets, fill_missing_gaps)
|
|
164
|
+
|
|
121
165
|
return Concat(datasets)._subset(**kwargs)
|
anemoi/datasets/data/dataset.py
CHANGED
|
@@ -5,24 +5,37 @@
|
|
|
5
5
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
6
|
# nor does it submit to any jurisdiction.
|
|
7
7
|
|
|
8
|
+
import datetime
|
|
9
|
+
import json
|
|
8
10
|
import logging
|
|
9
11
|
import os
|
|
12
|
+
import pprint
|
|
10
13
|
import warnings
|
|
11
14
|
from functools import cached_property
|
|
12
15
|
|
|
16
|
+
from anemoi.utils.dates import frequency_to_seconds
|
|
17
|
+
from anemoi.utils.dates import frequency_to_string
|
|
18
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
19
|
+
|
|
13
20
|
LOG = logging.getLogger(__name__)
|
|
14
21
|
|
|
15
22
|
|
|
16
23
|
class Dataset:
|
|
17
24
|
arguments = {}
|
|
18
25
|
|
|
26
|
+
def mutate(self):
|
|
27
|
+
return self
|
|
28
|
+
|
|
29
|
+
def swap_with_parent(self, parent):
|
|
30
|
+
return parent
|
|
31
|
+
|
|
19
32
|
@cached_property
|
|
20
33
|
def _len(self):
|
|
21
34
|
return len(self)
|
|
22
35
|
|
|
23
36
|
def _subset(self, **kwargs):
|
|
24
37
|
if not kwargs:
|
|
25
|
-
return self
|
|
38
|
+
return self.mutate()
|
|
26
39
|
|
|
27
40
|
if "start" in kwargs or "end" in kwargs:
|
|
28
41
|
start = kwargs.pop("start", None)
|
|
@@ -30,37 +43,52 @@ class Dataset:
|
|
|
30
43
|
|
|
31
44
|
from .subset import Subset
|
|
32
45
|
|
|
33
|
-
return
|
|
46
|
+
return (
|
|
47
|
+
Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs).mutate()
|
|
48
|
+
)
|
|
34
49
|
|
|
35
50
|
if "frequency" in kwargs:
|
|
36
51
|
from .subset import Subset
|
|
37
52
|
|
|
53
|
+
if "interpolate_frequency" in kwargs:
|
|
54
|
+
raise ValueError("Cannot use both `frequency` and `interpolate_frequency`")
|
|
55
|
+
|
|
38
56
|
frequency = kwargs.pop("frequency")
|
|
39
|
-
return
|
|
57
|
+
return (
|
|
58
|
+
Subset(self, self._frequency_to_indices(frequency), dict(frequency=frequency))
|
|
59
|
+
._subset(**kwargs)
|
|
60
|
+
.mutate()
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if "interpolate_frequency" in kwargs:
|
|
64
|
+
from .interpolate import InterpolateFrequency
|
|
65
|
+
|
|
66
|
+
interpolate_frequency = kwargs.pop("interpolate_frequency")
|
|
67
|
+
return InterpolateFrequency(self, interpolate_frequency)._subset(**kwargs).mutate()
|
|
40
68
|
|
|
41
69
|
if "select" in kwargs:
|
|
42
70
|
from .select import Select
|
|
43
71
|
|
|
44
72
|
select = kwargs.pop("select")
|
|
45
|
-
return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs)
|
|
73
|
+
return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs).mutate()
|
|
46
74
|
|
|
47
75
|
if "drop" in kwargs:
|
|
48
76
|
from .select import Select
|
|
49
77
|
|
|
50
78
|
drop = kwargs.pop("drop")
|
|
51
|
-
return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs)
|
|
79
|
+
return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs).mutate()
|
|
52
80
|
|
|
53
81
|
if "reorder" in kwargs:
|
|
54
82
|
from .select import Select
|
|
55
83
|
|
|
56
84
|
reorder = kwargs.pop("reorder")
|
|
57
|
-
return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs)
|
|
85
|
+
return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate()
|
|
58
86
|
|
|
59
87
|
if "rename" in kwargs:
|
|
60
88
|
from .select import Rename
|
|
61
89
|
|
|
62
90
|
rename = kwargs.pop("rename")
|
|
63
|
-
return Rename(self, rename)._subset(**kwargs)
|
|
91
|
+
return Rename(self, rename)._subset(**kwargs).mutate()
|
|
64
92
|
|
|
65
93
|
if "statistics" in kwargs:
|
|
66
94
|
from ..data import open_dataset
|
|
@@ -68,20 +96,38 @@ class Dataset:
|
|
|
68
96
|
|
|
69
97
|
statistics = kwargs.pop("statistics")
|
|
70
98
|
|
|
71
|
-
return Statistics(self, open_dataset(statistics))._subset(**kwargs)
|
|
99
|
+
return Statistics(self, open_dataset(statistics))._subset(**kwargs).mutate()
|
|
72
100
|
|
|
73
101
|
if "thinning" in kwargs:
|
|
74
102
|
from .masked import Thinning
|
|
75
103
|
|
|
76
104
|
thinning = kwargs.pop("thinning")
|
|
77
105
|
method = kwargs.pop("method", "every-nth")
|
|
78
|
-
return Thinning(self, thinning, method)._subset(**kwargs)
|
|
106
|
+
return Thinning(self, thinning, method)._subset(**kwargs).mutate()
|
|
79
107
|
|
|
80
108
|
if "area" in kwargs:
|
|
81
109
|
from .masked import Cropping
|
|
82
110
|
|
|
83
111
|
bbox = kwargs.pop("area")
|
|
84
|
-
return Cropping(self, bbox)._subset(**kwargs)
|
|
112
|
+
return Cropping(self, bbox)._subset(**kwargs).mutate()
|
|
113
|
+
|
|
114
|
+
if "missing_dates" in kwargs:
|
|
115
|
+
from .missing import MissingDates
|
|
116
|
+
|
|
117
|
+
missing_dates = kwargs.pop("missing_dates")
|
|
118
|
+
return MissingDates(self, missing_dates)._subset(**kwargs).mutate()
|
|
119
|
+
|
|
120
|
+
if "skip_missing_dates" in kwargs:
|
|
121
|
+
from .missing import SkipMissingDates
|
|
122
|
+
|
|
123
|
+
if "expected_access" not in kwargs:
|
|
124
|
+
raise ValueError("`expected_access` is required with `skip_missing_dates`")
|
|
125
|
+
|
|
126
|
+
skip_missing_dates = kwargs.pop("skip_missing_dates")
|
|
127
|
+
expected_access = kwargs.pop("expected_access")
|
|
128
|
+
|
|
129
|
+
if skip_missing_dates:
|
|
130
|
+
return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate()
|
|
85
131
|
|
|
86
132
|
# Keep last
|
|
87
133
|
if "shuffle" in kwargs:
|
|
@@ -90,15 +136,14 @@ class Dataset:
|
|
|
90
136
|
shuffle = kwargs.pop("shuffle")
|
|
91
137
|
|
|
92
138
|
if shuffle:
|
|
93
|
-
return Subset(self, self._shuffle_indices(), dict(shuffle=True))._subset(**kwargs)
|
|
139
|
+
return Subset(self, self._shuffle_indices(), dict(shuffle=True))._subset(**kwargs).mutate()
|
|
94
140
|
|
|
95
141
|
raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs))
|
|
96
142
|
|
|
97
143
|
def _frequency_to_indices(self, frequency):
|
|
98
|
-
from .misc import _frequency_to_hours
|
|
99
144
|
|
|
100
|
-
requested_frequency =
|
|
101
|
-
dataset_frequency =
|
|
145
|
+
requested_frequency = frequency_to_seconds(frequency)
|
|
146
|
+
dataset_frequency = frequency_to_seconds(self.frequency)
|
|
102
147
|
assert requested_frequency % dataset_frequency == 0
|
|
103
148
|
# Question: where do we start? first date, or first date that is a multiple of the frequency?
|
|
104
149
|
step = requested_frequency // dataset_frequency
|
|
@@ -171,37 +216,71 @@ class Dataset:
|
|
|
171
216
|
import anemoi
|
|
172
217
|
|
|
173
218
|
def tidy(v):
|
|
174
|
-
if isinstance(v, (list, tuple)):
|
|
219
|
+
if isinstance(v, (list, tuple, set)):
|
|
175
220
|
return [tidy(i) for i in v]
|
|
176
221
|
if isinstance(v, dict):
|
|
177
222
|
return {k: tidy(v) for k, v in v.items()}
|
|
178
223
|
if isinstance(v, str) and v.startswith("/"):
|
|
179
224
|
return os.path.basename(v)
|
|
225
|
+
if isinstance(v, datetime.datetime):
|
|
226
|
+
return v.isoformat()
|
|
227
|
+
if isinstance(v, datetime.date):
|
|
228
|
+
return v.isoformat()
|
|
229
|
+
if isinstance(v, datetime.timedelta):
|
|
230
|
+
return frequency_to_string(v)
|
|
231
|
+
|
|
232
|
+
if isinstance(v, Dataset):
|
|
233
|
+
# That can happen in the `arguments`
|
|
234
|
+
# if a dataset is passed as an argument
|
|
235
|
+
return repr(v)
|
|
236
|
+
|
|
237
|
+
if isinstance(v, slice):
|
|
238
|
+
return (v.start, v.stop, v.step)
|
|
239
|
+
|
|
180
240
|
return v
|
|
181
241
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
)
|
|
242
|
+
md = dict(
|
|
243
|
+
version=anemoi.datasets.__version__,
|
|
244
|
+
arguments=self.arguments,
|
|
245
|
+
**self.dataset_metadata(),
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
try:
|
|
249
|
+
return json.loads(json.dumps(tidy(md)))
|
|
250
|
+
except Exception:
|
|
251
|
+
LOG.exception("Failed to serialize metadata")
|
|
252
|
+
pprint.pprint(md)
|
|
253
|
+
|
|
254
|
+
raise
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def start_date(self):
|
|
258
|
+
return self.dates[0]
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def end_date(self):
|
|
262
|
+
return self.dates[-1]
|
|
263
|
+
|
|
264
|
+
def dataset_metadata(self):
|
|
265
|
+
return dict(
|
|
266
|
+
specific=self.metadata_specific(),
|
|
267
|
+
frequency=self.frequency,
|
|
268
|
+
variables=self.variables,
|
|
269
|
+
shape=self.shape,
|
|
270
|
+
start_date=self.start_date.astype(str),
|
|
271
|
+
end_date=self.end_date.astype(str),
|
|
193
272
|
)
|
|
194
273
|
|
|
195
274
|
def metadata_specific(self, **kwargs):
|
|
196
275
|
action = self.__class__.__name__.lower()
|
|
197
|
-
assert isinstance(self.frequency,
|
|
276
|
+
# assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action)
|
|
198
277
|
return dict(
|
|
199
278
|
action=action,
|
|
200
279
|
variables=self.variables,
|
|
201
280
|
shape=self.shape,
|
|
202
|
-
frequency=self.frequency,
|
|
203
|
-
start_date=self.
|
|
204
|
-
end_date=self.
|
|
281
|
+
frequency=frequency_to_string(frequency_to_timedelta(self.frequency)),
|
|
282
|
+
start_date=self.start_date.astype(str),
|
|
283
|
+
end_date=self.end_date.astype(str),
|
|
205
284
|
**kwargs,
|
|
206
285
|
)
|
|
207
286
|
|
|
@@ -220,10 +299,6 @@ class Dataset:
|
|
|
220
299
|
if n.startswith("_") and not n.startswith("__"):
|
|
221
300
|
warnings.warn(f"Private method {n} is overriden in {ds.__class__.__name__}")
|
|
222
301
|
|
|
223
|
-
# for n in ('metadata_specific', 'tree', 'source'):
|
|
224
|
-
# if n not in overriden:
|
|
225
|
-
# warnings.warn(f"Method {n} is not overriden in {ds.__class__.__name__}")
|
|
226
|
-
|
|
227
302
|
def _repr_html_(self):
|
|
228
303
|
return self.tree().html()
|
|
229
304
|
|