anemoi-datasets 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/cleanup.py +44 -0
- anemoi/datasets/commands/create.py +52 -21
- anemoi/datasets/commands/finalise-additions.py +45 -0
- anemoi/datasets/commands/finalise.py +39 -0
- anemoi/datasets/commands/init-additions.py +45 -0
- anemoi/datasets/commands/init.py +67 -0
- anemoi/datasets/commands/inspect.py +1 -1
- anemoi/datasets/commands/load-additions.py +47 -0
- anemoi/datasets/commands/load.py +47 -0
- anemoi/datasets/commands/patch.py +39 -0
- anemoi/datasets/create/__init__.py +959 -146
- anemoi/datasets/create/check.py +5 -3
- anemoi/datasets/create/config.py +54 -2
- anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +57 -0
- anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +57 -0
- anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +54 -0
- anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +59 -0
- anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +115 -0
- anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +390 -0
- anemoi/datasets/create/functions/filters/speeddir_to_uv.py +77 -0
- anemoi/datasets/create/functions/filters/uv_to_speeddir.py +55 -0
- anemoi/datasets/create/functions/sources/grib.py +86 -1
- anemoi/datasets/create/functions/sources/hindcasts.py +14 -73
- anemoi/datasets/create/functions/sources/mars.py +9 -3
- anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
- anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
- anemoi/datasets/create/functions/sources/xarray/field.py +8 -2
- anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
- anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
- anemoi/datasets/create/functions/sources/xarray/metadata.py +40 -40
- anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
- anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
- anemoi/datasets/create/input.py +62 -39
- anemoi/datasets/create/persistent.py +1 -1
- anemoi/datasets/create/statistics/__init__.py +39 -23
- anemoi/datasets/create/utils.py +6 -2
- anemoi/datasets/data/__init__.py +1 -0
- anemoi/datasets/data/concat.py +46 -2
- anemoi/datasets/data/dataset.py +119 -34
- anemoi/datasets/data/debug.py +5 -1
- anemoi/datasets/data/forwards.py +17 -8
- anemoi/datasets/data/grids.py +17 -3
- anemoi/datasets/data/interpolate.py +133 -0
- anemoi/datasets/data/masked.py +2 -2
- anemoi/datasets/data/misc.py +56 -66
- anemoi/datasets/data/missing.py +240 -0
- anemoi/datasets/data/rescale.py +147 -0
- anemoi/datasets/data/select.py +7 -1
- anemoi/datasets/data/stores.py +23 -10
- anemoi/datasets/data/subset.py +47 -5
- anemoi/datasets/data/unchecked.py +20 -22
- anemoi/datasets/data/xy.py +125 -0
- anemoi/datasets/dates/__init__.py +124 -95
- anemoi/datasets/dates/groups.py +85 -20
- anemoi/datasets/grids.py +66 -48
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/METADATA +8 -17
- anemoi_datasets-0.5.0.dist-info/RECORD +105 -0
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/loaders.py +0 -936
- anemoi_datasets-0.4.4.dist-info/RECORD +0 -86
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -7,196 +7,1009 @@
|
|
|
7
7
|
# nor does it submit to any jurisdiction.
|
|
8
8
|
#
|
|
9
9
|
|
|
10
|
+
import datetime
|
|
11
|
+
import json
|
|
10
12
|
import logging
|
|
11
13
|
import os
|
|
14
|
+
import time
|
|
15
|
+
import uuid
|
|
16
|
+
import warnings
|
|
17
|
+
from functools import cached_property
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import tqdm
|
|
21
|
+
from anemoi.utils.config import DotDict as DotDict
|
|
22
|
+
from anemoi.utils.dates import as_datetime
|
|
23
|
+
from anemoi.utils.dates import frequency_to_string
|
|
24
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
25
|
+
from anemoi.utils.humanize import compress_dates
|
|
26
|
+
from anemoi.utils.humanize import seconds_to_human
|
|
27
|
+
|
|
28
|
+
from anemoi.datasets import MissingDateError
|
|
29
|
+
from anemoi.datasets import open_dataset
|
|
30
|
+
from anemoi.datasets.create.persistent import build_storage
|
|
31
|
+
from anemoi.datasets.data.misc import as_first_date
|
|
32
|
+
from anemoi.datasets.data.misc import as_last_date
|
|
33
|
+
from anemoi.datasets.dates.groups import Groups
|
|
34
|
+
|
|
35
|
+
from .check import DatasetName
|
|
36
|
+
from .check import check_data_values
|
|
37
|
+
from .chunks import ChunkFilter
|
|
38
|
+
from .config import build_output
|
|
39
|
+
from .config import loader_config
|
|
40
|
+
from .input import build_input
|
|
41
|
+
from .statistics import Summary
|
|
42
|
+
from .statistics import TmpStatistics
|
|
43
|
+
from .statistics import check_variance
|
|
44
|
+
from .statistics import compute_statistics
|
|
45
|
+
from .statistics import default_statistics_dates
|
|
46
|
+
from .statistics import fix_variance
|
|
47
|
+
from .utils import normalize_and_check_dates
|
|
48
|
+
from .writer import ViewCacheArray
|
|
12
49
|
|
|
13
50
|
LOG = logging.getLogger(__name__)
|
|
14
51
|
|
|
52
|
+
VERSION = "0.20"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def json_tidy(o):
|
|
56
|
+
|
|
57
|
+
if isinstance(o, datetime.datetime):
|
|
58
|
+
return o.isoformat()
|
|
59
|
+
|
|
60
|
+
if isinstance(o, datetime.datetime):
|
|
61
|
+
return o.isoformat()
|
|
62
|
+
|
|
63
|
+
if isinstance(o, datetime.timedelta):
|
|
64
|
+
return frequency_to_string(o)
|
|
65
|
+
|
|
66
|
+
raise TypeError(repr(o) + " is not JSON serializable")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def build_statistics_dates(dates, start, end):
|
|
70
|
+
"""Compute the start and end dates for the statistics, based on :
|
|
71
|
+
- The start and end dates in the config
|
|
72
|
+
- The default statistics dates convention
|
|
73
|
+
|
|
74
|
+
Then adapt according to the actual dates in the dataset.
|
|
75
|
+
"""
|
|
76
|
+
# if not specified, use the default statistics dates
|
|
77
|
+
default_start, default_end = default_statistics_dates(dates)
|
|
78
|
+
if start is None:
|
|
79
|
+
start = default_start
|
|
80
|
+
if end is None:
|
|
81
|
+
end = default_end
|
|
82
|
+
|
|
83
|
+
# in any case, adapt to the actual dates in the dataset
|
|
84
|
+
start = as_first_date(start, dates)
|
|
85
|
+
end = as_last_date(end, dates)
|
|
86
|
+
|
|
87
|
+
# and convert to datetime to isoformat
|
|
88
|
+
start = start.astype(datetime.datetime)
|
|
89
|
+
end = end.astype(datetime.datetime)
|
|
90
|
+
return (start.isoformat(), end.isoformat())
|
|
91
|
+
|
|
15
92
|
|
|
16
93
|
def _ignore(*args, **kwargs):
|
|
17
94
|
pass
|
|
18
95
|
|
|
19
96
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
97
|
+
def _path_readable(path):
|
|
98
|
+
import zarr
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
zarr.open(path, "r")
|
|
102
|
+
return True
|
|
103
|
+
except zarr.errors.PathNotFoundError:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class Dataset:
|
|
108
|
+
def __init__(self, path):
|
|
109
|
+
self.path = path
|
|
110
|
+
|
|
111
|
+
_, ext = os.path.splitext(self.path)
|
|
112
|
+
if ext != ".zarr":
|
|
113
|
+
raise ValueError(f"Unsupported extension={ext} for path={self.path}")
|
|
114
|
+
|
|
115
|
+
def add_dataset(self, mode="r+", **kwargs):
|
|
116
|
+
import zarr
|
|
117
|
+
|
|
118
|
+
z = zarr.open(self.path, mode=mode)
|
|
119
|
+
from .zarr import add_zarr_dataset
|
|
120
|
+
|
|
121
|
+
return add_zarr_dataset(zarr_root=z, **kwargs)
|
|
122
|
+
|
|
123
|
+
def update_metadata(self, **kwargs):
|
|
124
|
+
import zarr
|
|
125
|
+
|
|
126
|
+
LOG.debug(f"Updating metadata {kwargs}")
|
|
127
|
+
z = zarr.open(self.path, mode="w+")
|
|
128
|
+
for k, v in kwargs.items():
|
|
129
|
+
if isinstance(v, np.datetime64):
|
|
130
|
+
v = v.astype(datetime.datetime)
|
|
131
|
+
if isinstance(v, datetime.date):
|
|
132
|
+
v = v.isoformat()
|
|
133
|
+
z.attrs[k] = json.loads(json.dumps(v, default=json_tidy))
|
|
134
|
+
|
|
135
|
+
@cached_property
|
|
136
|
+
def anemoi_dataset(self):
|
|
137
|
+
return open_dataset(self.path)
|
|
138
|
+
|
|
139
|
+
@cached_property
|
|
140
|
+
def zarr_metadata(self):
|
|
141
|
+
import zarr
|
|
142
|
+
|
|
143
|
+
return dict(zarr.open(self.path, mode="r").attrs)
|
|
144
|
+
|
|
145
|
+
def print_info(self):
|
|
146
|
+
import zarr
|
|
147
|
+
|
|
148
|
+
z = zarr.open(self.path, mode="r")
|
|
149
|
+
try:
|
|
150
|
+
LOG.info(z["data"].info)
|
|
151
|
+
except Exception as e:
|
|
152
|
+
LOG.info(e)
|
|
153
|
+
|
|
154
|
+
def get_zarr_chunks(self):
|
|
155
|
+
import zarr
|
|
156
|
+
|
|
157
|
+
z = zarr.open(self.path, mode="r")
|
|
158
|
+
return z["data"].chunks
|
|
159
|
+
|
|
160
|
+
def check_name(self, resolution, dates, frequency, raise_exception=True, is_test=False):
|
|
161
|
+
basename, _ = os.path.splitext(os.path.basename(self.path))
|
|
162
|
+
try:
|
|
163
|
+
DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid()
|
|
164
|
+
except Exception as e:
|
|
165
|
+
if raise_exception and not is_test:
|
|
166
|
+
raise e
|
|
167
|
+
else:
|
|
168
|
+
LOG.warning(f"Dataset name error: {e}")
|
|
169
|
+
|
|
170
|
+
def get_main_config(self):
|
|
171
|
+
"""Returns None if the config is not found."""
|
|
172
|
+
import zarr
|
|
173
|
+
|
|
174
|
+
z = zarr.open(self.path, mode="r")
|
|
175
|
+
return loader_config(z.attrs.get("_create_yaml_config"))
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class WritableDataset(Dataset):
|
|
179
|
+
def __init__(self, path):
|
|
180
|
+
super().__init__(path)
|
|
181
|
+
self.path = path
|
|
182
|
+
|
|
183
|
+
import zarr
|
|
184
|
+
|
|
185
|
+
self.z = zarr.open(self.path, mode="r+")
|
|
186
|
+
|
|
187
|
+
@cached_property
|
|
188
|
+
def data_array(self):
|
|
189
|
+
import zarr
|
|
190
|
+
|
|
191
|
+
return zarr.open(self.path, mode="r+")["data"]
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class NewDataset(Dataset):
|
|
195
|
+
def __init__(self, path, overwrite=False):
|
|
196
|
+
super().__init__(path)
|
|
197
|
+
self.path = path
|
|
198
|
+
|
|
199
|
+
import zarr
|
|
200
|
+
|
|
201
|
+
self.z = zarr.open(self.path, mode="w")
|
|
202
|
+
self.z.create_group("_build")
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class Actor: # TODO: rename to Creator
|
|
206
|
+
dataset_class = WritableDataset
|
|
207
|
+
|
|
208
|
+
def __init__(self, path, cache=None):
|
|
209
|
+
# Catch all floating point errors, including overflow, sqrt(<0), etc
|
|
210
|
+
np.seterr(all="raise", under="warn")
|
|
211
|
+
|
|
212
|
+
self.path = path
|
|
35
213
|
self.cache = cache
|
|
214
|
+
self.dataset = self.dataset_class(self.path)
|
|
215
|
+
|
|
216
|
+
def run(self):
|
|
217
|
+
# to be implemented in the sub-classes
|
|
218
|
+
raise NotImplementedError()
|
|
219
|
+
|
|
220
|
+
def update_metadata(self, **kwargs):
|
|
221
|
+
self.dataset.update_metadata(**kwargs)
|
|
222
|
+
|
|
223
|
+
def _cache_context(self):
|
|
224
|
+
from .utils import cache_context
|
|
225
|
+
|
|
226
|
+
return cache_context(self.cache)
|
|
227
|
+
|
|
228
|
+
def check_unkown_kwargs(self, kwargs):
|
|
229
|
+
# remove this latter
|
|
230
|
+
LOG.warning(f"💬 Unknown kwargs for {self.__class__.__name__}: {kwargs}")
|
|
231
|
+
|
|
232
|
+
def read_dataset_metadata(self, path):
|
|
233
|
+
ds = open_dataset(path)
|
|
234
|
+
self.dataset_shape = ds.shape
|
|
235
|
+
self.variables_names = ds.variables
|
|
236
|
+
assert len(self.variables_names) == ds.shape[1], self.dataset_shape
|
|
237
|
+
self.dates = ds.dates
|
|
238
|
+
|
|
239
|
+
self.missing_dates = sorted(list([self.dates[i] for i in ds.missing]))
|
|
240
|
+
|
|
241
|
+
def check_missing_dates(expected):
|
|
242
|
+
import zarr
|
|
243
|
+
|
|
244
|
+
z = zarr.open(path, "r")
|
|
245
|
+
missing_dates = z.attrs.get("missing_dates", [])
|
|
246
|
+
missing_dates = sorted([np.datetime64(d) for d in missing_dates])
|
|
247
|
+
if missing_dates != expected:
|
|
248
|
+
LOG.warning("Missing dates given in recipe do not match the actual missing dates in the dataset.")
|
|
249
|
+
LOG.warning(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}")
|
|
250
|
+
LOG.warning(f"Missing dates in dataset: {sorted(str(x) for x in expected)}")
|
|
251
|
+
raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.")
|
|
252
|
+
|
|
253
|
+
check_missing_dates(self.missing_dates)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class Patch(Actor):
|
|
257
|
+
def __init__(self, path, options=None, **kwargs):
|
|
258
|
+
self.path = path
|
|
259
|
+
self.options = options or {}
|
|
260
|
+
|
|
261
|
+
def run(self):
|
|
262
|
+
from .patch import apply_patch
|
|
263
|
+
|
|
264
|
+
apply_patch(self.path, **self.options)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class Size(Actor):
|
|
268
|
+
def __init__(self, path, **kwargs):
|
|
269
|
+
super().__init__(path)
|
|
270
|
+
|
|
271
|
+
def run(self):
|
|
272
|
+
from .size import compute_directory_sizes
|
|
273
|
+
|
|
274
|
+
metadata = compute_directory_sizes(self.path)
|
|
275
|
+
self.update_metadata(**metadata)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class HasRegistryMixin:
|
|
279
|
+
@cached_property
|
|
280
|
+
def registry(self):
|
|
281
|
+
from .zarr import ZarrBuiltRegistry
|
|
282
|
+
|
|
283
|
+
return ZarrBuiltRegistry(self.path, use_threads=self.use_threads)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class HasStatisticTempMixin:
|
|
287
|
+
@cached_property
|
|
288
|
+
def tmp_statistics(self):
|
|
289
|
+
directory = self.statistics_temp_dir or os.path.join(self.path + ".storage_for_statistics.tmp")
|
|
290
|
+
return TmpStatistics(directory)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class HasElementForDataMixin:
|
|
294
|
+
def create_elements(self, config):
|
|
295
|
+
|
|
296
|
+
assert self.registry
|
|
297
|
+
assert self.tmp_statistics
|
|
298
|
+
|
|
299
|
+
LOG.info(dict(config.dates))
|
|
300
|
+
|
|
301
|
+
self.groups = Groups(**config.dates)
|
|
302
|
+
LOG.info(self.groups)
|
|
303
|
+
|
|
304
|
+
self.output = build_output(config.output, parent=self)
|
|
305
|
+
|
|
306
|
+
self.input = build_input_(main_config=config, output_config=self.output)
|
|
307
|
+
LOG.info(self.input)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def build_input_(main_config, output_config):
|
|
311
|
+
from earthkit.data.core.order import build_remapping
|
|
312
|
+
|
|
313
|
+
builder = build_input(
|
|
314
|
+
main_config.input,
|
|
315
|
+
data_sources=main_config.get("data_sources", {}),
|
|
316
|
+
order_by=output_config.order_by,
|
|
317
|
+
flatten_grid=output_config.flatten_grid,
|
|
318
|
+
remapping=build_remapping(output_config.remapping),
|
|
319
|
+
use_grib_paramid=main_config.build.use_grib_paramid,
|
|
320
|
+
)
|
|
321
|
+
LOG.debug("✅ INPUT_BUILDER")
|
|
322
|
+
LOG.debug(builder)
|
|
323
|
+
return builder
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin):
|
|
327
|
+
dataset_class = NewDataset
|
|
328
|
+
def __init__(self, path, config, check_name=False, overwrite=False, use_threads=False, statistics_temp_dir=None, progress=None, test=False, cache=None, **kwargs): # fmt: skip
|
|
329
|
+
if _path_readable(path) and not overwrite:
|
|
330
|
+
raise Exception(f"{path} already exists. Use overwrite=True to overwrite.")
|
|
331
|
+
|
|
332
|
+
super().__init__(path, cache=cache)
|
|
333
|
+
self.config = config
|
|
334
|
+
self.check_name = check_name
|
|
36
335
|
self.use_threads = use_threads
|
|
37
|
-
self.
|
|
38
|
-
self.
|
|
336
|
+
self.statistics_temp_dir = statistics_temp_dir
|
|
337
|
+
self.progress = progress
|
|
39
338
|
self.test = test
|
|
40
|
-
self.progress = progress if progress is not None else _ignore
|
|
41
339
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
340
|
+
self.main_config = loader_config(config, is_test=test)
|
|
341
|
+
|
|
342
|
+
# self.registry.delete() ??
|
|
343
|
+
self.tmp_statistics.delete()
|
|
344
|
+
|
|
345
|
+
assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by
|
|
346
|
+
self.create_elements(self.main_config)
|
|
47
347
|
|
|
48
|
-
|
|
49
|
-
raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.")
|
|
348
|
+
LOG.info(f"Groups: {self.groups}")
|
|
50
349
|
|
|
350
|
+
one_date = self.groups.one_date()
|
|
351
|
+
# assert False, (type(one_date), type(self.groups))
|
|
352
|
+
self.minimal_input = self.input.select(one_date)
|
|
353
|
+
LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}")
|
|
354
|
+
LOG.info(self.minimal_input)
|
|
355
|
+
|
|
356
|
+
def run(self):
|
|
51
357
|
with self._cache_context():
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
358
|
+
return self._run()
|
|
359
|
+
|
|
360
|
+
def _run(self):
|
|
361
|
+
"""Create an empty dataset of the right final shape
|
|
362
|
+
|
|
363
|
+
Read a small part of the data to get the shape of the data and the resolution and more metadata.
|
|
364
|
+
"""
|
|
365
|
+
|
|
366
|
+
LOG.info("Config loaded ok:")
|
|
367
|
+
# LOG.info(self.main_config)
|
|
368
|
+
|
|
369
|
+
dates = self.groups.provider.values
|
|
370
|
+
frequency = self.groups.provider.frequency
|
|
371
|
+
missing = self.groups.provider.missing
|
|
372
|
+
|
|
373
|
+
assert isinstance(frequency, datetime.timedelta), frequency
|
|
374
|
+
|
|
375
|
+
LOG.info(f"Found {len(dates)} datetimes.")
|
|
376
|
+
LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ")
|
|
377
|
+
LOG.info(f"Missing dates: {len(missing)}")
|
|
378
|
+
lengths = tuple(len(g) for g in self.groups)
|
|
379
|
+
|
|
380
|
+
variables = self.minimal_input.variables
|
|
381
|
+
LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.")
|
|
382
|
+
|
|
383
|
+
variables_with_nans = self.main_config.statistics.get("allow_nans", [])
|
|
384
|
+
|
|
385
|
+
ensembles = self.minimal_input.ensembles
|
|
386
|
+
LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.")
|
|
387
|
+
|
|
388
|
+
grid_points = self.minimal_input.grid_points
|
|
389
|
+
LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}")
|
|
390
|
+
|
|
391
|
+
resolution = self.minimal_input.resolution
|
|
392
|
+
LOG.info(f"{resolution=}")
|
|
393
|
+
|
|
394
|
+
coords = self.minimal_input.coords
|
|
395
|
+
coords["dates"] = dates
|
|
396
|
+
total_shape = self.minimal_input.shape
|
|
397
|
+
total_shape[0] = len(dates)
|
|
398
|
+
LOG.info(f"total_shape = {total_shape}")
|
|
399
|
+
|
|
400
|
+
chunks = self.output.get_chunking(coords)
|
|
401
|
+
LOG.info(f"{chunks=}")
|
|
402
|
+
dtype = self.output.dtype
|
|
403
|
+
|
|
404
|
+
LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}")
|
|
405
|
+
|
|
406
|
+
metadata = {}
|
|
407
|
+
metadata["uuid"] = str(uuid.uuid4())
|
|
408
|
+
|
|
409
|
+
metadata.update(self.main_config.get("add_metadata", {}))
|
|
410
|
+
|
|
411
|
+
metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict()
|
|
412
|
+
|
|
413
|
+
metadata["description"] = self.main_config.description
|
|
414
|
+
metadata["licence"] = self.main_config["licence"]
|
|
415
|
+
metadata["attribution"] = self.main_config["attribution"]
|
|
416
|
+
|
|
417
|
+
metadata["remapping"] = self.output.remapping
|
|
418
|
+
metadata["order_by"] = self.output.order_by_as_list
|
|
419
|
+
metadata["flatten_grid"] = self.output.flatten_grid
|
|
420
|
+
|
|
421
|
+
metadata["ensemble_dimension"] = len(ensembles)
|
|
422
|
+
metadata["variables"] = variables
|
|
423
|
+
metadata["variables_with_nans"] = variables_with_nans
|
|
424
|
+
metadata["allow_nans"] = self.main_config.build.get("allow_nans", False)
|
|
425
|
+
metadata["resolution"] = resolution
|
|
426
|
+
|
|
427
|
+
metadata["data_request"] = self.minimal_input.data_request
|
|
428
|
+
metadata["field_shape"] = self.minimal_input.field_shape
|
|
429
|
+
metadata["proj_string"] = self.minimal_input.proj_string
|
|
430
|
+
|
|
431
|
+
metadata["start_date"] = dates[0].isoformat()
|
|
432
|
+
metadata["end_date"] = dates[-1].isoformat()
|
|
433
|
+
metadata["frequency"] = frequency
|
|
434
|
+
metadata["missing_dates"] = [_.isoformat() for _ in missing]
|
|
435
|
+
|
|
436
|
+
metadata["version"] = VERSION
|
|
437
|
+
|
|
438
|
+
self.dataset.check_name(
|
|
439
|
+
raise_exception=self.check_name,
|
|
440
|
+
is_test=self.test,
|
|
441
|
+
resolution=resolution,
|
|
442
|
+
dates=dates,
|
|
443
|
+
frequency=frequency,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if len(dates) != total_shape[0]:
|
|
447
|
+
raise ValueError(
|
|
448
|
+
f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) "
|
|
449
|
+
f"does not match data shape {total_shape[0]}. {total_shape=}"
|
|
59
450
|
)
|
|
60
|
-
return obj.initialise(check_name=check_name)
|
|
61
451
|
|
|
62
|
-
|
|
63
|
-
|
|
452
|
+
dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"])
|
|
453
|
+
|
|
454
|
+
metadata.update(self.main_config.get("force_metadata", {}))
|
|
455
|
+
|
|
456
|
+
###############################################################
|
|
457
|
+
# write metadata
|
|
458
|
+
###############################################################
|
|
459
|
+
|
|
460
|
+
self.update_metadata(**metadata)
|
|
461
|
+
|
|
462
|
+
self.dataset.add_dataset(
|
|
463
|
+
name="data",
|
|
464
|
+
chunks=chunks,
|
|
465
|
+
dtype=dtype,
|
|
466
|
+
shape=total_shape,
|
|
467
|
+
dimensions=("time", "variable", "ensemble", "cell"),
|
|
468
|
+
)
|
|
469
|
+
self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",))
|
|
470
|
+
self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",))
|
|
471
|
+
self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",))
|
|
472
|
+
|
|
473
|
+
self.registry.create(lengths=lengths)
|
|
474
|
+
self.tmp_statistics.create(exist_ok=False)
|
|
475
|
+
self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version)
|
|
476
|
+
|
|
477
|
+
statistics_start, statistics_end = build_statistics_dates(
|
|
478
|
+
dates,
|
|
479
|
+
self.main_config.statistics.get("start"),
|
|
480
|
+
self.main_config.statistics.get("end"),
|
|
481
|
+
)
|
|
482
|
+
self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end)
|
|
483
|
+
LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}")
|
|
484
|
+
|
|
485
|
+
self.registry.add_to_history("init finished")
|
|
486
|
+
|
|
487
|
+
assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks())
|
|
488
|
+
|
|
489
|
+
# Return the number of groups to process, so we can show a nice progress bar
|
|
490
|
+
return len(lengths)
|
|
491
|
+
|
|
64
492
|
|
|
493
|
+
class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin):
|
|
494
|
+
def __init__(self, path, parts=None, use_threads=False, statistics_temp_dir=None, progress=None, cache=None, **kwargs): # fmt: skip
|
|
495
|
+
super().__init__(path, cache=cache)
|
|
496
|
+
self.use_threads = use_threads
|
|
497
|
+
self.statistics_temp_dir = statistics_temp_dir
|
|
498
|
+
self.progress = progress
|
|
499
|
+
self.parts = parts
|
|
500
|
+
self.dataset = WritableDataset(self.path)
|
|
501
|
+
|
|
502
|
+
self.main_config = self.dataset.get_main_config()
|
|
503
|
+
self.create_elements(self.main_config)
|
|
504
|
+
self.read_dataset_metadata(self.dataset.path)
|
|
505
|
+
|
|
506
|
+
total = len(self.registry.get_flags())
|
|
507
|
+
self.chunk_filter = ChunkFilter(parts=self.parts, total=total)
|
|
508
|
+
|
|
509
|
+
self.data_array = self.dataset.data_array
|
|
510
|
+
self.n_groups = len(self.groups)
|
|
511
|
+
|
|
512
|
+
def run(self):
|
|
65
513
|
with self._cache_context():
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
514
|
+
self._run()
|
|
515
|
+
|
|
516
|
+
def _run(self):
|
|
517
|
+
for igroup, group in enumerate(self.groups):
|
|
518
|
+
if not self.chunk_filter(igroup):
|
|
519
|
+
continue
|
|
520
|
+
if self.registry.get_flag(igroup):
|
|
521
|
+
LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)")
|
|
522
|
+
continue
|
|
523
|
+
|
|
524
|
+
# assert isinstance(group[0], datetime.datetime), type(group[0])
|
|
525
|
+
LOG.debug(f"Building data for group {igroup}/{self.n_groups}")
|
|
526
|
+
|
|
527
|
+
result = self.input.select(dates=group)
|
|
528
|
+
assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group)
|
|
529
|
+
|
|
530
|
+
# There are several groups.
|
|
531
|
+
# There is one result to load for each group.
|
|
532
|
+
self.load_result(result)
|
|
533
|
+
self.registry.set_flag(igroup)
|
|
534
|
+
|
|
535
|
+
self.registry.add_provenance(name="provenance_load")
|
|
536
|
+
self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config)
|
|
537
|
+
|
|
538
|
+
self.dataset.print_info()
|
|
539
|
+
|
|
540
|
+
def load_result(self, result):
|
|
541
|
+
# There is one cube to load for each result.
|
|
542
|
+
dates = list(result.group_of_dates)
|
|
543
|
+
|
|
544
|
+
cube = result.get_cube()
|
|
545
|
+
shape = cube.extended_user_shape
|
|
546
|
+
dates_in_data = cube.user_coords["valid_datetime"]
|
|
547
|
+
|
|
548
|
+
LOG.debug(f"Loading {shape=} in {self.data_array.shape=}")
|
|
549
|
+
|
|
550
|
+
def check_shape(cube, dates, dates_in_data):
|
|
551
|
+
if cube.extended_user_shape[0] != len(dates):
|
|
552
|
+
print(
|
|
553
|
+
f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}"
|
|
554
|
+
)
|
|
555
|
+
print("Requested dates", compress_dates(dates))
|
|
556
|
+
print("Cube dates", compress_dates(dates_in_data))
|
|
557
|
+
|
|
558
|
+
a = set(as_datetime(_) for _ in dates)
|
|
559
|
+
b = set(as_datetime(_) for _ in dates_in_data)
|
|
560
|
+
|
|
561
|
+
print("Missing dates", compress_dates(a - b))
|
|
562
|
+
print("Extra dates", compress_dates(b - a))
|
|
563
|
+
|
|
564
|
+
raise ValueError(
|
|
565
|
+
f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}"
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
check_shape(cube, dates, dates_in_data)
|
|
569
|
+
|
|
570
|
+
def check_dates_in_data(lst, lst2):
|
|
571
|
+
lst2 = [np.datetime64(_) for _ in lst2]
|
|
572
|
+
lst = [np.datetime64(_) for _ in lst]
|
|
573
|
+
assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2)
|
|
574
|
+
|
|
575
|
+
check_dates_in_data(dates_in_data, dates)
|
|
576
|
+
|
|
577
|
+
def dates_to_indexes(dates, all_dates):
|
|
578
|
+
x = np.array(dates, dtype=np.datetime64)
|
|
579
|
+
y = np.array(all_dates, dtype=np.datetime64)
|
|
580
|
+
bitmap = np.isin(x, y)
|
|
581
|
+
return np.where(bitmap)[0]
|
|
582
|
+
|
|
583
|
+
indexes = dates_to_indexes(self.dates, dates_in_data)
|
|
584
|
+
|
|
585
|
+
array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes)
|
|
586
|
+
self.load_cube(cube, array)
|
|
587
|
+
|
|
588
|
+
stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans())
|
|
589
|
+
self.tmp_statistics.write(indexes, stats, dates=dates_in_data)
|
|
590
|
+
|
|
591
|
+
array.flush()
|
|
592
|
+
|
|
593
|
+
def _get_allow_nans(self):
|
|
594
|
+
config = self.main_config
|
|
595
|
+
if "allow_nans" in config.build:
|
|
596
|
+
return config.build.allow_nans
|
|
597
|
+
|
|
598
|
+
return config.statistics.get("allow_nans", [])
|
|
599
|
+
|
|
600
|
+
def load_cube(self, cube, array):
|
|
601
|
+
# There are several cubelets for each cube
|
|
602
|
+
start = time.time()
|
|
603
|
+
load = 0
|
|
604
|
+
save = 0
|
|
605
|
+
|
|
606
|
+
reading_chunks = None
|
|
607
|
+
total = cube.count(reading_chunks)
|
|
608
|
+
LOG.debug(f"Loading datacube: {cube}")
|
|
609
|
+
|
|
610
|
+
def position(x):
|
|
611
|
+
if isinstance(x, str) and "/" in x:
|
|
612
|
+
x = x.split("/")
|
|
613
|
+
return int(x[0])
|
|
614
|
+
return None
|
|
615
|
+
|
|
616
|
+
bar = tqdm.tqdm(
|
|
617
|
+
iterable=cube.iterate_cubelets(reading_chunks),
|
|
618
|
+
total=total,
|
|
619
|
+
desc=f"Loading datacube {cube}",
|
|
620
|
+
position=position(self.parts),
|
|
621
|
+
)
|
|
622
|
+
for i, cubelet in enumerate(bar):
|
|
623
|
+
bar.set_description(f"Loading {i}/{total}")
|
|
624
|
+
|
|
625
|
+
now = time.time()
|
|
626
|
+
data = cubelet.to_numpy()
|
|
627
|
+
local_indexes = cubelet.coords
|
|
628
|
+
load += time.time() - now
|
|
629
|
+
|
|
630
|
+
name = self.variables_names[local_indexes[1]]
|
|
631
|
+
check_data_values(
|
|
632
|
+
data[:],
|
|
633
|
+
name=name,
|
|
634
|
+
log=[i, data.shape, local_indexes],
|
|
635
|
+
allow_nans=self._get_allow_nans(),
|
|
72
636
|
)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
recompute=False,
|
|
85
|
-
statistics_start=start,
|
|
86
|
-
statistics_end=end,
|
|
637
|
+
|
|
638
|
+
now = time.time()
|
|
639
|
+
array[local_indexes] = data
|
|
640
|
+
save += time.time() - now
|
|
641
|
+
|
|
642
|
+
now = time.time()
|
|
643
|
+
save += time.time() - now
|
|
644
|
+
LOG.debug(
|
|
645
|
+
f"Elapsed: {seconds_to_human(time.time() - start)}, "
|
|
646
|
+
f"load time: {seconds_to_human(load)}, "
|
|
647
|
+
f"write time: {seconds_to_human(save)}."
|
|
87
648
|
)
|
|
88
|
-
loader.run()
|
|
89
|
-
assert loader.ready()
|
|
90
649
|
|
|
91
|
-
def size(self):
|
|
92
|
-
from .loaders import DatasetHandler
|
|
93
|
-
from .size import compute_directory_sizes
|
|
94
650
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
651
|
+
class Cleanup(Actor, HasRegistryMixin, HasStatisticTempMixin):
|
|
652
|
+
def __init__(self, path, statistics_temp_dir=None, delta=[], use_threads=False, **kwargs):
|
|
653
|
+
super().__init__(path)
|
|
654
|
+
self.use_threads = use_threads
|
|
655
|
+
self.statistics_temp_dir = statistics_temp_dir
|
|
656
|
+
self.additinon_temp_dir = statistics_temp_dir
|
|
657
|
+
self.actors = [
|
|
658
|
+
_InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir)
|
|
659
|
+
for d in delta
|
|
660
|
+
]
|
|
661
|
+
|
|
662
|
+
def run(self):
|
|
663
|
+
self.tmp_statistics.delete()
|
|
664
|
+
self.registry.clean()
|
|
665
|
+
for actor in self.actors:
|
|
666
|
+
actor.cleanup()
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
class Verify(Actor):
|
|
670
|
+
def __init__(self, path, **kwargs):
|
|
671
|
+
super().__init__(path)
|
|
672
|
+
|
|
673
|
+
def run(self):
|
|
674
|
+
LOG.info(f"Verifying dataset at {self.path}")
|
|
675
|
+
LOG.info(str(self.dataset.anemoi_dataset))
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
class AdditionsMixin:
|
|
679
|
+
def skip(self):
|
|
680
|
+
frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency)
|
|
681
|
+
if not self.delta.total_seconds() % frequency.total_seconds() == 0:
|
|
682
|
+
LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.")
|
|
683
|
+
return True
|
|
684
|
+
return False
|
|
685
|
+
|
|
686
|
+
@cached_property
|
|
687
|
+
def tmp_storage_path(self):
|
|
688
|
+
name = "storage_for_additions"
|
|
689
|
+
if self.delta:
|
|
690
|
+
name += frequency_to_string(self.delta)
|
|
691
|
+
return os.path.join(f"{self.path}.{name}.tmp")
|
|
692
|
+
|
|
693
|
+
def read_from_dataset(self):
|
|
694
|
+
self.variables = self.dataset.anemoi_dataset.variables
|
|
695
|
+
self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency)
|
|
696
|
+
start = self.dataset.zarr_metadata["statistics_start_date"]
|
|
697
|
+
end = self.dataset.zarr_metadata["statistics_end_date"]
|
|
698
|
+
self.start = datetime.datetime.fromisoformat(start)
|
|
699
|
+
self.end = datetime.datetime.fromisoformat(end)
|
|
700
|
+
|
|
701
|
+
ds = open_dataset(self.path, start=self.start, end=self.end)
|
|
702
|
+
self.dates = ds.dates
|
|
703
|
+
self.total = len(self.dates)
|
|
704
|
+
|
|
705
|
+
idelta = self.delta.total_seconds() // self.frequency.total_seconds()
|
|
706
|
+
assert int(idelta) == idelta, idelta
|
|
707
|
+
idelta = int(idelta)
|
|
708
|
+
self.ds = DeltaDataset(ds, idelta)
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
class DeltaDataset:
|
|
712
|
+
def __init__(self, ds, idelta):
|
|
713
|
+
self.ds = ds
|
|
714
|
+
self.idelta = idelta
|
|
715
|
+
|
|
716
|
+
def __getitem__(self, i):
|
|
717
|
+
j = i - self.idelta
|
|
718
|
+
if j < 0:
|
|
719
|
+
raise MissingDateError(f"Missing date {j}")
|
|
720
|
+
return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...]
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
class _InitAdditions(Actor, HasRegistryMixin, AdditionsMixin):
|
|
724
|
+
def __init__(self, path, delta, use_threads=False, progress=None, **kwargs):
|
|
725
|
+
super().__init__(path)
|
|
726
|
+
self.delta = frequency_to_timedelta(delta)
|
|
727
|
+
self.use_threads = use_threads
|
|
728
|
+
self.progress = progress
|
|
729
|
+
|
|
730
|
+
def run(self):
|
|
731
|
+
if self.skip():
|
|
732
|
+
LOG.info(f"Skipping delta={self.delta}")
|
|
733
|
+
return
|
|
734
|
+
|
|
735
|
+
self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True)
|
|
736
|
+
self.tmp_storage.delete()
|
|
737
|
+
self.tmp_storage.create()
|
|
738
|
+
LOG.info(f"Dataset {self.tmp_storage_path} additions initialized.")
|
|
98
739
|
|
|
99
740
|
def cleanup(self):
|
|
100
|
-
|
|
741
|
+
self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False)
|
|
742
|
+
self.tmp_storage.delete()
|
|
743
|
+
LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}")
|
|
101
744
|
|
|
102
|
-
cleaner = DatasetHandlerWithStatistics.from_dataset(
|
|
103
|
-
path=self.path, use_threads=self.use_threads, progress=self.progress, statistics_tmp=self.statistics_tmp
|
|
104
|
-
)
|
|
105
|
-
cleaner.tmp_statistics.delete()
|
|
106
|
-
cleaner.registry.clean()
|
|
107
745
|
|
|
108
|
-
|
|
109
|
-
|
|
746
|
+
class _RunAdditions(Actor, HasRegistryMixin, AdditionsMixin):
|
|
747
|
+
def __init__(self, path, delta, parts=None, use_threads=False, progress=None, **kwargs):
|
|
748
|
+
super().__init__(path)
|
|
749
|
+
self.delta = frequency_to_timedelta(delta)
|
|
750
|
+
self.use_threads = use_threads
|
|
751
|
+
self.progress = progress
|
|
752
|
+
self.parts = parts
|
|
110
753
|
|
|
111
|
-
|
|
754
|
+
self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False)
|
|
755
|
+
LOG.info(f"Writing in {self.tmp_storage_path}")
|
|
112
756
|
|
|
113
|
-
def
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
757
|
+
def run(self):
|
|
758
|
+
if self.skip():
|
|
759
|
+
LOG.info(f"Skipping delta={self.delta}")
|
|
760
|
+
return
|
|
117
761
|
|
|
118
|
-
|
|
119
|
-
a = StatisticsAddition.from_dataset(path=self.path, use_threads=self.use_threads)
|
|
120
|
-
a.initialise()
|
|
762
|
+
self.read_from_dataset()
|
|
121
763
|
|
|
122
|
-
|
|
764
|
+
chunk_filter = ChunkFilter(parts=self.parts, total=self.total)
|
|
765
|
+
for i in range(0, self.total):
|
|
766
|
+
if not chunk_filter(i):
|
|
767
|
+
continue
|
|
768
|
+
date = self.dates[i]
|
|
123
769
|
try:
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
770
|
+
arr = self.ds[i]
|
|
771
|
+
stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans)
|
|
772
|
+
self.tmp_storage.add([date, i, stats], key=date)
|
|
773
|
+
except MissingDateError:
|
|
774
|
+
self.tmp_storage.add([date, i, "missing"], key=date)
|
|
775
|
+
self.tmp_storage.flush()
|
|
776
|
+
LOG.debug(f"Dataset {self.path} additions run.")
|
|
777
|
+
|
|
778
|
+
def allow_nans(self):
|
|
779
|
+
if self.dataset.anemoi_dataset.metadata.get("allow_nans", False):
|
|
780
|
+
return True
|
|
130
781
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
782
|
+
variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None)
|
|
783
|
+
if variables_with_nans is not None:
|
|
784
|
+
return variables_with_nans
|
|
785
|
+
warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.")
|
|
786
|
+
return True
|
|
135
787
|
|
|
136
|
-
if statistics:
|
|
137
|
-
a = StatisticsAddition.from_dataset(path=self.path, use_threads=self.use_threads)
|
|
138
|
-
a.run(parts)
|
|
139
788
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
789
|
+
class _FinaliseAdditions(Actor, HasRegistryMixin, AdditionsMixin):
|
|
790
|
+
def __init__(self, path, delta, use_threads=False, progress=None, **kwargs):
|
|
791
|
+
super().__init__(path)
|
|
792
|
+
self.delta = frequency_to_timedelta(delta)
|
|
793
|
+
self.use_threads = use_threads
|
|
794
|
+
self.progress = progress
|
|
795
|
+
|
|
796
|
+
self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False)
|
|
797
|
+
LOG.info(f"Reading from {self.tmp_storage_path}.")
|
|
798
|
+
|
|
799
|
+
def run(self):
|
|
800
|
+
if self.skip():
|
|
801
|
+
LOG.info(f"Skipping delta={self.delta}.")
|
|
802
|
+
return
|
|
803
|
+
|
|
804
|
+
self.read_from_dataset()
|
|
805
|
+
|
|
806
|
+
shape = (len(self.dates), len(self.variables))
|
|
807
|
+
agg = dict(
|
|
808
|
+
minimum=np.full(shape, np.nan, dtype=np.float64),
|
|
809
|
+
maximum=np.full(shape, np.nan, dtype=np.float64),
|
|
810
|
+
sums=np.full(shape, np.nan, dtype=np.float64),
|
|
811
|
+
squares=np.full(shape, np.nan, dtype=np.float64),
|
|
812
|
+
count=np.full(shape, -1, dtype=np.int64),
|
|
813
|
+
has_nans=np.full(shape, False, dtype=np.bool_),
|
|
814
|
+
)
|
|
815
|
+
LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}")
|
|
816
|
+
|
|
817
|
+
found = set()
|
|
818
|
+
ifound = set()
|
|
819
|
+
missing = set()
|
|
820
|
+
for _date, (date, i, stats) in self.tmp_storage.items():
|
|
821
|
+
assert _date == date
|
|
822
|
+
if stats == "missing":
|
|
823
|
+
missing.add(date)
|
|
824
|
+
continue
|
|
825
|
+
|
|
826
|
+
assert date not in found, f"Duplicates found {date}"
|
|
827
|
+
found.add(date)
|
|
828
|
+
ifound.add(i)
|
|
829
|
+
|
|
830
|
+
for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
|
|
831
|
+
agg[k][i, ...] = stats[k]
|
|
832
|
+
|
|
833
|
+
assert len(found) + len(missing) == len(self.dates), (
|
|
834
|
+
len(found),
|
|
835
|
+
len(missing),
|
|
836
|
+
len(self.dates),
|
|
837
|
+
)
|
|
838
|
+
assert found.union(missing) == set(self.dates), (
|
|
839
|
+
found,
|
|
840
|
+
missing,
|
|
841
|
+
set(self.dates),
|
|
842
|
+
)
|
|
148
843
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
844
|
+
if len(ifound) < 2:
|
|
845
|
+
LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.")
|
|
846
|
+
self.tmp_storage.delete()
|
|
847
|
+
return
|
|
153
848
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
849
|
+
mask = sorted(list(ifound))
|
|
850
|
+
for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
|
|
851
|
+
agg[k] = agg[k][mask, ...]
|
|
157
852
|
|
|
158
|
-
for
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
a.finalise()
|
|
164
|
-
except TendenciesStatisticsDeltaNotMultipleOfFrequency:
|
|
165
|
-
LOG.debug(f"Skipping delta={d} as it is not a multiple of the frequency.")
|
|
166
|
-
|
|
167
|
-
def finalise(self, **kwargs):
|
|
168
|
-
self.statistics(**kwargs)
|
|
169
|
-
self.size()
|
|
170
|
-
|
|
171
|
-
def create(self):
|
|
172
|
-
self.init()
|
|
173
|
-
self.load()
|
|
174
|
-
self.finalise()
|
|
175
|
-
self.additions()
|
|
176
|
-
self.cleanup()
|
|
177
|
-
|
|
178
|
-
def additions(self, delta=[1, 3, 6, 12, 24]):
|
|
179
|
-
self.init_additions(delta=delta)
|
|
180
|
-
self.run_additions(delta=delta)
|
|
181
|
-
self.finalise_additions(delta=delta)
|
|
853
|
+
for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
|
|
854
|
+
assert agg[k].shape == agg["count"].shape, (
|
|
855
|
+
agg[k].shape,
|
|
856
|
+
agg["count"].shape,
|
|
857
|
+
)
|
|
182
858
|
|
|
183
|
-
|
|
184
|
-
|
|
859
|
+
minimum = np.nanmin(agg["minimum"], axis=0)
|
|
860
|
+
maximum = np.nanmax(agg["maximum"], axis=0)
|
|
861
|
+
sums = np.nansum(agg["sums"], axis=0)
|
|
862
|
+
squares = np.nansum(agg["squares"], axis=0)
|
|
863
|
+
count = np.nansum(agg["count"], axis=0)
|
|
864
|
+
has_nans = np.any(agg["has_nans"], axis=0)
|
|
865
|
+
|
|
866
|
+
assert sums.shape == count.shape
|
|
867
|
+
assert sums.shape == squares.shape
|
|
868
|
+
assert sums.shape == minimum.shape
|
|
869
|
+
assert sums.shape == maximum.shape
|
|
870
|
+
assert sums.shape == has_nans.shape
|
|
871
|
+
|
|
872
|
+
mean = sums / count
|
|
873
|
+
assert sums.shape == mean.shape
|
|
874
|
+
|
|
875
|
+
x = squares / count - mean * mean
|
|
876
|
+
# x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0
|
|
877
|
+
# remove negative variance due to numerical errors
|
|
878
|
+
for i, name in enumerate(self.variables):
|
|
879
|
+
x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1])
|
|
880
|
+
check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares)
|
|
881
|
+
|
|
882
|
+
stdev = np.sqrt(x)
|
|
883
|
+
assert sums.shape == stdev.shape
|
|
884
|
+
|
|
885
|
+
self.summary = Summary(
|
|
886
|
+
minimum=minimum,
|
|
887
|
+
maximum=maximum,
|
|
888
|
+
mean=mean,
|
|
889
|
+
count=count,
|
|
890
|
+
sums=sums,
|
|
891
|
+
squares=squares,
|
|
892
|
+
stdev=stdev,
|
|
893
|
+
variables_names=self.variables,
|
|
894
|
+
has_nans=has_nans,
|
|
895
|
+
)
|
|
896
|
+
LOG.info(f"Dataset {self.path} additions finalised.")
|
|
897
|
+
# self.check_statistics()
|
|
898
|
+
self._write(self.summary)
|
|
899
|
+
self.tmp_storage.delete()
|
|
185
900
|
|
|
186
|
-
|
|
901
|
+
def _write(self, summary):
|
|
902
|
+
for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]:
|
|
903
|
+
name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}"
|
|
904
|
+
self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",))
|
|
905
|
+
self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end")
|
|
906
|
+
LOG.debug(f"Wrote additions in {self.path}")
|
|
187
907
|
|
|
188
|
-
def _path_readable(self):
|
|
189
|
-
import zarr
|
|
190
908
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
909
|
+
def multi_addition(cls):
|
|
910
|
+
class MultiAdditions:
|
|
911
|
+
def __init__(self, *args, **kwargs):
|
|
912
|
+
self.actors = []
|
|
913
|
+
|
|
914
|
+
for k in kwargs.pop("delta", []):
|
|
915
|
+
self.actors.append(cls(*args, delta=k, **kwargs))
|
|
916
|
+
|
|
917
|
+
if not self.actors:
|
|
918
|
+
LOG.warning("No delta found in kwargs, no additions will be computed.")
|
|
919
|
+
|
|
920
|
+
def run(self):
|
|
921
|
+
for actor in self.actors:
|
|
922
|
+
actor.run()
|
|
923
|
+
|
|
924
|
+
return MultiAdditions
|
|
196
925
|
|
|
197
|
-
def verify(self):
|
|
198
|
-
from .loaders import DatasetVerifier
|
|
199
926
|
|
|
200
|
-
|
|
927
|
+
InitAdditions = multi_addition(_InitAdditions)
|
|
928
|
+
RunAdditions = multi_addition(_RunAdditions)
|
|
929
|
+
FinaliseAdditions = multi_addition(_FinaliseAdditions)
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin):
|
|
933
|
+
def __init__(self, path, use_threads=False, statistics_temp_dir=None, progress=None, **kwargs):
|
|
934
|
+
super().__init__(path)
|
|
935
|
+
self.use_threads = use_threads
|
|
936
|
+
self.progress = progress
|
|
937
|
+
self.statistics_temp_dir = statistics_temp_dir
|
|
938
|
+
|
|
939
|
+
def run(self):
|
|
940
|
+
start, end = (
|
|
941
|
+
self.dataset.zarr_metadata["statistics_start_date"],
|
|
942
|
+
self.dataset.zarr_metadata["statistics_end_date"],
|
|
943
|
+
)
|
|
944
|
+
start, end = np.datetime64(start), np.datetime64(end)
|
|
945
|
+
dates = self.dataset.anemoi_dataset.dates
|
|
946
|
+
|
|
947
|
+
assert type(dates[0]) is type(start), (type(dates[0]), type(start))
|
|
948
|
+
|
|
949
|
+
dates = [d for d in dates if d >= start and d <= end]
|
|
950
|
+
dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing]
|
|
951
|
+
variables = self.dataset.anemoi_dataset.variables
|
|
952
|
+
stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans)
|
|
953
|
+
|
|
954
|
+
LOG.info(stats)
|
|
955
|
+
|
|
956
|
+
if not all(self.registry.get_flags(sync=False)):
|
|
957
|
+
raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.")
|
|
958
|
+
|
|
959
|
+
for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]:
|
|
960
|
+
self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",))
|
|
961
|
+
|
|
962
|
+
self.registry.add_to_history("compute_statistics_end")
|
|
963
|
+
LOG.info(f"Wrote statistics in {self.path}")
|
|
964
|
+
|
|
965
|
+
@cached_property
|
|
966
|
+
def allow_nans(self):
|
|
967
|
+
import zarr
|
|
201
968
|
|
|
202
|
-
|
|
969
|
+
z = zarr.open(self.path, mode="r")
|
|
970
|
+
if "allow_nans" in z.attrs:
|
|
971
|
+
return z.attrs["allow_nans"]
|
|
972
|
+
|
|
973
|
+
if "variables_with_nans" in z.attrs:
|
|
974
|
+
return z.attrs["variables_with_nans"]
|
|
975
|
+
|
|
976
|
+
warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.")
|
|
977
|
+
return True
|
|
978
|
+
|
|
979
|
+
|
|
980
|
+
def chain(tasks):
|
|
981
|
+
class Chain(Actor):
|
|
982
|
+
def __init__(self, **kwargs):
|
|
983
|
+
self.kwargs = kwargs
|
|
984
|
+
|
|
985
|
+
def run(self):
|
|
986
|
+
for cls in tasks:
|
|
987
|
+
t = cls(**self.kwargs)
|
|
988
|
+
t.run()
|
|
989
|
+
|
|
990
|
+
return Chain
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
def creator_factory(name, trace=None, **kwargs):
|
|
994
|
+
if trace:
|
|
995
|
+
from anemoi.datasets.create.trace import enable_trace
|
|
996
|
+
|
|
997
|
+
enable_trace(trace)
|
|
998
|
+
|
|
999
|
+
cls = dict(
|
|
1000
|
+
init=Init,
|
|
1001
|
+
load=Load,
|
|
1002
|
+
size=Size,
|
|
1003
|
+
patch=Patch,
|
|
1004
|
+
statistics=Statistics,
|
|
1005
|
+
finalise=chain([Statistics, Size, Cleanup]),
|
|
1006
|
+
cleanup=Cleanup,
|
|
1007
|
+
verify=Verify,
|
|
1008
|
+
init_additions=InitAdditions,
|
|
1009
|
+
load_additions=RunAdditions,
|
|
1010
|
+
run_additions=RunAdditions,
|
|
1011
|
+
finalise_additions=chain([FinaliseAdditions, Size]),
|
|
1012
|
+
additions=chain([InitAdditions, RunAdditions, FinaliseAdditions, Size, Cleanup]),
|
|
1013
|
+
)[name]
|
|
1014
|
+
LOG.debug(f"Creating {cls.__name__} with {kwargs}")
|
|
1015
|
+
return cls(**kwargs)
|