anemoi-datasets 0.4.4__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/cleanup.py +44 -0
  3. anemoi/datasets/commands/create.py +50 -20
  4. anemoi/datasets/commands/finalise-additions.py +45 -0
  5. anemoi/datasets/commands/finalise.py +39 -0
  6. anemoi/datasets/commands/init-additions.py +45 -0
  7. anemoi/datasets/commands/init.py +67 -0
  8. anemoi/datasets/commands/inspect.py +1 -1
  9. anemoi/datasets/commands/load-additions.py +47 -0
  10. anemoi/datasets/commands/load.py +47 -0
  11. anemoi/datasets/commands/patch.py +39 -0
  12. anemoi/datasets/create/__init__.py +961 -146
  13. anemoi/datasets/create/check.py +5 -3
  14. anemoi/datasets/create/config.py +53 -2
  15. anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
  16. anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
  17. anemoi/datasets/create/functions/sources/xarray/field.py +1 -1
  18. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
  19. anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
  20. anemoi/datasets/create/functions/sources/xarray/metadata.py +27 -29
  21. anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
  22. anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
  23. anemoi/datasets/create/input.py +23 -22
  24. anemoi/datasets/create/statistics/__init__.py +39 -23
  25. anemoi/datasets/create/utils.py +3 -2
  26. anemoi/datasets/data/__init__.py +1 -0
  27. anemoi/datasets/data/concat.py +46 -2
  28. anemoi/datasets/data/dataset.py +109 -34
  29. anemoi/datasets/data/forwards.py +17 -8
  30. anemoi/datasets/data/grids.py +17 -3
  31. anemoi/datasets/data/interpolate.py +133 -0
  32. anemoi/datasets/data/misc.py +56 -66
  33. anemoi/datasets/data/missing.py +240 -0
  34. anemoi/datasets/data/select.py +7 -1
  35. anemoi/datasets/data/stores.py +3 -3
  36. anemoi/datasets/data/subset.py +47 -5
  37. anemoi/datasets/data/unchecked.py +20 -22
  38. anemoi/datasets/data/xy.py +125 -0
  39. anemoi/datasets/dates/__init__.py +13 -66
  40. anemoi/datasets/dates/groups.py +2 -2
  41. anemoi/datasets/grids.py +66 -48
  42. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/METADATA +5 -5
  43. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/RECORD +47 -37
  44. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/WHEEL +1 -1
  45. anemoi/datasets/create/loaders.py +0 -936
  46. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/LICENSE +0 -0
  47. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/entry_points.txt +0 -0
  48. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/top_level.txt +0 -0
