anemoi-datasets 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/compare.py +59 -0
  3. anemoi/datasets/commands/create.py +84 -3
  4. anemoi/datasets/commands/inspect.py +3 -3
  5. anemoi/datasets/create/__init__.py +44 -17
  6. anemoi/datasets/create/check.py +6 -5
  7. anemoi/datasets/create/chunks.py +1 -1
  8. anemoi/datasets/create/config.py +5 -26
  9. anemoi/datasets/create/functions/filters/rename.py +9 -1
  10. anemoi/datasets/create/functions/filters/rotate_winds.py +10 -1
  11. anemoi/datasets/create/functions/sources/__init__.py +39 -0
  12. anemoi/datasets/create/functions/sources/accumulations.py +11 -41
  13. anemoi/datasets/create/functions/sources/constants.py +3 -0
  14. anemoi/datasets/create/functions/sources/grib.py +4 -0
  15. anemoi/datasets/create/functions/sources/hindcasts.py +32 -377
  16. anemoi/datasets/create/functions/sources/mars.py +53 -22
  17. anemoi/datasets/create/functions/sources/netcdf.py +2 -60
  18. anemoi/datasets/create/functions/sources/opendap.py +3 -2
  19. anemoi/datasets/create/functions/sources/xarray/__init__.py +73 -0
  20. anemoi/datasets/create/functions/sources/xarray/coordinates.py +234 -0
  21. anemoi/datasets/create/functions/sources/xarray/field.py +109 -0
  22. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +171 -0
  23. anemoi/datasets/create/functions/sources/xarray/flavour.py +330 -0
  24. anemoi/datasets/create/functions/sources/xarray/grid.py +46 -0
  25. anemoi/datasets/create/functions/sources/xarray/metadata.py +161 -0
  26. anemoi/datasets/create/functions/sources/xarray/time.py +98 -0
  27. anemoi/datasets/create/functions/sources/xarray/variable.py +198 -0
  28. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +42 -0
  29. anemoi/datasets/create/functions/sources/xarray_zarr.py +15 -0
  30. anemoi/datasets/create/functions/sources/zenodo.py +40 -0
  31. anemoi/datasets/create/input.py +290 -172
  32. anemoi/datasets/create/loaders.py +120 -71
  33. anemoi/datasets/create/patch.py +17 -14
  34. anemoi/datasets/create/persistent.py +1 -1
  35. anemoi/datasets/create/size.py +4 -5
  36. anemoi/datasets/create/statistics/__init__.py +49 -16
  37. anemoi/datasets/create/template.py +11 -61
  38. anemoi/datasets/create/trace.py +91 -0
  39. anemoi/datasets/create/utils.py +0 -48
  40. anemoi/datasets/create/zarr.py +24 -10
  41. anemoi/datasets/data/misc.py +9 -37
  42. anemoi/datasets/data/stores.py +29 -14
  43. anemoi/datasets/dates/__init__.py +7 -1
  44. anemoi/datasets/dates/groups.py +3 -0
  45. {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/METADATA +18 -3
  46. anemoi_datasets-0.4.2.dist-info/RECORD +86 -0
  47. {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/WHEEL +1 -1
  48. anemoi_datasets-0.4.0.dist-info/RECORD +0 -73
  49. {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/LICENSE +0 -0
  50. {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/entry_points.txt +0 -0
  51. {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@
5
5
  # granted to it by virtue of its status as an intergovernmental organisation
6
6
  # nor does it submit to any jurisdiction.
7
7
  import datetime
8
+ import json
8
9
  import logging
9
10
  import os
10
11
  import time
@@ -13,7 +14,10 @@ import warnings
13
14
  from functools import cached_property
14
15
 
15
16
  import numpy as np
17
+ import tqdm
16
18
  import zarr
19
+ from anemoi.utils.config import DotDict
20
+ from anemoi.utils.humanize import seconds_to_human
17
21
 
18
22
  from anemoi.datasets import MissingDateError
19
23
  from anemoi.datasets import open_dataset
@@ -25,7 +29,6 @@ from anemoi.datasets.dates.groups import Groups
25
29
  from .check import DatasetName
26
30
  from .check import check_data_values
27
31
  from .chunks import ChunkFilter
28
- from .config import DictObj
29
32
  from .config import build_output
30
33
  from .config import loader_config
31
34
  from .input import build_input
@@ -35,8 +38,6 @@ from .statistics import check_variance
35
38
  from .statistics import compute_statistics
36
39
  from .statistics import default_statistics_dates
37
40
  from .utils import normalize_and_check_dates
38
- from .utils import progress_bar
39
- from .utils import seconds
40
41
  from .writer import ViewCacheArray
41
42
  from .zarr import ZarrBuiltRegistry
42
43
  from .zarr import add_zarr_dataset
@@ -65,7 +66,7 @@ def set_to_test_mode(cfg):
65
66
  for v in obj:
66
67
  set_element_to_test(v)
67
68
  return
68
- if isinstance(obj, (dict, DictObj)):
69
+ if isinstance(obj, (dict, DotDict)):
69
70
  if "grid" in obj:
70
71
  previous = obj["grid"]
71
72
  obj["grid"] = "20./20."
@@ -77,12 +78,16 @@ def set_to_test_mode(cfg):
77
78
  LOG.warn(f"Running in test mode. Setting number to {obj['number']} instead of {previous}")
78
79
  for k, v in obj.items():
79
80
  set_element_to_test(v)
81
+ if "constants" in obj:
82
+ constants = obj["constants"]
83
+ if "param" in constants and isinstance(constants["param"], list):
84
+ constants["param"] = ["cos_latitude"]
80
85
 
81
86
  set_element_to_test(cfg)
82
87
 
83
88
 
84
89
  class GenericDatasetHandler:
85
- def __init__(self, *, path, print=print, **kwargs):
90
+ def __init__(self, *, path, use_threads=False, **kwargs):
86
91
 
87
92
  # Catch all floating point errors, including overflow, sqrt(<0), etc
88
93
  np.seterr(all="raise", under="warn")
@@ -91,33 +96,33 @@ class GenericDatasetHandler:
91
96
 
92
97
  self.path = path
93
98
  self.kwargs = kwargs
94
- self.print = print
99
+ self.use_threads = use_threads
95
100
  if "test" in kwargs:
96
101
  self.test = kwargs["test"]
97
102
 
98
103
  @classmethod
99
- def from_config(cls, *, config, path, print=print, **kwargs):
104
+ def from_config(cls, *, config, path, use_threads=False, **kwargs):
100
105
  """Config is the path to the config file or a dict with the config"""
101
106
 
102
107
  assert isinstance(config, dict) or isinstance(config, str), config
103
- return cls(config=config, path=path, print=print, **kwargs)
108
+ return cls(config=config, path=path, use_threads=use_threads, **kwargs)
104
109
 
105
110
  @classmethod
106
- def from_dataset_config(cls, *, path, print=print, **kwargs):
111
+ def from_dataset_config(cls, *, path, use_threads=False, **kwargs):
107
112
  """Read the config saved inside the zarr dataset and instantiate the class for this config."""
108
113
 
109
114
  assert os.path.exists(path), f"Path {path} does not exist."
110
115
  z = zarr.open(path, mode="r")
111
116
  config = z.attrs["_create_yaml_config"]
112
- LOG.info(f"Config loaded from zarr config: {config}")
113
- return cls.from_config(config=config, path=path, print=print, **kwargs)
117
+ LOG.debug("Config loaded from zarr config:\n%s", json.dumps(config, indent=4, sort_keys=True, default=str))
118
+ return cls.from_config(config=config, path=path, use_threads=use_threads, **kwargs)
114
119
 
115
120
  @classmethod
116
- def from_dataset(cls, *, path, **kwargs):
121
+ def from_dataset(cls, *, path, use_threads=False, **kwargs):
117
122
  """Instanciate the class from the path to the zarr dataset, without config."""
118
123
 
119
124
  assert os.path.exists(path), f"Path {path} does not exist."
120
- return cls(path=path, **kwargs)
125
+ return cls(path=path, use_threads=use_threads, **kwargs)
121
126
 
122
127
  def read_dataset_metadata(self):
123
128
  ds = open_dataset(self.path)
@@ -131,14 +136,22 @@ class GenericDatasetHandler:
131
136
  z = zarr.open(self.path, "r")
132
137
  missing_dates = z.attrs.get("missing_dates", [])
133
138
  missing_dates = sorted([np.datetime64(d) for d in missing_dates])
134
- assert missing_dates == self.missing_dates, (missing_dates, self.missing_dates)
139
+
140
+ if missing_dates != self.missing_dates:
141
+ LOG.warn("Missing dates given in recipe do not match the actual missing dates in the dataset.")
142
+ LOG.warn(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}")
143
+ LOG.warn(f"Missing dates in dataset: {sorted(str(x) for x in self.missing_dates)}")
144
+ raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.")
135
145
 
136
146
  @cached_property
137
147
  def registry(self):
138
- return ZarrBuiltRegistry(self.path)
148
+ return ZarrBuiltRegistry(self.path, use_threads=self.use_threads)
149
+
150
+ def ready(self):
151
+ return all(self.registry.get_flags())
139
152
 
140
153
  def update_metadata(self, **kwargs):
141
- LOG.info(f"Updating metadata {kwargs}")
154
+ LOG.debug(f"Updating metadata {kwargs}")
142
155
  z = zarr.open(self.path, mode="w+")
143
156
  for k, v in kwargs.items():
144
157
  if isinstance(v, np.datetime64):
@@ -170,7 +183,7 @@ class DatasetHandler(GenericDatasetHandler):
170
183
  class DatasetHandlerWithStatistics(GenericDatasetHandler):
171
184
  def __init__(self, statistics_tmp=None, **kwargs):
172
185
  super().__init__(**kwargs)
173
- statistics_tmp = kwargs.get("statistics_tmp") or os.path.join(self.path + ".tmp_data", "statistics")
186
+ statistics_tmp = kwargs.get("statistics_tmp") or os.path.join(self.path + ".storage_for_statistics.tmp")
174
187
  self.tmp_statistics = TmpStatistics(statistics_tmp)
175
188
 
176
189
 
@@ -186,12 +199,16 @@ class Loader(DatasetHandlerWithStatistics):
186
199
  remapping=build_remapping(self.output.remapping),
187
200
  use_grib_paramid=self.main_config.build.use_grib_paramid,
188
201
  )
189
- LOG.info("✅ INPUT_BUILDER")
190
- LOG.info(builder)
202
+ LOG.debug("✅ INPUT_BUILDER")
203
+ LOG.debug(builder)
191
204
  return builder
192
205
 
193
- def allow_nan(self, name):
194
- return name in self.main_config.statistics.get("allow_nans", [])
206
+ @property
207
+ def allow_nans(self):
208
+ if "allow_nans" in self.main_config.build:
209
+ return self.main_config.build.allow_nans
210
+
211
+ return self.main_config.statistics.get("allow_nans", [])
195
212
 
196
213
 
197
214
  class InitialiserLoader(Loader):
@@ -202,7 +219,7 @@ class InitialiserLoader(Loader):
202
219
  if self.test:
203
220
  set_to_test_mode(self.main_config)
204
221
 
205
- LOG.info(self.main_config.dates)
222
+ LOG.info(dict(self.main_config.dates))
206
223
 
207
224
  self.tmp_statistics.delete()
208
225
 
@@ -255,26 +272,25 @@ class InitialiserLoader(Loader):
255
272
  Read a small part of the data to get the shape of the data and the resolution and more metadata.
256
273
  """
257
274
 
258
- self.print("Config loaded ok:")
259
- LOG.info(self.main_config)
275
+ LOG.info("Config loaded ok:")
276
+ # LOG.info(self.main_config)
260
277
 
261
278
  dates = self.groups.dates
262
279
  frequency = dates.frequency
263
280
  assert isinstance(frequency, int), frequency
264
281
 
265
- self.print(f"Found {len(dates)} datetimes.")
282
+ LOG.info(f"Found {len(dates)} datetimes.")
266
283
  LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ")
267
284
  LOG.info(f"Missing dates: {len(dates.missing)}")
268
- lengths = [len(g) for g in self.groups]
269
- self.print(f"Found {len(dates)} datetimes {'+'.join([str(_) for _ in lengths])}.")
285
+ lengths = tuple(len(g) for g in self.groups)
270
286
 
271
287
  variables = self.minimal_input.variables
272
- self.print(f"Found {len(variables)} variables : {','.join(variables)}.")
288
+ LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.")
273
289
 
274
290
  variables_with_nans = self.main_config.statistics.get("allow_nans", [])
275
291
 
276
292
  ensembles = self.minimal_input.ensembles
277
- self.print(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.")
293
+ LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.")
278
294
 
279
295
  grid_points = self.minimal_input.grid_points
280
296
  LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}")
@@ -286,13 +302,13 @@ class InitialiserLoader(Loader):
286
302
  coords["dates"] = dates
287
303
  total_shape = self.minimal_input.shape
288
304
  total_shape[0] = len(dates)
289
- self.print(f"total_shape = {total_shape}")
305
+ LOG.info(f"total_shape = {total_shape}")
290
306
 
291
307
  chunks = self.output.get_chunking(coords)
292
308
  LOG.info(f"{chunks=}")
293
309
  dtype = self.output.dtype
294
310
 
295
- self.print(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}")
311
+ LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}")
296
312
 
297
313
  metadata = {}
298
314
  metadata["uuid"] = str(uuid.uuid4())
@@ -312,6 +328,7 @@ class InitialiserLoader(Loader):
312
328
  metadata["ensemble_dimension"] = len(ensembles)
313
329
  metadata["variables"] = variables
314
330
  metadata["variables_with_nans"] = variables_with_nans
331
+ metadata["allow_nans"] = self.main_config.build.get("allow_nans", False)
315
332
  metadata["resolution"] = resolution
316
333
 
317
334
  metadata["data_request"] = self.minimal_input.data_request
@@ -328,7 +345,7 @@ class InitialiserLoader(Loader):
328
345
  if check_name:
329
346
  basename, ext = os.path.splitext(os.path.basename(self.path)) # noqa: F841
330
347
  ds_name = DatasetName(basename, resolution, dates[0], dates[-1], frequency)
331
- ds_name.raise_if_not_valid(print=self.print)
348
+ ds_name.raise_if_not_valid()
332
349
 
333
350
  if len(dates) != total_shape[0]:
334
351
  raise ValueError(
@@ -348,10 +365,16 @@ class InitialiserLoader(Loader):
348
365
 
349
366
  self.update_metadata(**metadata)
350
367
 
351
- self._add_dataset(name="data", chunks=chunks, dtype=dtype, shape=total_shape)
352
- self._add_dataset(name="dates", array=dates)
353
- self._add_dataset(name="latitudes", array=grid_points[0])
354
- self._add_dataset(name="longitudes", array=grid_points[1])
368
+ self._add_dataset(
369
+ name="data",
370
+ chunks=chunks,
371
+ dtype=dtype,
372
+ shape=total_shape,
373
+ dimensions=("time", "variable", "ensemble", "cell"),
374
+ )
375
+ self._add_dataset(name="dates", array=dates, dimensions=("time",))
376
+ self._add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",))
377
+ self._add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",))
355
378
 
356
379
  self.registry.create(lengths=lengths)
357
380
  self.tmp_statistics.create(exist_ok=False)
@@ -368,6 +391,9 @@ class InitialiserLoader(Loader):
368
391
 
369
392
  assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks())
370
393
 
394
+ # Return the number of groups to process, so we can show a nice progress bar
395
+ return len(lengths)
396
+
371
397
 
372
398
  class ContentLoader(Loader):
373
399
  def __init__(self, config, parts, **kwargs):
@@ -387,35 +413,29 @@ class ContentLoader(Loader):
387
413
  self.n_groups = len(self.groups)
388
414
 
389
415
  def load(self):
390
- self.registry.add_to_history("loading_data_start", parts=self.parts)
391
-
392
416
  for igroup, group in enumerate(self.groups):
393
417
  if not self.chunk_filter(igroup):
394
418
  continue
395
419
  if self.registry.get_flag(igroup):
396
420
  LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)")
397
421
  continue
398
- # self.print(f" -> Processing {igroup} total={len(self.groups)}")
399
- # print("========", group)
422
+
400
423
  assert isinstance(group[0], datetime.datetime), group
401
424
 
402
425
  result = self.input.select(dates=group)
403
426
  assert result.dates == group, (len(result.dates), len(group))
404
427
 
405
- msg = f"Building data for group {igroup}/{self.n_groups}"
406
- LOG.info(msg)
407
- self.print(msg)
428
+ LOG.debug(f"Building data for group {igroup}/{self.n_groups}")
408
429
 
409
430
  # There are several groups.
410
431
  # There is one result to load for each group.
411
432
  self.load_result(result)
412
433
  self.registry.set_flag(igroup)
413
434
 
414
- self.registry.add_to_history("loading_data_end", parts=self.parts)
415
435
  self.registry.add_provenance(name="provenance_load")
416
436
  self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config)
417
437
 
418
- self.print_info()
438
+ # self.print_info()
419
439
 
420
440
  def load_result(self, result):
421
441
  # There is one cube to load for each result.
@@ -430,7 +450,7 @@ class ContentLoader(Loader):
430
450
  shape = cube.extended_user_shape
431
451
  dates_in_data = cube.user_coords["valid_datetime"]
432
452
 
433
- LOG.info(f"Loading {shape=} in {self.data_array.shape=}")
453
+ LOG.debug(f"Loading {shape=} in {self.data_array.shape=}")
434
454
 
435
455
  def check_dates_in_data(lst, lst2):
436
456
  lst2 = [np.datetime64(_) for _ in lst2]
@@ -450,7 +470,7 @@ class ContentLoader(Loader):
450
470
  array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes)
451
471
  self.load_cube(cube, array)
452
472
 
453
- stats = compute_statistics(array.cache, self.variables_names, allow_nan=self.allow_nan)
473
+ stats = compute_statistics(array.cache, self.variables_names, allow_nans=self.allow_nans)
454
474
  self.tmp_statistics.write(indexes, stats, dates=dates_in_data)
455
475
 
456
476
  array.flush()
@@ -463,11 +483,19 @@ class ContentLoader(Loader):
463
483
 
464
484
  reading_chunks = None
465
485
  total = cube.count(reading_chunks)
466
- self.print(f"Loading datacube: {cube}")
467
- bar = progress_bar(
486
+ LOG.debug(f"Loading datacube: {cube}")
487
+
488
+ def position(x):
489
+ if isinstance(x, str) and "/" in x:
490
+ x = x.split("/")
491
+ return int(x[0])
492
+ return None
493
+
494
+ bar = tqdm.tqdm(
468
495
  iterable=cube.iterate_cubelets(reading_chunks),
469
496
  total=total,
470
497
  desc=f"Loading datacube {cube}",
498
+ position=position(self.parts),
471
499
  )
472
500
  for i, cubelet in enumerate(bar):
473
501
  bar.set_description(f"Loading {i}/{total}")
@@ -482,7 +510,7 @@ class ContentLoader(Loader):
482
510
  data[:],
483
511
  name=name,
484
512
  log=[i, data.shape, local_indexes],
485
- allow_nan=self.allow_nan,
513
+ allow_nans=self.allow_nans,
486
514
  )
487
515
 
488
516
  now = time.time()
@@ -491,10 +519,11 @@ class ContentLoader(Loader):
491
519
 
492
520
  now = time.time()
493
521
  save += time.time() - now
494
- LOG.info("Written.")
495
- msg = f"Elapsed: {seconds(time.time() - start)}, load time: {seconds(load)}, write time: {seconds(save)}."
496
- self.print(msg)
497
- LOG.info(msg)
522
+ LOG.debug(
523
+ f"Elapsed: {seconds_to_human(time.time() - start)}, "
524
+ f"load time: {seconds_to_human(load)}, "
525
+ f"write time: {seconds_to_human(save)}."
526
+ )
498
527
 
499
528
 
500
529
  class StatisticsAdder(DatasetHandlerWithStatistics):
@@ -518,12 +547,16 @@ class StatisticsAdder(DatasetHandlerWithStatistics):
518
547
 
519
548
  self.read_dataset_metadata()
520
549
 
521
- def allow_nan(self, name):
550
+ @cached_property
551
+ def allow_nans(self):
522
552
  z = zarr.open(self.path, mode="r")
553
+ if "allow_nans" in z.attrs:
554
+ return z.attrs["allow_nans"]
555
+
523
556
  if "variables_with_nans" in z.attrs:
524
- return name in z.attrs["variables_with_nans"]
557
+ return z.attrs["variables_with_nans"]
525
558
 
526
- warnings.warn(f"Cannot find 'variables_with_nans' in {self.path}. Assuming nans allowed for {name}.")
559
+ warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.")
527
560
  return True
528
561
 
529
562
  def _get_statistics_dates(self):
@@ -562,7 +595,7 @@ class StatisticsAdder(DatasetHandlerWithStatistics):
562
595
 
563
596
  def run(self):
564
597
  dates = self._get_statistics_dates()
565
- stats = self.tmp_statistics.get_aggregated(dates, self.variables_names, self.allow_nan)
598
+ stats = self.tmp_statistics.get_aggregated(dates, self.variables_names, self.allow_nans)
566
599
  self.output_writer(stats)
567
600
 
568
601
  def write_stats_to_file(self, stats):
@@ -591,7 +624,7 @@ class StatisticsAdder(DatasetHandlerWithStatistics):
591
624
  "count",
592
625
  "has_nans",
593
626
  ]:
594
- self._add_dataset(name=k, array=stats[k])
627
+ self._add_dataset(name=k, array=stats[k], dimensions=("variable",))
595
628
 
596
629
  self.registry.add_to_history("compute_statistics_end")
597
630
  LOG.info(f"Wrote statistics in {self.path}")
@@ -625,6 +658,7 @@ class GenericAdditions(GenericDatasetHandler):
625
658
  raise NotImplementedError()
626
659
 
627
660
  def finalise(self):
661
+
628
662
  shape = (len(self.dates), len(self.variables))
629
663
  agg = dict(
630
664
  minimum=np.full(shape, np.nan, dtype=np.float64),
@@ -634,7 +668,7 @@ class GenericAdditions(GenericDatasetHandler):
634
668
  count=np.full(shape, -1, dtype=np.int64),
635
669
  has_nans=np.full(shape, False, dtype=np.bool_),
636
670
  )
637
- LOG.info(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}")
671
+ LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}")
638
672
 
639
673
  found = set()
640
674
  ifound = set()
@@ -730,9 +764,9 @@ class GenericAdditions(GenericDatasetHandler):
730
764
  "has_nans",
731
765
  ]:
732
766
  name = self.final_storage_name(k)
733
- self._add_dataset(name=name, array=summary[k])
767
+ self._add_dataset(name=name, array=summary[k], dimensions=("variable",))
734
768
  self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end")
735
- LOG.info(f"Wrote additions in {self.path} ({self.final_storage_name('*')})")
769
+ LOG.debug(f"Wrote additions in {self.path} ({self.final_storage_name('*')})")
736
770
 
737
771
  def check_statistics(self):
738
772
  pass
@@ -744,10 +778,19 @@ class GenericAdditions(GenericDatasetHandler):
744
778
  return z.attrs["variables_with_nans"]
745
779
  return None
746
780
 
747
- def allow_nan(self, name):
781
+ @cached_property
782
+ def _allow_nans(self):
783
+ z = zarr.open(self.path, mode="r")
784
+ return z.attrs.get("allow_nans", False)
785
+
786
+ def allow_nans(self):
787
+
788
+ if self._allow_nans:
789
+ return True
790
+
748
791
  if self._variables_with_nans is not None:
749
- return name in self._variables_with_nans
750
- warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, Assuming nans allowed for {name}.")
792
+ return self._variables_with_nans
793
+ warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.")
751
794
  return True
752
795
 
753
796
 
@@ -768,7 +811,7 @@ class StatisticsAddition(GenericAdditions):
768
811
 
769
812
  @property
770
813
  def tmp_storage_path(self):
771
- return f"{self.path}.tmp_storage_statistics"
814
+ return f"{self.path}.storage_statistics.tmp"
772
815
 
773
816
  def final_storage_name(self, k):
774
817
  return k
@@ -781,12 +824,12 @@ class StatisticsAddition(GenericAdditions):
781
824
  date = self.dates[i]
782
825
  try:
783
826
  arr = self.ds[i : i + 1, ...]
784
- stats = compute_statistics(arr, self.variables, allow_nan=self.allow_nan)
827
+ stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans)
785
828
  self.tmp_storage.add([date, i, stats], key=date)
786
829
  except MissingDateError:
787
830
  self.tmp_storage.add([date, i, "missing"], key=date)
788
831
  self.tmp_storage.flush()
789
- LOG.info(f"Dataset {self.path} additions run.")
832
+ LOG.debug(f"Dataset {self.path} additions run.")
790
833
 
791
834
  def check_statistics(self):
792
835
  ds = open_dataset(self.path)
@@ -846,7 +889,7 @@ class TendenciesStatisticsAddition(GenericAdditions):
846
889
 
847
890
  @property
848
891
  def tmp_storage_path(self):
849
- return f"{self.path}.tmp_storage_statistics_{self.delta}h"
892
+ return f"{self.path}.storage_statistics_{self.delta}h.tmp"
850
893
 
851
894
  def final_storage_name(self, k):
852
895
  return self.final_storage_name_from_delta(k, delta=self.delta)
@@ -867,9 +910,15 @@ class TendenciesStatisticsAddition(GenericAdditions):
867
910
  date = self.dates[i]
868
911
  try:
869
912
  arr = self.ds[i]
870
- stats = compute_statistics(arr, self.variables, allow_nan=self.allow_nan)
913
+ stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans)
871
914
  self.tmp_storage.add([date, i, stats], key=date)
872
915
  except MissingDateError:
873
916
  self.tmp_storage.add([date, i, "missing"], key=date)
874
917
  self.tmp_storage.flush()
875
- LOG.info(f"Dataset {self.path} additions run.")
918
+ LOG.debug(f"Dataset {self.path} additions run.")
919
+
920
+
921
+ class DatasetVerifier(GenericDatasetHandler):
922
+
923
+ def verify(self):
924
+ pass
@@ -1,9 +1,12 @@
1
1
  #!/usr/bin/env python3
2
2
  import json
3
+ import logging
3
4
  import os
4
5
 
5
6
  import zarr
6
7
 
8
+ LOG = logging.getLogger(__name__)
9
+
7
10
 
8
11
  def fix_order_by(order_by):
9
12
  if isinstance(order_by, list):
@@ -48,7 +51,7 @@ def fix_provenance(provenance):
48
51
  provenance["module_versions"][k] = os.path.join("...", os.path.basename(v))
49
52
 
50
53
  for k, v in list(provenance["git_versions"].items()):
51
- print(k, v)
54
+ LOG.debug(k, v)
52
55
  modified_files = v["git"].get("modified_files", [])
53
56
  untracked_files = v["git"].get("untracked_files", [])
54
57
  if not isinstance(modified_files, int):
@@ -63,21 +66,21 @@ def fix_provenance(provenance):
63
66
  }
64
67
  )
65
68
 
66
- print(json.dumps(provenance, indent=2))
69
+ LOG.debug(json.dumps(provenance, indent=2))
67
70
  # assert False
68
71
  return provenance
69
72
 
70
73
 
71
74
  def apply_patch(path, verbose=True, dry_run=False):
72
- print("====================")
73
- print(f"Patching {path}")
74
- print("====================")
75
+ LOG.debug("====================")
76
+ LOG.debug(f"Patching {path}")
77
+ LOG.debug("====================")
75
78
 
76
79
  try:
77
80
  attrs = zarr.open(path, mode="r").attrs.asdict()
78
81
  except zarr.errors.PathNotFoundError as e:
79
- print(f"Failed to open {path}")
80
- print(e)
82
+ LOG.error(f"Failed to open {path}")
83
+ LOG.error(e)
81
84
  exit(0)
82
85
 
83
86
  FIXES = {
@@ -94,23 +97,23 @@ def apply_patch(path, verbose=True, dry_run=False):
94
97
  for k, v in attrs.items():
95
98
  v = attrs[k]
96
99
  if k in REMOVE:
97
- print(f"✅ Remove {k}")
100
+ LOG.info(f"✅ Remove {k}")
98
101
  continue
99
102
 
100
103
  if k not in FIXES:
101
104
  assert not k.startswith("provenance"), f"[{k}]"
102
- print(f"✅ Don't fix {k}")
105
+ LOG.debug(f"✅ Don't fix {k}")
103
106
  fixed_attrs[k] = v
104
107
  continue
105
108
 
106
109
  new_v = FIXES[k](v)
107
110
  if json.dumps(new_v, sort_keys=True) != json.dumps(v, sort_keys=True):
108
- print(f"✅ Fix {k}")
111
+ LOG.info(f"✅ Fix {k}")
109
112
  if verbose:
110
- print(f" Before : {k}= {v}")
111
- print(f" After : {k}= {new_v}")
113
+ LOG.info(f" Before : {k}= {v}")
114
+ LOG.info(f" After : {k}= {new_v}")
112
115
  else:
113
- print(f"✅ Unchanged {k}")
116
+ LOG.debug(f"✅ Unchanged {k}")
114
117
  fixed_attrs[k] = new_v
115
118
 
116
119
  if dry_run:
@@ -125,6 +128,6 @@ def apply_patch(path, verbose=True, dry_run=False):
125
128
 
126
129
  after = json.dumps(z.attrs.asdict(), sort_keys=True)
127
130
  if before != after:
128
- print("CHANGED")
131
+ LOG.info("Dataset changed by patch")
129
132
 
130
133
  assert json.dumps(z.attrs.asdict(), sort_keys=True) == json.dumps(fixed_attrs, sort_keys=True)
@@ -49,7 +49,7 @@ class PersistentDict:
49
49
  def items(self):
50
50
  # use glob to read all pickles
51
51
  files = glob.glob(self.dirname + "/*.pickle")
52
- LOG.info(f"Reading {self.name} data, found {len(files)} files in {self.dirname}")
52
+ LOG.debug(f"Reading {self.name} data, found {len(files)} files in {self.dirname}")
53
53
  assert len(files) > 0, f"No files found in {self.dirname}"
54
54
  for f in files:
55
55
  with open(f, "rb") as f:
@@ -10,9 +10,8 @@
10
10
  import logging
11
11
  import os
12
12
 
13
- from anemoi.utils.humanize import bytes
14
-
15
- from anemoi.datasets.create.utils import progress_bar
13
+ import tqdm
14
+ from anemoi.utils.humanize import bytes_to_human
16
15
 
17
16
  LOG = logging.getLogger(__name__)
18
17
 
@@ -22,14 +21,14 @@ def compute_directory_sizes(path):
22
21
  return None
23
22
 
24
23
  size, n = 0, 0
25
- bar = progress_bar(iterable=os.walk(path), desc=f"Computing size of {path}")
24
+ bar = tqdm.tqdm(iterable=os.walk(path), desc=f"Computing size of {path}")
26
25
  for dirpath, _, filenames in bar:
27
26
  for filename in filenames:
28
27
  file_path = os.path.join(dirpath, filename)
29
28
  size += os.path.getsize(file_path)
30
29
  n += 1
31
30
 
32
- LOG.info(f"Total size: {bytes(size)}")
31
+ LOG.info(f"Total size: {bytes_to_human(size)}")
33
32
  LOG.info(f"Total number of files: {n}")
34
33
 
35
34
  return dict(total_size=size, total_number_of_files=n)