anemoi-datasets 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/compare.py +59 -0
- anemoi/datasets/commands/create.py +84 -3
- anemoi/datasets/commands/inspect.py +3 -3
- anemoi/datasets/create/__init__.py +44 -17
- anemoi/datasets/create/check.py +6 -5
- anemoi/datasets/create/chunks.py +1 -1
- anemoi/datasets/create/config.py +5 -26
- anemoi/datasets/create/functions/filters/rename.py +9 -1
- anemoi/datasets/create/functions/filters/rotate_winds.py +10 -1
- anemoi/datasets/create/functions/sources/__init__.py +39 -0
- anemoi/datasets/create/functions/sources/accumulations.py +11 -41
- anemoi/datasets/create/functions/sources/constants.py +3 -0
- anemoi/datasets/create/functions/sources/grib.py +4 -0
- anemoi/datasets/create/functions/sources/hindcasts.py +32 -377
- anemoi/datasets/create/functions/sources/mars.py +53 -22
- anemoi/datasets/create/functions/sources/netcdf.py +2 -60
- anemoi/datasets/create/functions/sources/opendap.py +3 -2
- anemoi/datasets/create/functions/sources/xarray/__init__.py +73 -0
- anemoi/datasets/create/functions/sources/xarray/coordinates.py +234 -0
- anemoi/datasets/create/functions/sources/xarray/field.py +109 -0
- anemoi/datasets/create/functions/sources/xarray/fieldlist.py +171 -0
- anemoi/datasets/create/functions/sources/xarray/flavour.py +330 -0
- anemoi/datasets/create/functions/sources/xarray/grid.py +46 -0
- anemoi/datasets/create/functions/sources/xarray/metadata.py +161 -0
- anemoi/datasets/create/functions/sources/xarray/time.py +98 -0
- anemoi/datasets/create/functions/sources/xarray/variable.py +198 -0
- anemoi/datasets/create/functions/sources/xarray_kerchunk.py +42 -0
- anemoi/datasets/create/functions/sources/xarray_zarr.py +15 -0
- anemoi/datasets/create/functions/sources/zenodo.py +40 -0
- anemoi/datasets/create/input.py +290 -172
- anemoi/datasets/create/loaders.py +120 -71
- anemoi/datasets/create/patch.py +17 -14
- anemoi/datasets/create/persistent.py +1 -1
- anemoi/datasets/create/size.py +4 -5
- anemoi/datasets/create/statistics/__init__.py +49 -16
- anemoi/datasets/create/template.py +11 -61
- anemoi/datasets/create/trace.py +91 -0
- anemoi/datasets/create/utils.py +0 -48
- anemoi/datasets/create/zarr.py +24 -10
- anemoi/datasets/data/misc.py +9 -37
- anemoi/datasets/data/stores.py +29 -14
- anemoi/datasets/dates/__init__.py +7 -1
- anemoi/datasets/dates/groups.py +3 -0
- {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/METADATA +18 -3
- anemoi_datasets-0.4.2.dist-info/RECORD +86 -0
- {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/WHEEL +1 -1
- anemoi_datasets-0.4.0.dist-info/RECORD +0 -73
- {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/top_level.txt +0 -0
|
@@ -89,20 +89,23 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa
|
|
|
89
89
|
continue
|
|
90
90
|
print("---")
|
|
91
91
|
print(f"❗ Negative variance for {name=}, variance={y}")
|
|
92
|
-
print(f"
|
|
92
|
+
print(f" min={minimum[i]} max={maximum[i]} mean={mean[i]} count={count[i]} sums={sums[i]} squares={squares[i]}")
|
|
93
93
|
print(f" -> sums: min={np.min(sums[i])}, max={np.max(sums[i])}, argmin={np.argmin(sums[i])}")
|
|
94
94
|
print(f" -> squares: min={np.min(squares[i])}, max={np.max(squares[i])}, argmin={np.argmin(squares[i])}")
|
|
95
95
|
print(f" -> count: min={np.min(count[i])}, max={np.max(count[i])}, argmin={np.argmin(count[i])}")
|
|
96
|
+
print(
|
|
97
|
+
f" squares / count - mean * mean = {squares[i] / count[i]} - {mean[i] * mean[i]} = {squares[i] / count[i] - mean[i] * mean[i]}"
|
|
98
|
+
)
|
|
96
99
|
|
|
97
100
|
raise ValueError("Negative variance")
|
|
98
101
|
|
|
99
102
|
|
|
100
|
-
def compute_statistics(array, check_variables_names=None,
|
|
103
|
+
def compute_statistics(array, check_variables_names=None, allow_nans=False):
|
|
101
104
|
"""Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary."""
|
|
102
105
|
|
|
103
106
|
nvars = array.shape[1]
|
|
104
107
|
|
|
105
|
-
LOG.
|
|
108
|
+
LOG.debug(f"Stats {nvars}, {array.shape}, {check_variables_names}")
|
|
106
109
|
if check_variables_names:
|
|
107
110
|
assert nvars == len(check_variables_names), (nvars, check_variables_names)
|
|
108
111
|
stats_shape = (array.shape[0], nvars)
|
|
@@ -118,7 +121,7 @@ def compute_statistics(array, check_variables_names=None, allow_nan=False):
|
|
|
118
121
|
values = chunk.reshape((nvars, -1))
|
|
119
122
|
|
|
120
123
|
for j, name in enumerate(check_variables_names):
|
|
121
|
-
check_data_values(values[j, :], name=name,
|
|
124
|
+
check_data_values(values[j, :], name=name, allow_nans=allow_nans)
|
|
122
125
|
if np.isnan(values[j, :]).all():
|
|
123
126
|
# LOG.warning(f"All NaN values for {name} ({j}) for date {i}")
|
|
124
127
|
raise ValueError(f"All NaN values for {name} ({j}) for date {i}")
|
|
@@ -179,12 +182,12 @@ class TmpStatistics:
|
|
|
179
182
|
pickle.dump((key, dates, data), f)
|
|
180
183
|
shutil.move(tmp_path, path)
|
|
181
184
|
|
|
182
|
-
LOG.
|
|
185
|
+
LOG.debug(f"Written statistics data for {len(dates)} dates in {path} ({dates})")
|
|
183
186
|
|
|
184
187
|
def _gather_data(self):
|
|
185
188
|
# use glob to read all pickles
|
|
186
189
|
files = glob.glob(self.dirname + "/*.npz")
|
|
187
|
-
LOG.
|
|
190
|
+
LOG.debug(f"Reading stats data, found {len(files)} files in {self.dirname}")
|
|
188
191
|
assert len(files) > 0, f"No files found in {self.dirname}"
|
|
189
192
|
for f in files:
|
|
190
193
|
with open(f, "rb") as f:
|
|
@@ -211,17 +214,17 @@ def normalise_dates(dates):
|
|
|
211
214
|
class StatAggregator:
|
|
212
215
|
NAMES = ["minimum", "maximum", "sums", "squares", "count", "has_nans"]
|
|
213
216
|
|
|
214
|
-
def __init__(self, owner, dates, variables_names,
|
|
217
|
+
def __init__(self, owner, dates, variables_names, allow_nans):
|
|
215
218
|
dates = sorted(dates)
|
|
216
219
|
dates = to_datetimes(dates)
|
|
217
220
|
assert dates, "No dates selected"
|
|
218
221
|
self.owner = owner
|
|
219
222
|
self.dates = dates
|
|
220
223
|
self.variables_names = variables_names
|
|
221
|
-
self.
|
|
224
|
+
self.allow_nans = allow_nans
|
|
222
225
|
|
|
223
226
|
self.shape = (len(self.dates), len(self.variables_names))
|
|
224
|
-
LOG.
|
|
227
|
+
LOG.debug(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}")
|
|
225
228
|
|
|
226
229
|
self.minimum = np.full(self.shape, np.nan, dtype=np.float64)
|
|
227
230
|
self.maximum = np.full(self.shape, np.nan, dtype=np.float64)
|
|
@@ -284,7 +287,7 @@ class StatAggregator:
|
|
|
284
287
|
assert d in found, f"Statistics for date {d} not precomputed."
|
|
285
288
|
assert len(self.dates) == len(found), "Not all dates found in precomputed statistics"
|
|
286
289
|
assert len(self.dates) == offset, "Not all dates found in precomputed statistics."
|
|
287
|
-
LOG.
|
|
290
|
+
LOG.debug(f"Statistics for {len(found)} dates found.")
|
|
288
291
|
|
|
289
292
|
def aggregate(self):
|
|
290
293
|
minimum = np.nanmin(self.minimum, axis=0)
|
|
@@ -298,13 +301,43 @@ class StatAggregator:
|
|
|
298
301
|
assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape
|
|
299
302
|
|
|
300
303
|
x = squares / count - mean * mean
|
|
301
|
-
# remove negative variance due to numerical errors
|
|
302
|
-
# x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0
|
|
303
|
-
check_variance(x, self.variables_names, minimum, maximum, mean, count, sums, squares)
|
|
304
|
-
stdev = np.sqrt(x)
|
|
305
304
|
|
|
306
|
-
|
|
307
|
-
|
|
305
|
+
# def fix_variance(x, name, minimum, maximum, mean, count, sums, squares):
|
|
306
|
+
# assert x.shape == minimum.shape == maximum.shape == mean.shape == count.shape == sums.shape == squares.shape
|
|
307
|
+
# assert x.shape == (1,)
|
|
308
|
+
# x, minimum, maximum, mean, count, sums, squares = x[0], minimum[0], maximum[0], mean[0], count[0], sums[0], squares[0]
|
|
309
|
+
# if x >= 0:
|
|
310
|
+
# return x
|
|
311
|
+
#
|
|
312
|
+
# order = np.sqrt((squares / count + mean * mean)/2)
|
|
313
|
+
# range = maximum - minimum
|
|
314
|
+
# LOG.warning(f"Negative variance for {name=}, variance={x}")
|
|
315
|
+
# LOG.warning(f"square / count - mean * mean = {squares / count} - {mean * mean} = {squares / count - mean * mean}")
|
|
316
|
+
# LOG.warning(f"Variable order of magnitude is {order}.")
|
|
317
|
+
# LOG.warning(f"Range is {range} ({maximum=} - {minimum=}).")
|
|
318
|
+
# LOG.warning(f"Count is {count}.")
|
|
319
|
+
# if abs(x) < order * 1e-6 and abs(x) < range * 1e-6:
|
|
320
|
+
# LOG.warning(f"Variance is negative but very small, setting to 0.")
|
|
321
|
+
# return x*0
|
|
322
|
+
# return x
|
|
323
|
+
|
|
324
|
+
for i, name in enumerate(self.variables_names):
|
|
325
|
+
# remove negative variance due to numerical errors
|
|
326
|
+
# Not needed for now, fix_variance is disabled
|
|
327
|
+
# x[i] = fix_variance(x[i:i+1], name, minimum[i:i+1], maximum[i:i+1], mean[i:i+1], count[i:i+1], sums[i:i+1], squares[i:i+1])
|
|
328
|
+
check_variance(
|
|
329
|
+
x[i : i + 1],
|
|
330
|
+
[name],
|
|
331
|
+
minimum[i : i + 1],
|
|
332
|
+
maximum[i : i + 1],
|
|
333
|
+
mean[i : i + 1],
|
|
334
|
+
count[i : i + 1],
|
|
335
|
+
sums[i : i + 1],
|
|
336
|
+
squares[i : i + 1],
|
|
337
|
+
)
|
|
338
|
+
check_data_values(np.array([mean[i]]), name=name, allow_nans=False)
|
|
339
|
+
|
|
340
|
+
stdev = np.sqrt(x)
|
|
308
341
|
|
|
309
342
|
return Summary(
|
|
310
343
|
minimum=minimum,
|
|
@@ -8,72 +8,16 @@
|
|
|
8
8
|
#
|
|
9
9
|
|
|
10
10
|
import logging
|
|
11
|
-
import os
|
|
12
11
|
import re
|
|
13
12
|
import textwrap
|
|
14
13
|
from functools import wraps
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
TRACE_INDENT = 0
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def step(action_path):
|
|
22
|
-
return f"[{'.'.join(action_path)}]"
|
|
15
|
+
from anemoi.utils.humanize import plural
|
|
23
16
|
|
|
17
|
+
from .trace import step
|
|
18
|
+
from .trace import trace
|
|
24
19
|
|
|
25
|
-
|
|
26
|
-
if os.environ.get("ANEMOI_DATASET_TRACE_CREATE") is None:
|
|
27
|
-
return
|
|
28
|
-
print(emoji, " " * TRACE_INDENT, *args)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def trace_datasource(method):
|
|
32
|
-
@wraps(method)
|
|
33
|
-
def wrapper(self, *args, **kwargs):
|
|
34
|
-
global TRACE_INDENT
|
|
35
|
-
trace(
|
|
36
|
-
"🌍",
|
|
37
|
-
"=>",
|
|
38
|
-
step(self.action_path),
|
|
39
|
-
self._trace_datasource(*args, **kwargs),
|
|
40
|
-
)
|
|
41
|
-
TRACE_INDENT += 1
|
|
42
|
-
result = method(self, *args, **kwargs)
|
|
43
|
-
TRACE_INDENT -= 1
|
|
44
|
-
trace(
|
|
45
|
-
"🍎",
|
|
46
|
-
"<=",
|
|
47
|
-
step(self.action_path),
|
|
48
|
-
textwrap.shorten(repr(result), 256),
|
|
49
|
-
)
|
|
50
|
-
return result
|
|
51
|
-
|
|
52
|
-
return wrapper
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def trace_select(method):
|
|
56
|
-
@wraps(method)
|
|
57
|
-
def wrapper(self, *args, **kwargs):
|
|
58
|
-
global TRACE_INDENT
|
|
59
|
-
trace(
|
|
60
|
-
"👓",
|
|
61
|
-
"=>",
|
|
62
|
-
".".join(self.action_path),
|
|
63
|
-
self._trace_select(*args, **kwargs),
|
|
64
|
-
)
|
|
65
|
-
TRACE_INDENT += 1
|
|
66
|
-
result = method(self, *args, **kwargs)
|
|
67
|
-
TRACE_INDENT -= 1
|
|
68
|
-
trace(
|
|
69
|
-
"🍍",
|
|
70
|
-
"<=",
|
|
71
|
-
".".join(self.action_path),
|
|
72
|
-
textwrap.shorten(repr(result), 256),
|
|
73
|
-
)
|
|
74
|
-
return result
|
|
75
|
-
|
|
76
|
-
return wrapper
|
|
20
|
+
LOG = logging.getLogger(__name__)
|
|
77
21
|
|
|
78
22
|
|
|
79
23
|
def notify_result(method):
|
|
@@ -99,7 +43,13 @@ class Context:
|
|
|
99
43
|
self.used_references.add(key)
|
|
100
44
|
|
|
101
45
|
def notify_result(self, key, result):
|
|
102
|
-
trace(
|
|
46
|
+
trace(
|
|
47
|
+
"🎯",
|
|
48
|
+
step(key),
|
|
49
|
+
"notify result",
|
|
50
|
+
textwrap.shorten(repr(result).replace(",", ", "), width=40),
|
|
51
|
+
plural(len(result), "field"),
|
|
52
|
+
)
|
|
103
53
|
assert isinstance(key, (list, tuple)), key
|
|
104
54
|
key = tuple(key)
|
|
105
55
|
if key in self.used_references:
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import textwrap
|
|
12
|
+
import threading
|
|
13
|
+
from functools import wraps
|
|
14
|
+
|
|
15
|
+
LOG = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
thread_local = threading.local()
|
|
19
|
+
TRACE = 0
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def enable_trace(on_off):
|
|
23
|
+
global TRACE
|
|
24
|
+
TRACE = on_off
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def step(action_path):
|
|
28
|
+
return f"[{'.'.join(action_path)}]"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def trace(emoji, *args):
|
|
32
|
+
|
|
33
|
+
if not TRACE:
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
if not hasattr(thread_local, "TRACE_INDENT"):
|
|
37
|
+
thread_local.TRACE_INDENT = 0
|
|
38
|
+
|
|
39
|
+
print(emoji, " " * thread_local.TRACE_INDENT, *args)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def trace_datasource(method):
|
|
43
|
+
@wraps(method)
|
|
44
|
+
def wrapper(self, *args, **kwargs):
|
|
45
|
+
|
|
46
|
+
if not hasattr(thread_local, "TRACE_INDENT"):
|
|
47
|
+
thread_local.TRACE_INDENT = 0
|
|
48
|
+
|
|
49
|
+
trace(
|
|
50
|
+
"🌍",
|
|
51
|
+
"=>",
|
|
52
|
+
step(self.action_path),
|
|
53
|
+
self._trace_datasource(*args, **kwargs),
|
|
54
|
+
)
|
|
55
|
+
thread_local.TRACE_INDENT += 1
|
|
56
|
+
result = method(self, *args, **kwargs)
|
|
57
|
+
thread_local.TRACE_INDENT -= 1
|
|
58
|
+
trace(
|
|
59
|
+
"🍎",
|
|
60
|
+
"<=",
|
|
61
|
+
step(self.action_path),
|
|
62
|
+
textwrap.shorten(repr(result), 256),
|
|
63
|
+
)
|
|
64
|
+
return result
|
|
65
|
+
|
|
66
|
+
return wrapper
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def trace_select(method):
|
|
70
|
+
@wraps(method)
|
|
71
|
+
def wrapper(self, *args, **kwargs):
|
|
72
|
+
if not hasattr(thread_local, "TRACE_INDENT"):
|
|
73
|
+
thread_local.TRACE_INDENT = 0
|
|
74
|
+
trace(
|
|
75
|
+
"👓",
|
|
76
|
+
"=>",
|
|
77
|
+
".".join(self.action_path),
|
|
78
|
+
self._trace_select(*args, **kwargs),
|
|
79
|
+
)
|
|
80
|
+
thread_local.TRACE_INDENT += 1
|
|
81
|
+
result = method(self, *args, **kwargs)
|
|
82
|
+
thread_local.TRACE_INDENT -= 1
|
|
83
|
+
trace(
|
|
84
|
+
"🍍",
|
|
85
|
+
"<=",
|
|
86
|
+
".".join(self.action_path),
|
|
87
|
+
textwrap.shorten(repr(result), 256),
|
|
88
|
+
)
|
|
89
|
+
return result
|
|
90
|
+
|
|
91
|
+
return wrapper
|
anemoi/datasets/create/utils.py
CHANGED
|
@@ -7,15 +7,11 @@
|
|
|
7
7
|
# nor does it submit to any jurisdiction.
|
|
8
8
|
#
|
|
9
9
|
|
|
10
|
-
import json
|
|
11
10
|
import os
|
|
12
11
|
from contextlib import contextmanager
|
|
13
12
|
|
|
14
13
|
import numpy as np
|
|
15
|
-
import yaml
|
|
16
14
|
from earthkit.data import settings
|
|
17
|
-
from earthkit.data.utils.humanize import seconds # noqa: F401
|
|
18
|
-
from tqdm.auto import tqdm
|
|
19
15
|
|
|
20
16
|
|
|
21
17
|
def cache_context(dirname):
|
|
@@ -31,26 +27,6 @@ def cache_context(dirname):
|
|
|
31
27
|
return settings.temporary({"cache-policy": "user", "user-cache-directory": dirname})
|
|
32
28
|
|
|
33
29
|
|
|
34
|
-
def bytes(n):
|
|
35
|
-
""">>> bytes(4096)
|
|
36
|
-
'4 KiB'
|
|
37
|
-
>>> bytes(4000)
|
|
38
|
-
'3.9 KiB'
|
|
39
|
-
"""
|
|
40
|
-
if n < 0:
|
|
41
|
-
sign = "-"
|
|
42
|
-
n -= 0
|
|
43
|
-
else:
|
|
44
|
-
sign = ""
|
|
45
|
-
|
|
46
|
-
u = ["", " KiB", " MiB", " GiB", " TiB", " PiB", " EiB", " ZiB", " YiB"]
|
|
47
|
-
i = 0
|
|
48
|
-
while n >= 1024:
|
|
49
|
-
n /= 1024.0
|
|
50
|
-
i += 1
|
|
51
|
-
return "%s%g%s" % (sign, int(n * 10 + 0.5) / 10.0, u[i])
|
|
52
|
-
|
|
53
|
-
|
|
54
30
|
def to_datetime_list(*args, **kwargs):
|
|
55
31
|
from earthkit.data.utils.dates import to_datetime_list as to_datetime_list_
|
|
56
32
|
|
|
@@ -63,15 +39,6 @@ def to_datetime(*args, **kwargs):
|
|
|
63
39
|
return to_datetime_(*args, **kwargs)
|
|
64
40
|
|
|
65
41
|
|
|
66
|
-
def load_json_or_yaml(path):
|
|
67
|
-
with open(path, "r") as f:
|
|
68
|
-
if path.endswith(".json"):
|
|
69
|
-
return json.load(f)
|
|
70
|
-
if path.endswith(".yaml") or path.endswith(".yml"):
|
|
71
|
-
return yaml.safe_load(f)
|
|
72
|
-
raise ValueError(f"Cannot read file {path}. Need json or yaml with appropriate extension.")
|
|
73
|
-
|
|
74
|
-
|
|
75
42
|
def make_list_int(value):
|
|
76
43
|
if isinstance(value, str):
|
|
77
44
|
if "/" not in value:
|
|
@@ -118,18 +85,3 @@ def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]
|
|
|
118
85
|
assert d1 == d2, (i, d1, d2)
|
|
119
86
|
|
|
120
87
|
return dates_
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def progress_bar(*, iterable=None, total=None, initial=0, desc=None):
|
|
124
|
-
return tqdm(
|
|
125
|
-
iterable=iterable,
|
|
126
|
-
total=total,
|
|
127
|
-
initial=initial,
|
|
128
|
-
unit_scale=True,
|
|
129
|
-
unit_divisor=1024,
|
|
130
|
-
unit="B",
|
|
131
|
-
disable=False,
|
|
132
|
-
leave=False,
|
|
133
|
-
desc=desc,
|
|
134
|
-
# dynamic_ncols=True, # make this the default?
|
|
135
|
-
)
|
anemoi/datasets/create/zarr.py
CHANGED
|
@@ -24,8 +24,12 @@ def add_zarr_dataset(
|
|
|
24
24
|
shape=None,
|
|
25
25
|
array=None,
|
|
26
26
|
overwrite=True,
|
|
27
|
+
dimensions=None,
|
|
27
28
|
**kwargs,
|
|
28
29
|
):
|
|
30
|
+
assert dimensions is not None, "Please pass dimensions to add_zarr_dataset."
|
|
31
|
+
assert isinstance(dimensions, (tuple, list))
|
|
32
|
+
|
|
29
33
|
if dtype is None:
|
|
30
34
|
assert array is not None, (name, shape, array, dtype, zarr_root)
|
|
31
35
|
dtype = array.dtype
|
|
@@ -44,6 +48,7 @@ def add_zarr_dataset(
|
|
|
44
48
|
**kwargs,
|
|
45
49
|
)
|
|
46
50
|
a[...] = array
|
|
51
|
+
a.attrs["_ARRAY_DIMENSIONS"] = dimensions
|
|
47
52
|
return a
|
|
48
53
|
|
|
49
54
|
if "fill_value" not in kwargs:
|
|
@@ -69,6 +74,7 @@ def add_zarr_dataset(
|
|
|
69
74
|
overwrite=overwrite,
|
|
70
75
|
**kwargs,
|
|
71
76
|
)
|
|
77
|
+
a.attrs["_ARRAY_DIMENSIONS"] = dimensions
|
|
72
78
|
return a
|
|
73
79
|
|
|
74
80
|
|
|
@@ -79,22 +85,27 @@ class ZarrBuiltRegistry:
|
|
|
79
85
|
flags = None
|
|
80
86
|
z = None
|
|
81
87
|
|
|
82
|
-
def __init__(self, path, synchronizer_path=None):
|
|
88
|
+
def __init__(self, path, synchronizer_path=None, use_threads=False):
|
|
83
89
|
import zarr
|
|
84
90
|
|
|
85
91
|
assert isinstance(path, str), path
|
|
86
92
|
self.zarr_path = path
|
|
87
93
|
|
|
88
|
-
if
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
94
|
+
if use_threads:
|
|
95
|
+
self.synchronizer = zarr.ThreadSynchronizer()
|
|
96
|
+
self.synchronizer_path = None
|
|
97
|
+
else:
|
|
98
|
+
if synchronizer_path is None:
|
|
99
|
+
synchronizer_path = self.zarr_path + ".sync"
|
|
100
|
+
self.synchronizer_path = synchronizer_path
|
|
101
|
+
self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path)
|
|
92
102
|
|
|
93
103
|
def clean(self):
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
104
|
+
if self.synchronizer_path is not None:
|
|
105
|
+
try:
|
|
106
|
+
shutil.rmtree(self.synchronizer_path)
|
|
107
|
+
except FileNotFoundError:
|
|
108
|
+
pass
|
|
98
109
|
|
|
99
110
|
def _open_write(self):
|
|
100
111
|
import zarr
|
|
@@ -112,7 +123,7 @@ class ZarrBuiltRegistry:
|
|
|
112
123
|
def new_dataset(self, *args, **kwargs):
|
|
113
124
|
z = self._open_write()
|
|
114
125
|
zarr_root = z["_build"]
|
|
115
|
-
add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, **kwargs)
|
|
126
|
+
add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs)
|
|
116
127
|
|
|
117
128
|
def add_to_history(self, action, **kwargs):
|
|
118
129
|
new = dict(
|
|
@@ -143,6 +154,9 @@ class ZarrBuiltRegistry:
|
|
|
143
154
|
z.attrs["latest_write_timestamp"] = datetime.datetime.utcnow().isoformat()
|
|
144
155
|
z["_build"][self.name_flags][i] = value
|
|
145
156
|
|
|
157
|
+
def ready(self):
|
|
158
|
+
return all(self.get_flags())
|
|
159
|
+
|
|
146
160
|
def create(self, lengths, overwrite=False):
|
|
147
161
|
self.new_dataset(name=self.name_lengths, array=np.array(lengths, dtype="i4"))
|
|
148
162
|
self.new_dataset(name=self.name_flags, array=np.array([False] * len(lengths), dtype=bool))
|
anemoi/datasets/data/misc.py
CHANGED
|
@@ -8,63 +8,35 @@
|
|
|
8
8
|
import calendar
|
|
9
9
|
import datetime
|
|
10
10
|
import logging
|
|
11
|
-
import os
|
|
12
11
|
import re
|
|
13
12
|
from pathlib import PurePath
|
|
14
13
|
|
|
15
14
|
import numpy as np
|
|
16
15
|
import zarr
|
|
16
|
+
from anemoi.utils.config import load_config as load_settings
|
|
17
17
|
|
|
18
18
|
from .dataset import Dataset
|
|
19
19
|
|
|
20
20
|
LOG = logging.getLogger(__name__)
|
|
21
21
|
|
|
22
|
-
CONFIG = None
|
|
23
22
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
except ImportError:
|
|
27
|
-
import tomli as tomllib
|
|
23
|
+
def load_config():
|
|
24
|
+
return load_settings(defaults={"datasets": {"named": {}, "path": []}})
|
|
28
25
|
|
|
29
26
|
|
|
30
27
|
def add_named_dataset(name, path, **kwargs):
|
|
31
|
-
load_config()
|
|
32
|
-
if name
|
|
28
|
+
config = load_config()
|
|
29
|
+
if name["datasets"]["named"]:
|
|
33
30
|
raise ValueError(f"Dataset {name} already exists")
|
|
34
31
|
|
|
35
|
-
|
|
32
|
+
config["datasets"]["named"][name] = path
|
|
36
33
|
|
|
37
34
|
|
|
38
35
|
def add_dataset_path(path):
|
|
39
|
-
load_config()
|
|
40
|
-
|
|
41
|
-
if path not in CONFIG["datasets"]["path"]:
|
|
42
|
-
CONFIG["datasets"]["path"].append(path)
|
|
43
|
-
|
|
44
|
-
# save_config()
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def load_config():
|
|
48
|
-
global CONFIG
|
|
49
|
-
if CONFIG is not None:
|
|
50
|
-
return CONFIG
|
|
51
|
-
|
|
52
|
-
conf = os.path.expanduser("~/.config/anemoi/settings.toml")
|
|
53
|
-
if not os.path.exists(conf):
|
|
54
|
-
conf = os.path.expanduser("~/.anemoi.toml")
|
|
55
|
-
|
|
56
|
-
if os.path.exists(conf):
|
|
57
|
-
|
|
58
|
-
with open(conf, "rb") as f:
|
|
59
|
-
CONFIG = tomllib.load(f)
|
|
60
|
-
else:
|
|
61
|
-
CONFIG = {}
|
|
62
|
-
|
|
63
|
-
CONFIG.setdefault("datasets", {})
|
|
64
|
-
CONFIG["datasets"].setdefault("path", [])
|
|
65
|
-
CONFIG["datasets"].setdefault("named", {})
|
|
36
|
+
config = load_config()
|
|
66
37
|
|
|
67
|
-
|
|
38
|
+
if path not in config["datasets"]["path"]:
|
|
39
|
+
config["datasets"]["path"].append(path)
|
|
68
40
|
|
|
69
41
|
|
|
70
42
|
def _frequency_to_hours(frequency):
|
anemoi/datasets/data/stores.py
CHANGED
|
@@ -9,6 +9,7 @@ import logging
|
|
|
9
9
|
import os
|
|
10
10
|
import warnings
|
|
11
11
|
from functools import cached_property
|
|
12
|
+
from urllib.parse import urlparse
|
|
12
13
|
|
|
13
14
|
import numpy as np
|
|
14
15
|
import zarr
|
|
@@ -40,7 +41,9 @@ class ReadOnlyStore(zarr.storage.BaseStore):
|
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
class HTTPStore(ReadOnlyStore):
|
|
43
|
-
"""We write our own HTTPStore because the one used by zarr (
|
|
44
|
+
"""We write our own HTTPStore because the one used by zarr (s3fs)
|
|
45
|
+
does not play well with fork() and multiprocessing.
|
|
46
|
+
"""
|
|
44
47
|
|
|
45
48
|
def __init__(self, url):
|
|
46
49
|
self.url = url
|
|
@@ -58,17 +61,16 @@ class HTTPStore(ReadOnlyStore):
|
|
|
58
61
|
|
|
59
62
|
|
|
60
63
|
class S3Store(ReadOnlyStore):
|
|
61
|
-
"""We write our own S3Store because the one used by zarr (
|
|
62
|
-
does not play well with fork()
|
|
63
|
-
|
|
64
|
+
"""We write our own S3Store because the one used by zarr (s3fs)
|
|
65
|
+
does not play well with fork(). We also get to control the s3 client
|
|
66
|
+
options using the anemoi configs.
|
|
64
67
|
"""
|
|
65
68
|
|
|
66
|
-
def __init__(self, url):
|
|
69
|
+
def __init__(self, url, region=None):
|
|
67
70
|
from anemoi.utils.s3 import s3_client
|
|
68
71
|
|
|
69
72
|
_, _, self.bucket, self.key = url.split("/", 3)
|
|
70
|
-
|
|
71
|
-
self.s3 = s3_client(self.bucket)
|
|
73
|
+
self.s3 = s3_client(self.bucket, region=region)
|
|
72
74
|
|
|
73
75
|
def __getitem__(self, key):
|
|
74
76
|
try:
|
|
@@ -101,15 +103,27 @@ class DebugStore(ReadOnlyStore):
|
|
|
101
103
|
return key in self.store
|
|
102
104
|
|
|
103
105
|
|
|
104
|
-
def
|
|
105
|
-
|
|
106
|
-
|
|
106
|
+
def name_to_zarr_store(path_or_url):
|
|
107
|
+
store = path_or_url
|
|
108
|
+
|
|
109
|
+
if store.startswith("s3://"):
|
|
110
|
+
store = S3Store(store)
|
|
107
111
|
|
|
108
|
-
|
|
112
|
+
elif store.startswith("http://") or store.startswith("https://"):
|
|
113
|
+
parsed = urlparse(store)
|
|
114
|
+
bits = parsed.netloc.split(".")
|
|
115
|
+
if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"):
|
|
116
|
+
s3_url = f"s3://{bits[0]}{parsed.path}"
|
|
117
|
+
store = S3Store(s3_url, region=bits[2])
|
|
118
|
+
else:
|
|
109
119
|
store = HTTPStore(store)
|
|
110
120
|
|
|
111
|
-
|
|
112
|
-
|
|
121
|
+
return store
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def open_zarr(path, dont_fail=False, cache=None):
|
|
125
|
+
try:
|
|
126
|
+
store = name_to_zarr_store(path)
|
|
113
127
|
|
|
114
128
|
if DEBUG_ZARR_LOADING:
|
|
115
129
|
if isinstance(store, str):
|
|
@@ -117,7 +131,8 @@ def open_zarr(path, dont_fail=False, cache=None):
|
|
|
117
131
|
|
|
118
132
|
if not os.path.isdir(store):
|
|
119
133
|
raise NotImplementedError(
|
|
120
|
-
"DEBUG_ZARR_LOADING is only implemented for DirectoryStore.
|
|
134
|
+
"DEBUG_ZARR_LOADING is only implemented for DirectoryStore. "
|
|
135
|
+
"Please disable it for other backends."
|
|
121
136
|
)
|
|
122
137
|
store = zarr.storage.DirectoryStore(store)
|
|
123
138
|
store = DebugStore(store)
|
|
@@ -96,7 +96,7 @@ class ValuesDates(Dates):
|
|
|
96
96
|
|
|
97
97
|
|
|
98
98
|
class StartEndDates(Dates):
|
|
99
|
-
def __init__(self, start, end, frequency=1, **kwargs):
|
|
99
|
+
def __init__(self, start, end, frequency=1, months=None, **kwargs):
|
|
100
100
|
frequency = frequency_to_hours(frequency)
|
|
101
101
|
|
|
102
102
|
def _(x):
|
|
@@ -128,6 +128,12 @@ class StartEndDates(Dates):
|
|
|
128
128
|
date = start
|
|
129
129
|
self.values = []
|
|
130
130
|
while date <= end:
|
|
131
|
+
|
|
132
|
+
if months is not None:
|
|
133
|
+
if date.month not in months:
|
|
134
|
+
date += increment
|
|
135
|
+
continue
|
|
136
|
+
|
|
131
137
|
self.values.append(date)
|
|
132
138
|
date += increment
|
|
133
139
|
|