anemoi-datasets 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/cleanup.py +44 -0
  3. anemoi/datasets/commands/create.py +50 -20
  4. anemoi/datasets/commands/finalise-additions.py +45 -0
  5. anemoi/datasets/commands/finalise.py +39 -0
  6. anemoi/datasets/commands/init-additions.py +45 -0
  7. anemoi/datasets/commands/init.py +67 -0
  8. anemoi/datasets/commands/inspect.py +1 -1
  9. anemoi/datasets/commands/load-additions.py +47 -0
  10. anemoi/datasets/commands/load.py +47 -0
  11. anemoi/datasets/commands/patch.py +39 -0
  12. anemoi/datasets/compute/recentre.py +1 -1
  13. anemoi/datasets/create/__init__.py +961 -146
  14. anemoi/datasets/create/check.py +5 -3
  15. anemoi/datasets/create/config.py +53 -2
  16. anemoi/datasets/create/functions/sources/accumulations.py +6 -22
  17. anemoi/datasets/create/functions/sources/hindcasts.py +27 -12
  18. anemoi/datasets/create/functions/sources/tendencies.py +1 -1
  19. anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
  20. anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
  21. anemoi/datasets/create/functions/sources/xarray/field.py +1 -1
  22. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
  23. anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
  24. anemoi/datasets/create/functions/sources/xarray/metadata.py +27 -29
  25. anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
  26. anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
  27. anemoi/datasets/create/input.py +62 -25
  28. anemoi/datasets/create/statistics/__init__.py +39 -23
  29. anemoi/datasets/create/utils.py +3 -2
  30. anemoi/datasets/data/__init__.py +1 -0
  31. anemoi/datasets/data/concat.py +46 -2
  32. anemoi/datasets/data/dataset.py +109 -34
  33. anemoi/datasets/data/forwards.py +17 -8
  34. anemoi/datasets/data/grids.py +17 -3
  35. anemoi/datasets/data/interpolate.py +133 -0
  36. anemoi/datasets/data/misc.py +56 -66
  37. anemoi/datasets/data/missing.py +240 -0
  38. anemoi/datasets/data/select.py +7 -1
  39. anemoi/datasets/data/stores.py +3 -3
  40. anemoi/datasets/data/subset.py +47 -5
  41. anemoi/datasets/data/unchecked.py +20 -22
  42. anemoi/datasets/data/xy.py +125 -0
  43. anemoi/datasets/dates/__init__.py +33 -20
  44. anemoi/datasets/dates/groups.py +2 -2
  45. anemoi/datasets/grids.py +66 -48
  46. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/METADATA +5 -5
  47. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/RECORD +51 -41
  48. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/WHEEL +1 -1
  49. anemoi/datasets/create/loaders.py +0 -924
  50. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/LICENSE +0 -0
  51. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/entry_points.txt +0 -0
  52. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/top_level.txt +0 -0