@@ -1,936 +0,0 @@
1
- # (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
2
- # This software is licensed under the terms of the Apache Licence Version 2.0
3
- # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
- # In applying this licence, ECMWF does not waive the privileges and immunities
5
- # granted to it by virtue of its status as an intergovernmental organisation
6
- # nor does it submit to any jurisdiction.
7
- import datetime
8
- import json
9
- import logging
10
- import os
11
- import time
12
- import uuid
13
- import warnings
14
- from functools import cached_property
15
-
16
- import numpy as np
17
- import tqdm
18
- import zarr
19
- from anemoi.utils.config import DotDict
20
- from anemoi.utils.dates import as_datetime
21
- from anemoi.utils.humanize import seconds_to_human
22
-
23
- from anemoi.datasets import MissingDateError
24
- from anemoi.datasets import open_dataset
25
- from anemoi.datasets.create.persistent import build_storage
26
- from anemoi.datasets.data.misc import as_first_date
27
- from anemoi.datasets.data.misc import as_last_date
28
- from anemoi.datasets.dates import compress_dates
29
- from anemoi.datasets.dates.groups import Groups
30
-
31
- from .check import DatasetName
32
- from .check import check_data_values
33
- from .chunks import ChunkFilter
34
- from .config import build_output
35
- from .config import loader_config
36
- from .input import build_input
37
- from .statistics import Summary
38
- from .statistics import TmpStatistics
39
- from .statistics import check_variance
40
- from .statistics import compute_statistics
41
- from .statistics import default_statistics_dates
42
- from .utils import normalize_and_check_dates
43
- from .writer import ViewCacheArray
44
- from .zarr import ZarrBuiltRegistry
45
- from .zarr import add_zarr_dataset
46
-
47
- LOG = logging.getLogger(__name__)
48
-
49
- VERSION = "0.20"
50
-
51
-
52
- def set_to_test_mode(cfg):
53
- NUMBER_OF_DATES = 4
54
-
55
- dates = cfg.dates
56
- LOG.warn(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.")
57
- groups = Groups(**cfg.dates)
58
- dates = groups.dates
59
- cfg.dates = dict(
60
- start=dates[0],
61
- end=dates[NUMBER_OF_DATES - 1],
62
- frequency=dates.frequency,
63
- group_by=NUMBER_OF_DATES,
64
- )
65
-
66
- def set_element_to_test(obj):
67
- if isinstance(obj, (list, tuple)):
68
- for v in obj:
69
- set_element_to_test(v)
70
- return
71
- if isinstance(obj, (dict, DotDict)):
72
- if "grid" in obj:
73
- previous = obj["grid"]
74
- obj["grid"] = "20./20."
75
- LOG.warn(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}")
76
- if "number" in obj:
77
- if isinstance(obj["number"], (list, tuple)):
78
- previous = obj["number"]
79
- obj["number"] = previous[0:3]
80
- LOG.warn(f"Running in test mode. Setting number to {obj['number']} instead of {previous}")
81
- for k, v in obj.items():
82
- set_element_to_test(v)
83
- if "constants" in obj:
84
- constants = obj["constants"]
85
- if "param" in constants and isinstance(constants["param"], list):
86
- constants["param"] = ["cos_latitude"]
87
-
88
- set_element_to_test(cfg)
89
-
90
-
91
- class GenericDatasetHandler:
92
- def __init__(self, *, path, use_threads=False, **kwargs):
93
-
94
- # Catch all floating point errors, including overflow, sqrt(<0), etc
95
- np.seterr(all="raise", under="warn")
96
-
97
- assert isinstance(path, str), path
98
-
99
- self.path = path
100
- self.kwargs = kwargs
101
- self.use_threads = use_threads
102
- if "test" in kwargs:
103
- self.test = kwargs["test"]
104
-
105
- @classmethod
106
- def from_config(cls, *, config, path, use_threads=False, **kwargs):
107
- """Config is the path to the config file or a dict with the config"""
108
-
109
- assert isinstance(config, dict) or isinstance(config, str), config
110
- return cls(config=config, path=path, use_threads=use_threads, **kwargs)
111
-
112
- @classmethod
113
- def from_dataset_config(cls, *, path, use_threads=False, **kwargs):
114
- """Read the config saved inside the zarr dataset and instantiate the class for this config."""
115
-
116
- assert os.path.exists(path), f"Path {path} does not exist."
117
- z = zarr.open(path, mode="r")
118
- config = z.attrs["_create_yaml_config"]
119
- LOG.debug("Config loaded from zarr config:\n%s", json.dumps(config, indent=4, sort_keys=True, default=str))
120
- return cls.from_config(config=config, path=path, use_threads=use_threads, **kwargs)
121
-
122
- @classmethod
123
- def from_dataset(cls, *, path, use_threads=False, **kwargs):
124
- """Instanciate the class from the path to the zarr dataset, without config."""
125
-
126
- assert os.path.exists(path), f"Path {path} does not exist."
127
- return cls(path=path, use_threads=use_threads, **kwargs)
128
-
129
- def read_dataset_metadata(self):
130
- ds = open_dataset(self.path)
131
- self.dataset_shape = ds.shape
132
- self.variables_names = ds.variables
133
- assert len(self.variables_names) == ds.shape[1], self.dataset_shape
134
- self.dates = ds.dates
135
-
136
- self.missing_dates = sorted(list([self.dates[i] for i in ds.missing]))
137
-
138
- z = zarr.open(self.path, "r")
139
- missing_dates = z.attrs.get("missing_dates", [])
140
- missing_dates = sorted([np.datetime64(d) for d in missing_dates])
141
-
142
- if missing_dates != self.missing_dates:
143
- LOG.warn("Missing dates given in recipe do not match the actual missing dates in the dataset.")
144
- LOG.warn(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}")
145
- LOG.warn(f"Missing dates in dataset: {sorted(str(x) for x in self.missing_dates)}")
146
- raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.")
147
-
148
- @cached_property
149
- def registry(self):
150
- return ZarrBuiltRegistry(self.path, use_threads=self.use_threads)
151
-
152
- def ready(self):
153
- return all(self.registry.get_flags())
154
-
155
- def update_metadata(self, **kwargs):
156
- LOG.debug(f"Updating metadata {kwargs}")
157
- z = zarr.open(self.path, mode="w+")
158
- for k, v in kwargs.items():
159
- if isinstance(v, np.datetime64):
160
- v = v.astype(datetime.datetime)
161
- if isinstance(v, datetime.date):
162
- v = v.isoformat()
163
- z.attrs[k] = v
164
-
165
- def _add_dataset(self, mode="r+", **kwargs):
166
- z = zarr.open(self.path, mode=mode)
167
- return add_zarr_dataset(zarr_root=z, **kwargs)
168
-
169
- def get_zarr_chunks(self):
170
- z = zarr.open(self.path, mode="r")
171
- return z["data"].chunks
172
-
173
- def print_info(self):
174
- z = zarr.open(self.path, mode="r")
175
- try:
176
- LOG.info(z["data"].info)
177
- except Exception as e:
178
- LOG.info(e)
179
-
180
-
181
- class DatasetHandler(GenericDatasetHandler):
182
- pass
183
-
184
-
185
- class DatasetHandlerWithStatistics(GenericDatasetHandler):
186
- def __init__(self, statistics_tmp=None, **kwargs):
187
- super().__init__(**kwargs)
188
- statistics_tmp = kwargs.get("statistics_tmp") or os.path.join(self.path + ".storage_for_statistics.tmp")
189
- self.tmp_statistics = TmpStatistics(statistics_tmp)
190
-
191
-
192
- class Loader(DatasetHandlerWithStatistics):
193
- def build_input(self):
194
- from earthkit.data.core.order import build_remapping
195
-
196
- builder = build_input(
197
- self.main_config.input,
198
- data_sources=self.main_config.get("data_sources", {}),
199
- order_by=self.output.order_by,
200
- flatten_grid=self.output.flatten_grid,
201
- remapping=build_remapping(self.output.remapping),
202
- use_grib_paramid=self.main_config.build.use_grib_paramid,
203
- )
204
- LOG.debug("✅ INPUT_BUILDER")
205
- LOG.debug(builder)
206
- return builder
207
-
208
- @property
209
- def allow_nans(self):
210
- if "allow_nans" in self.main_config.build:
211
- return self.main_config.build.allow_nans
212
-
213
- return self.main_config.statistics.get("allow_nans", [])
214
-
215
-
216
- class InitialiserLoader(Loader):
217
- def __init__(self, config, **kwargs):
218
- super().__init__(**kwargs)
219
-
220
- self.main_config = loader_config(config)
221
- if self.test:
222
- set_to_test_mode(self.main_config)
223
-
224
- LOG.info(dict(self.main_config.dates))
225
-
226
- self.tmp_statistics.delete()
227
-
228
- self.groups = Groups(**self.main_config.dates)
229
- LOG.info(self.groups)
230
-
231
- self.output = build_output(self.main_config.output, parent=self)
232
- self.input = self.build_input()
233
- LOG.info(self.input)
234
-
235
- first_date = self.groups.dates[0]
236
- self.minimal_input = self.input.select([first_date])
237
- LOG.info("Minimal input (using only the first date) :")
238
- LOG.info(self.minimal_input)
239
-
240
- def build_statistics_dates(self, start, end):
241
- """Compute the start and end dates for the statistics, based on :
242
- - The start and end dates in the config
243
- - The default statistics dates convention
244
-
245
- Then adapt according to the actual dates in the dataset.
246
- """
247
-
248
- ds = open_dataset(self.path)
249
- dates = ds.dates
250
-
251
- # if not specified, use the default statistics dates
252
- default_start, default_end = default_statistics_dates(dates)
253
- if start is None:
254
- start = default_start
255
- if end is None:
256
- end = default_end
257
-
258
- # in any case, adapt to the actual dates in the dataset
259
- start = as_first_date(start, dates)
260
- end = as_last_date(end, dates)
261
-
262
- # and convert to datetime to isoformat
263
- start = start.astype(datetime.datetime)
264
- end = end.astype(datetime.datetime)
265
- return (start.isoformat(), end.isoformat())
266
-
267
- def initialise_dataset_backend(self):
268
- z = zarr.open(self.path, mode="w")
269
- z.create_group("_build")
270
-
271
- def initialise(self, check_name=True):
272
- """Create an empty dataset of the right final shape
273
-
274
- Read a small part of the data to get the shape of the data and the resolution and more metadata.
275
- """
276
-
277
- LOG.info("Config loaded ok:")
278
- # LOG.info(self.main_config)
279
-
280
- dates = self.groups.dates
281
- frequency = dates.frequency
282
- assert isinstance(frequency, int), frequency
283
-
284
- LOG.info(f"Found {len(dates)} datetimes.")
285
- LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ")
286
- LOG.info(f"Missing dates: {len(dates.missing)}")
287
- lengths = tuple(len(g) for g in self.groups)
288
-
289
- variables = self.minimal_input.variables
290
- LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.")
291
-
292
- variables_with_nans = self.main_config.statistics.get("allow_nans", [])
293
-
294
- ensembles = self.minimal_input.ensembles
295
- LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.")
296
-
297
- grid_points = self.minimal_input.grid_points
298
- LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}")
299
-
300
- resolution = self.minimal_input.resolution
301
- LOG.info(f"{resolution=}")
302
-
303
- coords = self.minimal_input.coords
304
- coords["dates"] = dates
305
- total_shape = self.minimal_input.shape
306
- total_shape[0] = len(dates)
307
- LOG.info(f"total_shape = {total_shape}")
308
-
309
- chunks = self.output.get_chunking(coords)
310
- LOG.info(f"{chunks=}")
311
- dtype = self.output.dtype
312
-
313
- LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}")
314
-
315
- metadata = {}
316
- metadata["uuid"] = str(uuid.uuid4())
317
-
318
- metadata.update(self.main_config.get("add_metadata", {}))
319
-
320
- metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict()
321
-
322
- metadata["description"] = self.main_config.description
323
- metadata["licence"] = self.main_config["licence"]
324
- metadata["attribution"] = self.main_config["attribution"]
325
-
326
- metadata["remapping"] = self.output.remapping
327
- metadata["order_by"] = self.output.order_by_as_list
328
- metadata["flatten_grid"] = self.output.flatten_grid
329
-
330
- metadata["ensemble_dimension"] = len(ensembles)
331
- metadata["variables"] = variables
332
- metadata["variables_with_nans"] = variables_with_nans
333
- metadata["allow_nans"] = self.main_config.build.get("allow_nans", False)
334
- metadata["resolution"] = resolution
335
-
336
- metadata["data_request"] = self.minimal_input.data_request
337
- metadata["field_shape"] = self.minimal_input.field_shape
338
- metadata["proj_string"] = self.minimal_input.proj_string
339
-
340
- metadata["start_date"] = dates[0].isoformat()
341
- metadata["end_date"] = dates[-1].isoformat()
342
- metadata["frequency"] = frequency
343
- metadata["missing_dates"] = [_.isoformat() for _ in dates.missing]
344
-
345
- metadata["version"] = VERSION
346
-
347
- if check_name:
348
- basename, ext = os.path.splitext(os.path.basename(self.path)) # noqa: F841
349
- ds_name = DatasetName(basename, resolution, dates[0], dates[-1], frequency)
350
- ds_name.raise_if_not_valid()
351
-
352
- if len(dates) != total_shape[0]:
353
- raise ValueError(
354
- f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) "
355
- f"does not match data shape {total_shape[0]}. {total_shape=}"
356
- )
357
-
358
- dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"])
359
-
360
- metadata.update(self.main_config.get("force_metadata", {}))
361
-
362
- ###############################################################
363
- # write metadata
364
- ###############################################################
365
-
366
- self.initialise_dataset_backend()
367
-
368
- self.update_metadata(**metadata)
369
-
370
- self._add_dataset(
371
- name="data",
372
- chunks=chunks,
373
- dtype=dtype,
374
- shape=total_shape,
375
- dimensions=("time", "variable", "ensemble", "cell"),
376
- )
377
- self._add_dataset(name="dates", array=dates, dimensions=("time",))
378
- self._add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",))
379
- self._add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",))
380
-
381
- self.registry.create(lengths=lengths)
382
- self.tmp_statistics.create(exist_ok=False)
383
- self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version)
384
-
385
- statistics_start, statistics_end = self.build_statistics_dates(
386
- self.main_config.statistics.get("start"),
387
- self.main_config.statistics.get("end"),
388
- )
389
- self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end)
390
- LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}")
391
-
392
- self.registry.add_to_history("init finished")
393
-
394
- assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks())
395
-
396
- # Return the number of groups to process, so we can show a nice progress bar
397
- return len(lengths)
398
-
399
-
400
- class ContentLoader(Loader):
401
- def __init__(self, config, parts, **kwargs):
402
- super().__init__(**kwargs)
403
- self.main_config = loader_config(config)
404
-
405
- self.groups = Groups(**self.main_config.dates)
406
- self.output = build_output(self.main_config.output, parent=self)
407
- self.input = self.build_input()
408
- self.read_dataset_metadata()
409
-
410
- self.parts = parts
411
- total = len(self.registry.get_flags())
412
- self.chunk_filter = ChunkFilter(parts=self.parts, total=total)
413
-
414
- self.data_array = zarr.open(self.path, mode="r+")["data"]
415
- self.n_groups = len(self.groups)
416
-
417
- def load(self):
418
- for igroup, group in enumerate(self.groups):
419
- if not self.chunk_filter(igroup):
420
- continue
421
- if self.registry.get_flag(igroup):
422
- LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)")
423
- continue
424
-
425
- assert isinstance(group[0], datetime.datetime), group
426
-
427
- result = self.input.select(dates=group)
428
- assert result.dates == group, (len(result.dates), len(group))
429
-
430
- LOG.debug(f"Building data for group {igroup}/{self.n_groups}")
431
-
432
- # There are several groups.
433
- # There is one result to load for each group.
434
- self.load_result(result)
435
- self.registry.set_flag(igroup)
436
-
437
- self.registry.add_provenance(name="provenance_load")
438
- self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config)
439
-
440
- # self.print_info()
441
-
442
- def load_result(self, result):
443
- # There is one cube to load for each result.
444
- dates = result.dates
445
-
446
- cube = result.get_cube()
447
- shape = cube.extended_user_shape
448
- dates_in_data = cube.user_coords["valid_datetime"]
449
-
450
- if cube.extended_user_shape[0] != len(dates):
451
- print(f"Cube shape does not match the number of dates {cube.extended_user_shape[0]}, {len(dates)}")
452
- print("Requested dates", compress_dates(dates))
453
- print("Cube dates", compress_dates(dates_in_data))
454
-
455
- a = set(as_datetime(_) for _ in dates)
456
- b = set(as_datetime(_) for _ in dates_in_data)
457
-
458
- print("Missing dates", compress_dates(a - b))
459
- print("Extra dates", compress_dates(b - a))
460
-
461
- raise ValueError(
462
- f"Cube shape does not match the number of dates {cube.extended_user_shape[0]}, {len(dates)}"
463
- )
464
-
465
- LOG.debug(f"Loading {shape=} in {self.data_array.shape=}")
466
-
467
- def check_dates_in_data(lst, lst2):
468
- lst2 = [np.datetime64(_) for _ in lst2]
469
- lst = [np.datetime64(_) for _ in lst]
470
- assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2)
471
-
472
- check_dates_in_data(dates_in_data, dates)
473
-
474
- def dates_to_indexes(dates, all_dates):
475
- x = np.array(dates, dtype=np.datetime64)
476
- y = np.array(all_dates, dtype=np.datetime64)
477
- bitmap = np.isin(x, y)
478
- return np.where(bitmap)[0]
479
-
480
- indexes = dates_to_indexes(self.dates, dates_in_data)
481
-
482
- array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes)
483
- self.load_cube(cube, array)
484
-
485
- stats = compute_statistics(array.cache, self.variables_names, allow_nans=self.allow_nans)
486
- self.tmp_statistics.write(indexes, stats, dates=dates_in_data)
487
-
488
- array.flush()
489
-
490
- def load_cube(self, cube, array):
491
- # There are several cubelets for each cube
492
- start = time.time()
493
- load = 0
494
- save = 0
495
-
496
- reading_chunks = None
497
- total = cube.count(reading_chunks)
498
- LOG.debug(f"Loading datacube: {cube}")
499
-
500
- def position(x):
501
- if isinstance(x, str) and "/" in x:
502
- x = x.split("/")
503
- return int(x[0])
504
- return None
505
-
506
- bar = tqdm.tqdm(
507
- iterable=cube.iterate_cubelets(reading_chunks),
508
- total=total,
509
- desc=f"Loading datacube {cube}",
510
- position=position(self.parts),
511
- )
512
- for i, cubelet in enumerate(bar):
513
- bar.set_description(f"Loading {i}/{total}")
514
-
515
- now = time.time()
516
- data = cubelet.to_numpy()
517
- local_indexes = cubelet.coords
518
- load += time.time() - now
519
-
520
- name = self.variables_names[local_indexes[1]]
521
- check_data_values(
522
- data[:],
523
- name=name,
524
- log=[i, data.shape, local_indexes],
525
- allow_nans=self.allow_nans,
526
- )
527
-
528
- now = time.time()
529
- array[local_indexes] = data
530
- save += time.time() - now
531
-
532
- now = time.time()
533
- save += time.time() - now
534
- LOG.debug(
535
- f"Elapsed: {seconds_to_human(time.time() - start)}, "
536
- f"load time: {seconds_to_human(load)}, "
537
- f"write time: {seconds_to_human(save)}."
538
- )
539
-
540
-
541
- class StatisticsAdder(DatasetHandlerWithStatistics):
542
- def __init__(
543
- self,
544
- statistics_output=None,
545
- statistics_start=None,
546
- statistics_end=None,
547
- **kwargs,
548
- ):
549
- super().__init__(**kwargs)
550
- self.user_statistics_start = statistics_start
551
- self.user_statistics_end = statistics_end
552
-
553
- self.statistics_output = statistics_output
554
-
555
- self.output_writer = {
556
- None: self.write_stats_to_dataset,
557
- "-": self.write_stats_to_stdout,
558
- }.get(self.statistics_output, self.write_stats_to_file)
559
-
560
- self.read_dataset_metadata()
561
-
562
- @cached_property
563
- def allow_nans(self):
564
- z = zarr.open(self.path, mode="r")
565
- if "allow_nans" in z.attrs:
566
- return z.attrs["allow_nans"]
567
-
568
- if "variables_with_nans" in z.attrs:
569
- return z.attrs["variables_with_nans"]
570
-
571
- warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.")
572
- return True
573
-
574
- def _get_statistics_dates(self):
575
- dates = self.dates
576
- dtype = type(dates[0])
577
-
578
- def assert_dtype(d):
579
- assert type(d) is dtype, (type(d), dtype)
580
-
581
- # remove missing dates
582
- if self.missing_dates:
583
- assert_dtype(self.missing_dates[0])
584
- dates = [d for d in dates if d not in self.missing_dates]
585
-
586
- # filter dates according the the start and end dates in the metadata
587
- z = zarr.open(self.path, mode="r")
588
- start, end = z.attrs.get("statistics_start_date"), z.attrs.get("statistics_end_date")
589
- start, end = np.datetime64(start), np.datetime64(end)
590
- assert_dtype(start)
591
- dates = [d for d in dates if d >= start and d <= end]
592
-
593
- # filter dates according the the user specified start and end dates
594
- if self.user_statistics_start:
595
- limit = as_first_date(self.user_statistics_start, dates)
596
- limit = np.datetime64(limit)
597
- assert_dtype(limit)
598
- dates = [d for d in dates if d >= limit]
599
-
600
- if self.user_statistics_end:
601
- limit = as_last_date(self.user_statistics_start, dates)
602
- limit = np.datetime64(limit)
603
- assert_dtype(limit)
604
- dates = [d for d in dates if d <= limit]
605
-
606
- return dates
607
-
608
- def run(self):
609
- dates = self._get_statistics_dates()
610
- stats = self.tmp_statistics.get_aggregated(dates, self.variables_names, self.allow_nans)
611
- self.output_writer(stats)
612
-
613
- def write_stats_to_file(self, stats):
614
- stats.save(self.statistics_output)
615
- LOG.info(f"✅ Statistics written in {self.statistics_output}")
616
-
617
- def write_stats_to_dataset(self, stats):
618
- if self.user_statistics_start or self.user_statistics_end:
619
- raise ValueError(
620
- (
621
- "Cannot write statistics in dataset with user specified dates. "
622
- "This would be conflicting with the dataset metadata."
623
- )
624
- )
625
-
626
- if not all(self.registry.get_flags(sync=False)):
627
- raise Exception(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.")
628
-
629
- for k in [
630
- "mean",
631
- "stdev",
632
- "minimum",
633
- "maximum",
634
- "sums",
635
- "squares",
636
- "count",
637
- "has_nans",
638
- ]:
639
- self._add_dataset(name=k, array=stats[k], dimensions=("variable",))
640
-
641
- self.registry.add_to_history("compute_statistics_end")
642
- LOG.info(f"Wrote statistics in {self.path}")
643
-
644
- def write_stats_to_stdout(self, stats):
645
- LOG.info(stats)
646
-
647
-
648
- class GenericAdditions(GenericDatasetHandler):
649
- def __init__(self, **kwargs):
650
- super().__init__(**kwargs)
651
- self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True)
652
-
653
- @property
654
- def tmp_storage_path(self):
655
- """This should be implemented in the subclass."""
656
- raise NotImplementedError()
657
-
658
- @property
659
- def final_storage_path(self):
660
- """This should be implemented in the subclass."""
661
- raise NotImplementedError()
662
-
663
- def initialise(self):
664
- self.tmp_storage.delete()
665
- self.tmp_storage.create()
666
- LOG.info(f"Dataset {self.path} additions initialized.")
667
-
668
- def run(self, parts):
669
- """This should be implemented in the subclass."""
670
- raise NotImplementedError()
671
-
672
- def finalise(self):
673
-
674
- shape = (len(self.dates), len(self.variables))
675
- agg = dict(
676
- minimum=np.full(shape, np.nan, dtype=np.float64),
677
- maximum=np.full(shape, np.nan, dtype=np.float64),
678
- sums=np.full(shape, np.nan, dtype=np.float64),
679
- squares=np.full(shape, np.nan, dtype=np.float64),
680
- count=np.full(shape, -1, dtype=np.int64),
681
- has_nans=np.full(shape, False, dtype=np.bool_),
682
- )
683
- LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}")
684
-
685
- found = set()
686
- ifound = set()
687
- missing = set()
688
- for _date, (date, i, stats) in self.tmp_storage.items():
689
- assert _date == date
690
- if stats == "missing":
691
- missing.add(date)
692
- continue
693
-
694
- assert date not in found, f"Duplicates found {date}"
695
- found.add(date)
696
- ifound.add(i)
697
-
698
- for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
699
- agg[k][i, ...] = stats[k]
700
-
701
- assert len(found) + len(missing) == len(self.dates), (
702
- len(found),
703
- len(missing),
704
- len(self.dates),
705
- )
706
- assert found.union(missing) == set(self.dates), (
707
- found,
708
- missing,
709
- set(self.dates),
710
- )
711
-
712
- if len(ifound) < 2:
713
- LOG.warn(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.")
714
- self.tmp_storage.delete()
715
- return
716
-
717
- mask = sorted(list(ifound))
718
- for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
719
- agg[k] = agg[k][mask, ...]
720
-
721
- for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
722
- assert agg[k].shape == agg["count"].shape, (
723
- agg[k].shape,
724
- agg["count"].shape,
725
- )
726
-
727
- minimum = np.nanmin(agg["minimum"], axis=0)
728
- maximum = np.nanmax(agg["maximum"], axis=0)
729
- sums = np.nansum(agg["sums"], axis=0)
730
- squares = np.nansum(agg["squares"], axis=0)
731
- count = np.nansum(agg["count"], axis=0)
732
- has_nans = np.any(agg["has_nans"], axis=0)
733
-
734
- assert sums.shape == count.shape
735
- assert sums.shape == squares.shape
736
- assert sums.shape == minimum.shape
737
- assert sums.shape == maximum.shape
738
- assert sums.shape == has_nans.shape
739
-
740
- mean = sums / count
741
- assert sums.shape == mean.shape
742
-
743
- x = squares / count - mean * mean
744
- # remove negative variance due to numerical errors
745
- # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0
746
- check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares)
747
-
748
- stdev = np.sqrt(x)
749
- assert sums.shape == stdev.shape
750
-
751
- self.summary = Summary(
752
- minimum=minimum,
753
- maximum=maximum,
754
- mean=mean,
755
- count=count,
756
- sums=sums,
757
- squares=squares,
758
- stdev=stdev,
759
- variables_names=self.variables,
760
- has_nans=has_nans,
761
- )
762
- LOG.info(f"Dataset {self.path} additions finalised.")
763
- self.check_statistics()
764
- self._write(self.summary)
765
- self.tmp_storage.delete()
766
-
767
- def _write(self, summary):
768
- for k in [
769
- "mean",
770
- "stdev",
771
- "minimum",
772
- "maximum",
773
- "sums",
774
- "squares",
775
- "count",
776
- "has_nans",
777
- ]:
778
- name = self.final_storage_name(k)
779
- self._add_dataset(name=name, array=summary[k], dimensions=("variable",))
780
- self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end")
781
- LOG.debug(f"Wrote additions in {self.path} ({self.final_storage_name('*')})")
782
-
783
- def check_statistics(self):
784
- pass
785
-
786
- @cached_property
787
- def _variables_with_nans(self):
788
- z = zarr.open(self.path, mode="r")
789
- if "variables_with_nans" in z.attrs:
790
- return z.attrs["variables_with_nans"]
791
- return None
792
-
793
- @cached_property
794
- def _allow_nans(self):
795
- z = zarr.open(self.path, mode="r")
796
- return z.attrs.get("allow_nans", False)
797
-
798
- def allow_nans(self):
799
-
800
- if self._allow_nans:
801
- return True
802
-
803
- if self._variables_with_nans is not None:
804
- return self._variables_with_nans
805
- warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.")
806
- return True
807
-
808
-
809
- class StatisticsAddition(GenericAdditions):
810
- def __init__(self, **kwargs):
811
- super().__init__(**kwargs)
812
-
813
- z = zarr.open(self.path, mode="r")
814
- start = z.attrs["statistics_start_date"]
815
- end = z.attrs["statistics_end_date"]
816
- self.ds = open_dataset(self.path, start=start, end=end)
817
-
818
- self.variables = self.ds.variables
819
- self.dates = self.ds.dates
820
-
821
- assert len(self.variables) == self.ds.shape[1], self.ds.shape
822
- self.total = len(self.dates)
823
-
824
- @property
825
- def tmp_storage_path(self):
826
- return f"{self.path}.storage_statistics.tmp"
827
-
828
- def final_storage_name(self, k):
829
- return k
830
-
831
- def run(self, parts):
832
- chunk_filter = ChunkFilter(parts=parts, total=self.total)
833
- for i in range(0, self.total):
834
- if not chunk_filter(i):
835
- continue
836
- date = self.dates[i]
837
- try:
838
- arr = self.ds[i : i + 1, ...]
839
- stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans)
840
- self.tmp_storage.add([date, i, stats], key=date)
841
- except MissingDateError:
842
- self.tmp_storage.add([date, i, "missing"], key=date)
843
- self.tmp_storage.flush()
844
- LOG.debug(f"Dataset {self.path} additions run.")
845
-
846
- def check_statistics(self):
847
- ds = open_dataset(self.path)
848
- ref = ds.statistics
849
- for k in ds.statistics:
850
- assert np.all(np.isclose(ref[k], self.summary[k], rtol=1e-4, atol=1e-4)), (
851
- k,
852
- ref[k],
853
- self.summary[k],
854
- )
855
-
856
-
857
- class DeltaDataset:
858
- def __init__(self, ds, idelta):
859
- self.ds = ds
860
- self.idelta = idelta
861
-
862
- def __getitem__(self, i):
863
- j = i - self.idelta
864
- if j < 0:
865
- raise MissingDateError(f"Missing date {j}")
866
- return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...]
867
-
868
-
869
- class TendenciesStatisticsDeltaNotMultipleOfFrequency(ValueError):
870
- pass
871
-
872
-
873
- class TendenciesStatisticsAddition(GenericAdditions):
874
- def __init__(self, path, delta=None, **kwargs):
875
- full_ds = open_dataset(path)
876
- self.variables = full_ds.variables
877
-
878
- frequency = full_ds.frequency
879
- if delta is None:
880
- delta = frequency
881
- assert isinstance(delta, int), delta
882
- if not delta % frequency == 0:
883
- raise TendenciesStatisticsDeltaNotMultipleOfFrequency(
884
- f"Delta {delta} is not a multiple of frequency {frequency}"
885
- )
886
- self.delta = delta
887
- idelta = delta // frequency
888
-
889
- super().__init__(path=path, **kwargs)
890
-
891
- z = zarr.open(self.path, mode="r")
892
- start = z.attrs["statistics_start_date"]
893
- end = z.attrs["statistics_end_date"]
894
- start = datetime.datetime.fromisoformat(start)
895
- ds = open_dataset(self.path, start=start, end=end)
896
- self.dates = ds.dates
897
- self.total = len(self.dates)
898
-
899
- ds = open_dataset(self.path, start=start, end=end)
900
- self.ds = DeltaDataset(ds, idelta)
901
-
902
- @property
903
- def tmp_storage_path(self):
904
- return f"{self.path}.storage_statistics_{self.delta}h.tmp"
905
-
906
- def final_storage_name(self, k):
907
- return self.final_storage_name_from_delta(k, delta=self.delta)
908
-
909
- @classmethod
910
- def final_storage_name_from_delta(_, k, delta):
911
- if isinstance(delta, int):
912
- delta = str(delta)
913
- if not delta.endswith("h"):
914
- delta = delta + "h"
915
- return f"statistics_tendencies_{delta}_{k}"
916
-
917
- def run(self, parts):
918
- chunk_filter = ChunkFilter(parts=parts, total=self.total)
919
- for i in range(0, self.total):
920
- if not chunk_filter(i):
921
- continue
922
- date = self.dates[i]
923
- try:
924
- arr = self.ds[i]
925
- stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans)
926
- self.tmp_storage.add([date, i, stats], key=date)
927
- except MissingDateError:
928
- self.tmp_storage.add([date, i, "missing"], key=date)
929
- self.tmp_storage.flush()
930
- LOG.debug(f"Dataset {self.path} additions run.")
931
-
932
-
933
- class DatasetVerifier(GenericDatasetHandler):
934
-
935
- def verify(self):
936
- pass