@@ -1,924 +0,0 @@
1
- # (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
2
- # This software is licensed under the terms of the Apache Licence Version 2.0
3
- # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
- # In applying this licence, ECMWF does not waive the privileges and immunities
5
- # granted to it by virtue of its status as an intergovernmental organisation
6
- # nor does it submit to any jurisdiction.
7
- import datetime
8
- import json
9
- import logging
10
- import os
11
- import time
12
- import uuid
13
- import warnings
14
- from functools import cached_property
15
-
16
- import numpy as np
17
- import tqdm
18
- import zarr
19
- from anemoi.utils.config import DotDict
20
- from anemoi.utils.humanize import seconds_to_human
21
-
22
- from anemoi.datasets import MissingDateError
23
- from anemoi.datasets import open_dataset
24
- from anemoi.datasets.create.persistent import build_storage
25
- from anemoi.datasets.data.misc import as_first_date
26
- from anemoi.datasets.data.misc import as_last_date
27
- from anemoi.datasets.dates.groups import Groups
28
-
29
- from .check import DatasetName
30
- from .check import check_data_values
31
- from .chunks import ChunkFilter
32
- from .config import build_output
33
- from .config import loader_config
34
- from .input import build_input
35
- from .statistics import Summary
36
- from .statistics import TmpStatistics
37
- from .statistics import check_variance
38
- from .statistics import compute_statistics
39
- from .statistics import default_statistics_dates
40
- from .utils import normalize_and_check_dates
41
- from .writer import ViewCacheArray
42
- from .zarr import ZarrBuiltRegistry
43
- from .zarr import add_zarr_dataset
44
-
45
- LOG = logging.getLogger(__name__)
46
-
47
- VERSION = "0.20"
48
-
49
-
50
- def set_to_test_mode(cfg):
51
- NUMBER_OF_DATES = 4
52
-
53
- dates = cfg.dates
54
- LOG.warn(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.")
55
- groups = Groups(**cfg.dates)
56
- dates = groups.dates
57
- cfg.dates = dict(
58
- start=dates[0],
59
- end=dates[NUMBER_OF_DATES - 1],
60
- frequency=dates.frequency,
61
- group_by=NUMBER_OF_DATES,
62
- )
63
-
64
- def set_element_to_test(obj):
65
- if isinstance(obj, (list, tuple)):
66
- for v in obj:
67
- set_element_to_test(v)
68
- return
69
- if isinstance(obj, (dict, DotDict)):
70
- if "grid" in obj:
71
- previous = obj["grid"]
72
- obj["grid"] = "20./20."
73
- LOG.warn(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}")
74
- if "number" in obj:
75
- if isinstance(obj["number"], (list, tuple)):
76
- previous = obj["number"]
77
- obj["number"] = previous[0:3]
78
- LOG.warn(f"Running in test mode. Setting number to {obj['number']} instead of {previous}")
79
- for k, v in obj.items():
80
- set_element_to_test(v)
81
- if "constants" in obj:
82
- constants = obj["constants"]
83
- if "param" in constants and isinstance(constants["param"], list):
84
- constants["param"] = ["cos_latitude"]
85
-
86
- set_element_to_test(cfg)
87
-
88
-
89
- class GenericDatasetHandler:
90
- def __init__(self, *, path, use_threads=False, **kwargs):
91
-
92
- # Catch all floating point errors, including overflow, sqrt(<0), etc
93
- np.seterr(all="raise", under="warn")
94
-
95
- assert isinstance(path, str), path
96
-
97
- self.path = path
98
- self.kwargs = kwargs
99
- self.use_threads = use_threads
100
- if "test" in kwargs:
101
- self.test = kwargs["test"]
102
-
103
- @classmethod
104
- def from_config(cls, *, config, path, use_threads=False, **kwargs):
105
- """Config is the path to the config file or a dict with the config"""
106
-
107
- assert isinstance(config, dict) or isinstance(config, str), config
108
- return cls(config=config, path=path, use_threads=use_threads, **kwargs)
109
-
110
- @classmethod
111
- def from_dataset_config(cls, *, path, use_threads=False, **kwargs):
112
- """Read the config saved inside the zarr dataset and instantiate the class for this config."""
113
-
114
- assert os.path.exists(path), f"Path {path} does not exist."
115
- z = zarr.open(path, mode="r")
116
- config = z.attrs["_create_yaml_config"]
117
- LOG.debug("Config loaded from zarr config:\n%s", json.dumps(config, indent=4, sort_keys=True, default=str))
118
- return cls.from_config(config=config, path=path, use_threads=use_threads, **kwargs)
119
-
120
- @classmethod
121
- def from_dataset(cls, *, path, use_threads=False, **kwargs):
122
- """Instanciate the class from the path to the zarr dataset, without config."""
123
-
124
- assert os.path.exists(path), f"Path {path} does not exist."
125
- return cls(path=path, use_threads=use_threads, **kwargs)
126
-
127
- def read_dataset_metadata(self):
128
- ds = open_dataset(self.path)
129
- self.dataset_shape = ds.shape
130
- self.variables_names = ds.variables
131
- assert len(self.variables_names) == ds.shape[1], self.dataset_shape
132
- self.dates = ds.dates
133
-
134
- self.missing_dates = sorted(list([self.dates[i] for i in ds.missing]))
135
-
136
- z = zarr.open(self.path, "r")
137
- missing_dates = z.attrs.get("missing_dates", [])
138
- missing_dates = sorted([np.datetime64(d) for d in missing_dates])
139
-
140
- if missing_dates != self.missing_dates:
141
- LOG.warn("Missing dates given in recipe do not match the actual missing dates in the dataset.")
142
- LOG.warn(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}")
143
- LOG.warn(f"Missing dates in dataset: {sorted(str(x) for x in self.missing_dates)}")
144
- raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.")
145
-
146
- @cached_property
147
- def registry(self):
148
- return ZarrBuiltRegistry(self.path, use_threads=self.use_threads)
149
-
150
- def ready(self):
151
- return all(self.registry.get_flags())
152
-
153
- def update_metadata(self, **kwargs):
154
- LOG.debug(f"Updating metadata {kwargs}")
155
- z = zarr.open(self.path, mode="w+")
156
- for k, v in kwargs.items():
157
- if isinstance(v, np.datetime64):
158
- v = v.astype(datetime.datetime)
159
- if isinstance(v, datetime.date):
160
- v = v.isoformat()
161
- z.attrs[k] = v
162
-
163
- def _add_dataset(self, mode="r+", **kwargs):
164
- z = zarr.open(self.path, mode=mode)
165
- return add_zarr_dataset(zarr_root=z, **kwargs)
166
-
167
- def get_zarr_chunks(self):
168
- z = zarr.open(self.path, mode="r")
169
- return z["data"].chunks
170
-
171
- def print_info(self):
172
- z = zarr.open(self.path, mode="r")
173
- try:
174
- LOG.info(z["data"].info)
175
- except Exception as e:
176
- LOG.info(e)
177
-
178
-
179
- class DatasetHandler(GenericDatasetHandler):
180
- pass
181
-
182
-
183
- class DatasetHandlerWithStatistics(GenericDatasetHandler):
184
- def __init__(self, statistics_tmp=None, **kwargs):
185
- super().__init__(**kwargs)
186
- statistics_tmp = kwargs.get("statistics_tmp") or os.path.join(self.path + ".storage_for_statistics.tmp")
187
- self.tmp_statistics = TmpStatistics(statistics_tmp)
188
-
189
-
190
- class Loader(DatasetHandlerWithStatistics):
191
- def build_input(self):
192
- from earthkit.data.core.order import build_remapping
193
-
194
- builder = build_input(
195
- self.main_config.input,
196
- data_sources=self.main_config.get("data_sources", {}),
197
- order_by=self.output.order_by,
198
- flatten_grid=self.output.flatten_grid,
199
- remapping=build_remapping(self.output.remapping),
200
- use_grib_paramid=self.main_config.build.use_grib_paramid,
201
- )
202
- LOG.debug("✅ INPUT_BUILDER")
203
- LOG.debug(builder)
204
- return builder
205
-
206
- @property
207
- def allow_nans(self):
208
- if "allow_nans" in self.main_config.build:
209
- return self.main_config.build.allow_nans
210
-
211
- return self.main_config.statistics.get("allow_nans", [])
212
-
213
-
214
- class InitialiserLoader(Loader):
215
- def __init__(self, config, **kwargs):
216
- super().__init__(**kwargs)
217
-
218
- self.main_config = loader_config(config)
219
- if self.test:
220
- set_to_test_mode(self.main_config)
221
-
222
- LOG.info(dict(self.main_config.dates))
223
-
224
- self.tmp_statistics.delete()
225
-
226
- self.groups = Groups(**self.main_config.dates)
227
- LOG.info(self.groups)
228
-
229
- self.output = build_output(self.main_config.output, parent=self)
230
- self.input = self.build_input()
231
- LOG.info(self.input)
232
-
233
- first_date = self.groups.dates[0]
234
- self.minimal_input = self.input.select([first_date])
235
- LOG.info("Minimal input (using only the first date) :")
236
- LOG.info(self.minimal_input)
237
-
238
- def build_statistics_dates(self, start, end):
239
- """Compute the start and end dates for the statistics, based on :
240
- - The start and end dates in the config
241
- - The default statistics dates convention
242
-
243
- Then adapt according to the actual dates in the dataset.
244
- """
245
-
246
- ds = open_dataset(self.path)
247
- dates = ds.dates
248
-
249
- # if not specified, use the default statistics dates
250
- default_start, default_end = default_statistics_dates(dates)
251
- if start is None:
252
- start = default_start
253
- if end is None:
254
- end = default_end
255
-
256
- # in any case, adapt to the actual dates in the dataset
257
- start = as_first_date(start, dates)
258
- end = as_last_date(end, dates)
259
-
260
- # and convert to datetime to isoformat
261
- start = start.astype(datetime.datetime)
262
- end = end.astype(datetime.datetime)
263
- return (start.isoformat(), end.isoformat())
264
-
265
- def initialise_dataset_backend(self):
266
- z = zarr.open(self.path, mode="w")
267
- z.create_group("_build")
268
-
269
- def initialise(self, check_name=True):
270
- """Create an empty dataset of the right final shape
271
-
272
- Read a small part of the data to get the shape of the data and the resolution and more metadata.
273
- """
274
-
275
- LOG.info("Config loaded ok:")
276
- # LOG.info(self.main_config)
277
-
278
- dates = self.groups.dates
279
- frequency = dates.frequency
280
- assert isinstance(frequency, int), frequency
281
-
282
- LOG.info(f"Found {len(dates)} datetimes.")
283
- LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ")
284
- LOG.info(f"Missing dates: {len(dates.missing)}")
285
- lengths = tuple(len(g) for g in self.groups)
286
-
287
- variables = self.minimal_input.variables
288
- LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.")
289
-
290
- variables_with_nans = self.main_config.statistics.get("allow_nans", [])
291
-
292
- ensembles = self.minimal_input.ensembles
293
- LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.")
294
-
295
- grid_points = self.minimal_input.grid_points
296
- LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}")
297
-
298
- resolution = self.minimal_input.resolution
299
- LOG.info(f"{resolution=}")
300
-
301
- coords = self.minimal_input.coords
302
- coords["dates"] = dates
303
- total_shape = self.minimal_input.shape
304
- total_shape[0] = len(dates)
305
- LOG.info(f"total_shape = {total_shape}")
306
-
307
- chunks = self.output.get_chunking(coords)
308
- LOG.info(f"{chunks=}")
309
- dtype = self.output.dtype
310
-
311
- LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}")
312
-
313
- metadata = {}
314
- metadata["uuid"] = str(uuid.uuid4())
315
-
316
- metadata.update(self.main_config.get("add_metadata", {}))
317
-
318
- metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict()
319
-
320
- metadata["description"] = self.main_config.description
321
- metadata["licence"] = self.main_config["licence"]
322
- metadata["attribution"] = self.main_config["attribution"]
323
-
324
- metadata["remapping"] = self.output.remapping
325
- metadata["order_by"] = self.output.order_by_as_list
326
- metadata["flatten_grid"] = self.output.flatten_grid
327
-
328
- metadata["ensemble_dimension"] = len(ensembles)
329
- metadata["variables"] = variables
330
- metadata["variables_with_nans"] = variables_with_nans
331
- metadata["allow_nans"] = self.main_config.build.get("allow_nans", False)
332
- metadata["resolution"] = resolution
333
-
334
- metadata["data_request"] = self.minimal_input.data_request
335
- metadata["field_shape"] = self.minimal_input.field_shape
336
- metadata["proj_string"] = self.minimal_input.proj_string
337
-
338
- metadata["start_date"] = dates[0].isoformat()
339
- metadata["end_date"] = dates[-1].isoformat()
340
- metadata["frequency"] = frequency
341
- metadata["missing_dates"] = [_.isoformat() for _ in dates.missing]
342
-
343
- metadata["version"] = VERSION
344
-
345
- if check_name:
346
- basename, ext = os.path.splitext(os.path.basename(self.path)) # noqa: F841
347
- ds_name = DatasetName(basename, resolution, dates[0], dates[-1], frequency)
348
- ds_name.raise_if_not_valid()
349
-
350
- if len(dates) != total_shape[0]:
351
- raise ValueError(
352
- f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) "
353
- f"does not match data shape {total_shape[0]}. {total_shape=}"
354
- )
355
-
356
- dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"])
357
-
358
- metadata.update(self.main_config.get("force_metadata", {}))
359
-
360
- ###############################################################
361
- # write metadata
362
- ###############################################################
363
-
364
- self.initialise_dataset_backend()
365
-
366
- self.update_metadata(**metadata)
367
-
368
- self._add_dataset(
369
- name="data",
370
- chunks=chunks,
371
- dtype=dtype,
372
- shape=total_shape,
373
- dimensions=("time", "variable", "ensemble", "cell"),
374
- )
375
- self._add_dataset(name="dates", array=dates, dimensions=("time",))
376
- self._add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",))
377
- self._add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",))
378
-
379
- self.registry.create(lengths=lengths)
380
- self.tmp_statistics.create(exist_ok=False)
381
- self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version)
382
-
383
- statistics_start, statistics_end = self.build_statistics_dates(
384
- self.main_config.statistics.get("start"),
385
- self.main_config.statistics.get("end"),
386
- )
387
- self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end)
388
- LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}")
389
-
390
- self.registry.add_to_history("init finished")
391
-
392
- assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks())
393
-
394
- # Return the number of groups to process, so we can show a nice progress bar
395
- return len(lengths)
396
-
397
-
398
- class ContentLoader(Loader):
399
- def __init__(self, config, parts, **kwargs):
400
- super().__init__(**kwargs)
401
- self.main_config = loader_config(config)
402
-
403
- self.groups = Groups(**self.main_config.dates)
404
- self.output = build_output(self.main_config.output, parent=self)
405
- self.input = self.build_input()
406
- self.read_dataset_metadata()
407
-
408
- self.parts = parts
409
- total = len(self.registry.get_flags())
410
- self.chunk_filter = ChunkFilter(parts=self.parts, total=total)
411
-
412
- self.data_array = zarr.open(self.path, mode="r+")["data"]
413
- self.n_groups = len(self.groups)
414
-
415
- def load(self):
416
- for igroup, group in enumerate(self.groups):
417
- if not self.chunk_filter(igroup):
418
- continue
419
- if self.registry.get_flag(igroup):
420
- LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)")
421
- continue
422
-
423
- assert isinstance(group[0], datetime.datetime), group
424
-
425
- result = self.input.select(dates=group)
426
- assert result.dates == group, (len(result.dates), len(group))
427
-
428
- LOG.debug(f"Building data for group {igroup}/{self.n_groups}")
429
-
430
- # There are several groups.
431
- # There is one result to load for each group.
432
- self.load_result(result)
433
- self.registry.set_flag(igroup)
434
-
435
- self.registry.add_provenance(name="provenance_load")
436
- self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config)
437
-
438
- # self.print_info()
439
-
440
- def load_result(self, result):
441
- # There is one cube to load for each result.
442
- dates = result.dates
443
-
444
- cube = result.get_cube()
445
- assert cube.extended_user_shape[0] == len(dates), (
446
- cube.extended_user_shape[0],
447
- len(dates),
448
- )
449
-
450
- shape = cube.extended_user_shape
451
- dates_in_data = cube.user_coords["valid_datetime"]
452
-
453
- LOG.debug(f"Loading {shape=} in {self.data_array.shape=}")
454
-
455
- def check_dates_in_data(lst, lst2):
456
- lst2 = [np.datetime64(_) for _ in lst2]
457
- lst = [np.datetime64(_) for _ in lst]
458
- assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2)
459
-
460
- check_dates_in_data(dates_in_data, dates)
461
-
462
- def dates_to_indexes(dates, all_dates):
463
- x = np.array(dates, dtype=np.datetime64)
464
- y = np.array(all_dates, dtype=np.datetime64)
465
- bitmap = np.isin(x, y)
466
- return np.where(bitmap)[0]
467
-
468
- indexes = dates_to_indexes(self.dates, dates_in_data)
469
-
470
- array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes)
471
- self.load_cube(cube, array)
472
-
473
- stats = compute_statistics(array.cache, self.variables_names, allow_nans=self.allow_nans)
474
- self.tmp_statistics.write(indexes, stats, dates=dates_in_data)
475
-
476
- array.flush()
477
-
478
- def load_cube(self, cube, array):
479
- # There are several cubelets for each cube
480
- start = time.time()
481
- load = 0
482
- save = 0
483
-
484
- reading_chunks = None
485
- total = cube.count(reading_chunks)
486
- LOG.debug(f"Loading datacube: {cube}")
487
-
488
- def position(x):
489
- if isinstance(x, str) and "/" in x:
490
- x = x.split("/")
491
- return int(x[0])
492
- return None
493
-
494
- bar = tqdm.tqdm(
495
- iterable=cube.iterate_cubelets(reading_chunks),
496
- total=total,
497
- desc=f"Loading datacube {cube}",
498
- position=position(self.parts),
499
- )
500
- for i, cubelet in enumerate(bar):
501
- bar.set_description(f"Loading {i}/{total}")
502
-
503
- now = time.time()
504
- data = cubelet.to_numpy()
505
- local_indexes = cubelet.coords
506
- load += time.time() - now
507
-
508
- name = self.variables_names[local_indexes[1]]
509
- check_data_values(
510
- data[:],
511
- name=name,
512
- log=[i, data.shape, local_indexes],
513
- allow_nans=self.allow_nans,
514
- )
515
-
516
- now = time.time()
517
- array[local_indexes] = data
518
- save += time.time() - now
519
-
520
- now = time.time()
521
- save += time.time() - now
522
- LOG.debug(
523
- f"Elapsed: {seconds_to_human(time.time() - start)}, "
524
- f"load time: {seconds_to_human(load)}, "
525
- f"write time: {seconds_to_human(save)}."
526
- )
527
-
528
-
529
- class StatisticsAdder(DatasetHandlerWithStatistics):
530
- def __init__(
531
- self,
532
- statistics_output=None,
533
- statistics_start=None,
534
- statistics_end=None,
535
- **kwargs,
536
- ):
537
- super().__init__(**kwargs)
538
- self.user_statistics_start = statistics_start
539
- self.user_statistics_end = statistics_end
540
-
541
- self.statistics_output = statistics_output
542
-
543
- self.output_writer = {
544
- None: self.write_stats_to_dataset,
545
- "-": self.write_stats_to_stdout,
546
- }.get(self.statistics_output, self.write_stats_to_file)
547
-
548
- self.read_dataset_metadata()
549
-
550
- @cached_property
551
- def allow_nans(self):
552
- z = zarr.open(self.path, mode="r")
553
- if "allow_nans" in z.attrs:
554
- return z.attrs["allow_nans"]
555
-
556
- if "variables_with_nans" in z.attrs:
557
- return z.attrs["variables_with_nans"]
558
-
559
- warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.")
560
- return True
561
-
562
- def _get_statistics_dates(self):
563
- dates = self.dates
564
- dtype = type(dates[0])
565
-
566
- def assert_dtype(d):
567
- assert type(d) is dtype, (type(d), dtype)
568
-
569
- # remove missing dates
570
- if self.missing_dates:
571
- assert_dtype(self.missing_dates[0])
572
- dates = [d for d in dates if d not in self.missing_dates]
573
-
574
- # filter dates according the the start and end dates in the metadata
575
- z = zarr.open(self.path, mode="r")
576
- start, end = z.attrs.get("statistics_start_date"), z.attrs.get("statistics_end_date")
577
- start, end = np.datetime64(start), np.datetime64(end)
578
- assert_dtype(start)
579
- dates = [d for d in dates if d >= start and d <= end]
580
-
581
- # filter dates according the the user specified start and end dates
582
- if self.user_statistics_start:
583
- limit = as_first_date(self.user_statistics_start, dates)
584
- limit = np.datetime64(limit)
585
- assert_dtype(limit)
586
- dates = [d for d in dates if d >= limit]
587
-
588
- if self.user_statistics_end:
589
- limit = as_last_date(self.user_statistics_start, dates)
590
- limit = np.datetime64(limit)
591
- assert_dtype(limit)
592
- dates = [d for d in dates if d <= limit]
593
-
594
- return dates
595
-
596
- def run(self):
597
- dates = self._get_statistics_dates()
598
- stats = self.tmp_statistics.get_aggregated(dates, self.variables_names, self.allow_nans)
599
- self.output_writer(stats)
600
-
601
- def write_stats_to_file(self, stats):
602
- stats.save(self.statistics_output)
603
- LOG.info(f"✅ Statistics written in {self.statistics_output}")
604
-
605
- def write_stats_to_dataset(self, stats):
606
- if self.user_statistics_start or self.user_statistics_end:
607
- raise ValueError(
608
- (
609
- "Cannot write statistics in dataset with user specified dates. "
610
- "This would be conflicting with the dataset metadata."
611
- )
612
- )
613
-
614
- if not all(self.registry.get_flags(sync=False)):
615
- raise Exception(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.")
616
-
617
- for k in [
618
- "mean",
619
- "stdev",
620
- "minimum",
621
- "maximum",
622
- "sums",
623
- "squares",
624
- "count",
625
- "has_nans",
626
- ]:
627
- self._add_dataset(name=k, array=stats[k], dimensions=("variable",))
628
-
629
- self.registry.add_to_history("compute_statistics_end")
630
- LOG.info(f"Wrote statistics in {self.path}")
631
-
632
- def write_stats_to_stdout(self, stats):
633
- LOG.info(stats)
634
-
635
-
636
- class GenericAdditions(GenericDatasetHandler):
637
- def __init__(self, **kwargs):
638
- super().__init__(**kwargs)
639
- self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True)
640
-
641
- @property
642
- def tmp_storage_path(self):
643
- """This should be implemented in the subclass."""
644
- raise NotImplementedError()
645
-
646
- @property
647
- def final_storage_path(self):
648
- """This should be implemented in the subclass."""
649
- raise NotImplementedError()
650
-
651
- def initialise(self):
652
- self.tmp_storage.delete()
653
- self.tmp_storage.create()
654
- LOG.info(f"Dataset {self.path} additions initialized.")
655
-
656
- def run(self, parts):
657
- """This should be implemented in the subclass."""
658
- raise NotImplementedError()
659
-
660
- def finalise(self):
661
-
662
- shape = (len(self.dates), len(self.variables))
663
- agg = dict(
664
- minimum=np.full(shape, np.nan, dtype=np.float64),
665
- maximum=np.full(shape, np.nan, dtype=np.float64),
666
- sums=np.full(shape, np.nan, dtype=np.float64),
667
- squares=np.full(shape, np.nan, dtype=np.float64),
668
- count=np.full(shape, -1, dtype=np.int64),
669
- has_nans=np.full(shape, False, dtype=np.bool_),
670
- )
671
- LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}")
672
-
673
- found = set()
674
- ifound = set()
675
- missing = set()
676
- for _date, (date, i, stats) in self.tmp_storage.items():
677
- assert _date == date
678
- if stats == "missing":
679
- missing.add(date)
680
- continue
681
-
682
- assert date not in found, f"Duplicates found {date}"
683
- found.add(date)
684
- ifound.add(i)
685
-
686
- for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
687
- agg[k][i, ...] = stats[k]
688
-
689
- assert len(found) + len(missing) == len(self.dates), (
690
- len(found),
691
- len(missing),
692
- len(self.dates),
693
- )
694
- assert found.union(missing) == set(self.dates), (
695
- found,
696
- missing,
697
- set(self.dates),
698
- )
699
-
700
- if len(ifound) < 2:
701
- LOG.warn(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.")
702
- self.tmp_storage.delete()
703
- return
704
-
705
- mask = sorted(list(ifound))
706
- for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
707
- agg[k] = agg[k][mask, ...]
708
-
709
- for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
710
- assert agg[k].shape == agg["count"].shape, (
711
- agg[k].shape,
712
- agg["count"].shape,
713
- )
714
-
715
- minimum = np.nanmin(agg["minimum"], axis=0)
716
- maximum = np.nanmax(agg["maximum"], axis=0)
717
- sums = np.nansum(agg["sums"], axis=0)
718
- squares = np.nansum(agg["squares"], axis=0)
719
- count = np.nansum(agg["count"], axis=0)
720
- has_nans = np.any(agg["has_nans"], axis=0)
721
-
722
- assert sums.shape == count.shape
723
- assert sums.shape == squares.shape
724
- assert sums.shape == minimum.shape
725
- assert sums.shape == maximum.shape
726
- assert sums.shape == has_nans.shape
727
-
728
- mean = sums / count
729
- assert sums.shape == mean.shape
730
-
731
- x = squares / count - mean * mean
732
- # remove negative variance due to numerical errors
733
- # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0
734
- check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares)
735
-
736
- stdev = np.sqrt(x)
737
- assert sums.shape == stdev.shape
738
-
739
- self.summary = Summary(
740
- minimum=minimum,
741
- maximum=maximum,
742
- mean=mean,
743
- count=count,
744
- sums=sums,
745
- squares=squares,
746
- stdev=stdev,
747
- variables_names=self.variables,
748
- has_nans=has_nans,
749
- )
750
- LOG.info(f"Dataset {self.path} additions finalised.")
751
- self.check_statistics()
752
- self._write(self.summary)
753
- self.tmp_storage.delete()
754
-
755
- def _write(self, summary):
756
- for k in [
757
- "mean",
758
- "stdev",
759
- "minimum",
760
- "maximum",
761
- "sums",
762
- "squares",
763
- "count",
764
- "has_nans",
765
- ]:
766
- name = self.final_storage_name(k)
767
- self._add_dataset(name=name, array=summary[k], dimensions=("variable",))
768
- self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end")
769
- LOG.debug(f"Wrote additions in {self.path} ({self.final_storage_name('*')})")
770
-
771
- def check_statistics(self):
772
- pass
773
-
774
- @cached_property
775
- def _variables_with_nans(self):
776
- z = zarr.open(self.path, mode="r")
777
- if "variables_with_nans" in z.attrs:
778
- return z.attrs["variables_with_nans"]
779
- return None
780
-
781
- @cached_property
782
- def _allow_nans(self):
783
- z = zarr.open(self.path, mode="r")
784
- return z.attrs.get("allow_nans", False)
785
-
786
- def allow_nans(self):
787
-
788
- if self._allow_nans:
789
- return True
790
-
791
- if self._variables_with_nans is not None:
792
- return self._variables_with_nans
793
- warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.")
794
- return True
795
-
796
-
797
- class StatisticsAddition(GenericAdditions):
798
- def __init__(self, **kwargs):
799
- super().__init__(**kwargs)
800
-
801
- z = zarr.open(self.path, mode="r")
802
- start = z.attrs["statistics_start_date"]
803
- end = z.attrs["statistics_end_date"]
804
- self.ds = open_dataset(self.path, start=start, end=end)
805
-
806
- self.variables = self.ds.variables
807
- self.dates = self.ds.dates
808
-
809
- assert len(self.variables) == self.ds.shape[1], self.ds.shape
810
- self.total = len(self.dates)
811
-
812
- @property
813
- def tmp_storage_path(self):
814
- return f"{self.path}.storage_statistics.tmp"
815
-
816
- def final_storage_name(self, k):
817
- return k
818
-
819
- def run(self, parts):
820
- chunk_filter = ChunkFilter(parts=parts, total=self.total)
821
- for i in range(0, self.total):
822
- if not chunk_filter(i):
823
- continue
824
- date = self.dates[i]
825
- try:
826
- arr = self.ds[i : i + 1, ...]
827
- stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans)
828
- self.tmp_storage.add([date, i, stats], key=date)
829
- except MissingDateError:
830
- self.tmp_storage.add([date, i, "missing"], key=date)
831
- self.tmp_storage.flush()
832
- LOG.debug(f"Dataset {self.path} additions run.")
833
-
834
- def check_statistics(self):
835
- ds = open_dataset(self.path)
836
- ref = ds.statistics
837
- for k in ds.statistics:
838
- assert np.all(np.isclose(ref[k], self.summary[k], rtol=1e-4, atol=1e-4)), (
839
- k,
840
- ref[k],
841
- self.summary[k],
842
- )
843
-
844
-
845
- class DeltaDataset:
846
- def __init__(self, ds, idelta):
847
- self.ds = ds
848
- self.idelta = idelta
849
-
850
- def __getitem__(self, i):
851
- j = i - self.idelta
852
- if j < 0:
853
- raise MissingDateError(f"Missing date {j}")
854
- return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...]
855
-
856
-
857
- class TendenciesStatisticsDeltaNotMultipleOfFrequency(ValueError):
858
- pass
859
-
860
-
861
- class TendenciesStatisticsAddition(GenericAdditions):
862
- def __init__(self, path, delta=None, **kwargs):
863
- full_ds = open_dataset(path)
864
- self.variables = full_ds.variables
865
-
866
- frequency = full_ds.frequency
867
- if delta is None:
868
- delta = frequency
869
- assert isinstance(delta, int), delta
870
- if not delta % frequency == 0:
871
- raise TendenciesStatisticsDeltaNotMultipleOfFrequency(
872
- f"Delta {delta} is not a multiple of frequency {frequency}"
873
- )
874
- self.delta = delta
875
- idelta = delta // frequency
876
-
877
- super().__init__(path=path, **kwargs)
878
-
879
- z = zarr.open(self.path, mode="r")
880
- start = z.attrs["statistics_start_date"]
881
- end = z.attrs["statistics_end_date"]
882
- start = datetime.datetime.fromisoformat(start)
883
- ds = open_dataset(self.path, start=start, end=end)
884
- self.dates = ds.dates
885
- self.total = len(self.dates)
886
-
887
- ds = open_dataset(self.path, start=start, end=end)
888
- self.ds = DeltaDataset(ds, idelta)
889
-
890
- @property
891
- def tmp_storage_path(self):
892
- return f"{self.path}.storage_statistics_{self.delta}h.tmp"
893
-
894
- def final_storage_name(self, k):
895
- return self.final_storage_name_from_delta(k, delta=self.delta)
896
-
897
- @classmethod
898
- def final_storage_name_from_delta(_, k, delta):
899
- if isinstance(delta, int):
900
- delta = str(delta)
901
- if not delta.endswith("h"):
902
- delta = delta + "h"
903
- return f"statistics_tendencies_{delta}_{k}"
904
-
905
- def run(self, parts):
906
- chunk_filter = ChunkFilter(parts=parts, total=self.total)
907
- for i in range(0, self.total):
908
- if not chunk_filter(i):
909
- continue
910
- date = self.dates[i]
911
- try:
912
- arr = self.ds[i]
913
- stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans)
914
- self.tmp_storage.add([date, i, stats], key=date)
915
- except MissingDateError:
916
- self.tmp_storage.add([date, i, "missing"], key=date)
917
- self.tmp_storage.flush()
918
- LOG.debug(f"Dataset {self.path} additions run.")
919
-
920
-
921
- class DatasetVerifier(GenericDatasetHandler):
922
-
923
- def verify(self):
924
- pass