anemoi-datasets 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/cleanup.py +44 -0
  3. anemoi/datasets/commands/create.py +52 -21
  4. anemoi/datasets/commands/finalise-additions.py +45 -0
  5. anemoi/datasets/commands/finalise.py +39 -0
  6. anemoi/datasets/commands/init-additions.py +45 -0
  7. anemoi/datasets/commands/init.py +67 -0
  8. anemoi/datasets/commands/inspect.py +1 -1
  9. anemoi/datasets/commands/load-additions.py +47 -0
  10. anemoi/datasets/commands/load.py +47 -0
  11. anemoi/datasets/commands/patch.py +39 -0
  12. anemoi/datasets/create/__init__.py +959 -146
  13. anemoi/datasets/create/check.py +5 -3
  14. anemoi/datasets/create/config.py +54 -2
  15. anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +57 -0
  16. anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +57 -0
  17. anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +54 -0
  18. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +59 -0
  19. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +115 -0
  20. anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +390 -0
  21. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +77 -0
  22. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +55 -0
  23. anemoi/datasets/create/functions/sources/grib.py +86 -1
  24. anemoi/datasets/create/functions/sources/hindcasts.py +14 -73
  25. anemoi/datasets/create/functions/sources/mars.py +9 -3
  26. anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
  27. anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
  28. anemoi/datasets/create/functions/sources/xarray/field.py +8 -2
  29. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
  30. anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
  31. anemoi/datasets/create/functions/sources/xarray/metadata.py +40 -40
  32. anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
  33. anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
  34. anemoi/datasets/create/input.py +62 -39
  35. anemoi/datasets/create/persistent.py +1 -1
  36. anemoi/datasets/create/statistics/__init__.py +39 -23
  37. anemoi/datasets/create/utils.py +6 -2
  38. anemoi/datasets/data/__init__.py +1 -0
  39. anemoi/datasets/data/concat.py +46 -2
  40. anemoi/datasets/data/dataset.py +119 -34
  41. anemoi/datasets/data/debug.py +5 -1
  42. anemoi/datasets/data/forwards.py +17 -8
  43. anemoi/datasets/data/grids.py +17 -3
  44. anemoi/datasets/data/interpolate.py +133 -0
  45. anemoi/datasets/data/masked.py +2 -2
  46. anemoi/datasets/data/misc.py +56 -66
  47. anemoi/datasets/data/missing.py +240 -0
  48. anemoi/datasets/data/rescale.py +147 -0
  49. anemoi/datasets/data/select.py +7 -1
  50. anemoi/datasets/data/stores.py +23 -10
  51. anemoi/datasets/data/subset.py +47 -5
  52. anemoi/datasets/data/unchecked.py +20 -22
  53. anemoi/datasets/data/xy.py +125 -0
  54. anemoi/datasets/dates/__init__.py +124 -95
  55. anemoi/datasets/dates/groups.py +85 -20
  56. anemoi/datasets/grids.py +66 -48
  57. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/METADATA +8 -17
  58. anemoi_datasets-0.5.0.dist-info/RECORD +105 -0
  59. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/WHEEL +1 -1
  60. anemoi/datasets/create/loaders.py +0 -936
  61. anemoi_datasets-0.4.4.dist-info/RECORD +0 -86
  62. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/LICENSE +0 -0
  63. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/entry_points.txt +0 -0
  64. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/top_level.txt +0 -0
@@ -7,196 +7,1009 @@
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
9
 
10
+ import datetime
11
+ import json
10
12
  import logging
11
13
  import os
14
+ import time
15
+ import uuid
16
+ import warnings
17
+ from functools import cached_property
18
+
19
+ import numpy as np
20
+ import tqdm
21
+ from anemoi.utils.config import DotDict as DotDict
22
+ from anemoi.utils.dates import as_datetime
23
+ from anemoi.utils.dates import frequency_to_string
24
+ from anemoi.utils.dates import frequency_to_timedelta
25
+ from anemoi.utils.humanize import compress_dates
26
+ from anemoi.utils.humanize import seconds_to_human
27
+
28
+ from anemoi.datasets import MissingDateError
29
+ from anemoi.datasets import open_dataset
30
+ from anemoi.datasets.create.persistent import build_storage
31
+ from anemoi.datasets.data.misc import as_first_date
32
+ from anemoi.datasets.data.misc import as_last_date
33
+ from anemoi.datasets.dates.groups import Groups
34
+
35
+ from .check import DatasetName
36
+ from .check import check_data_values
37
+ from .chunks import ChunkFilter
38
+ from .config import build_output
39
+ from .config import loader_config
40
+ from .input import build_input
41
+ from .statistics import Summary
42
+ from .statistics import TmpStatistics
43
+ from .statistics import check_variance
44
+ from .statistics import compute_statistics
45
+ from .statistics import default_statistics_dates
46
+ from .statistics import fix_variance
47
+ from .utils import normalize_and_check_dates
48
+ from .writer import ViewCacheArray
12
49
 
13
50
  LOG = logging.getLogger(__name__)
14
51
 
52
+ VERSION = "0.20"
53
+
54
+
55
+ def json_tidy(o):
56
+
57
+ if isinstance(o, datetime.datetime):
58
+ return o.isoformat()
59
+
60
+ if isinstance(o, datetime.datetime):
61
+ return o.isoformat()
62
+
63
+ if isinstance(o, datetime.timedelta):
64
+ return frequency_to_string(o)
65
+
66
+ raise TypeError(repr(o) + " is not JSON serializable")
67
+
68
+
69
+ def build_statistics_dates(dates, start, end):
70
+ """Compute the start and end dates for the statistics, based on :
71
+ - The start and end dates in the config
72
+ - The default statistics dates convention
73
+
74
+ Then adapt according to the actual dates in the dataset.
75
+ """
76
+ # if not specified, use the default statistics dates
77
+ default_start, default_end = default_statistics_dates(dates)
78
+ if start is None:
79
+ start = default_start
80
+ if end is None:
81
+ end = default_end
82
+
83
+ # in any case, adapt to the actual dates in the dataset
84
+ start = as_first_date(start, dates)
85
+ end = as_last_date(end, dates)
86
+
87
+ # and convert to datetime to isoformat
88
+ start = start.astype(datetime.datetime)
89
+ end = end.astype(datetime.datetime)
90
+ return (start.isoformat(), end.isoformat())
91
+
15
92
 
16
93
  def _ignore(*args, **kwargs):
17
94
  pass
18
95
 
19
96
 
20
- class Creator:
21
- def __init__(
22
- self,
23
- path,
24
- config=None,
25
- cache=None,
26
- use_threads=False,
27
- statistics_tmp=None,
28
- overwrite=False,
29
- test=None,
30
- progress=None,
31
- **kwargs,
32
- ):
33
- self.path = path # Output path
34
- self.config = config
97
+ def _path_readable(path):
98
+ import zarr
99
+
100
+ try:
101
+ zarr.open(path, "r")
102
+ return True
103
+ except zarr.errors.PathNotFoundError:
104
+ return False
105
+
106
+
107
+ class Dataset:
108
+ def __init__(self, path):
109
+ self.path = path
110
+
111
+ _, ext = os.path.splitext(self.path)
112
+ if ext != ".zarr":
113
+ raise ValueError(f"Unsupported extension={ext} for path={self.path}")
114
+
115
+ def add_dataset(self, mode="r+", **kwargs):
116
+ import zarr
117
+
118
+ z = zarr.open(self.path, mode=mode)
119
+ from .zarr import add_zarr_dataset
120
+
121
+ return add_zarr_dataset(zarr_root=z, **kwargs)
122
+
123
+ def update_metadata(self, **kwargs):
124
+ import zarr
125
+
126
+ LOG.debug(f"Updating metadata {kwargs}")
127
+ z = zarr.open(self.path, mode="w+")
128
+ for k, v in kwargs.items():
129
+ if isinstance(v, np.datetime64):
130
+ v = v.astype(datetime.datetime)
131
+ if isinstance(v, datetime.date):
132
+ v = v.isoformat()
133
+ z.attrs[k] = json.loads(json.dumps(v, default=json_tidy))
134
+
135
+ @cached_property
136
+ def anemoi_dataset(self):
137
+ return open_dataset(self.path)
138
+
139
+ @cached_property
140
+ def zarr_metadata(self):
141
+ import zarr
142
+
143
+ return dict(zarr.open(self.path, mode="r").attrs)
144
+
145
+ def print_info(self):
146
+ import zarr
147
+
148
+ z = zarr.open(self.path, mode="r")
149
+ try:
150
+ LOG.info(z["data"].info)
151
+ except Exception as e:
152
+ LOG.info(e)
153
+
154
+ def get_zarr_chunks(self):
155
+ import zarr
156
+
157
+ z = zarr.open(self.path, mode="r")
158
+ return z["data"].chunks
159
+
160
+ def check_name(self, resolution, dates, frequency, raise_exception=True, is_test=False):
161
+ basename, _ = os.path.splitext(os.path.basename(self.path))
162
+ try:
163
+ DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid()
164
+ except Exception as e:
165
+ if raise_exception and not is_test:
166
+ raise e
167
+ else:
168
+ LOG.warning(f"Dataset name error: {e}")
169
+
170
+ def get_main_config(self):
171
+ """Returns None if the config is not found."""
172
+ import zarr
173
+
174
+ z = zarr.open(self.path, mode="r")
175
+ return loader_config(z.attrs.get("_create_yaml_config"))
176
+
177
+
178
+ class WritableDataset(Dataset):
179
+ def __init__(self, path):
180
+ super().__init__(path)
181
+ self.path = path
182
+
183
+ import zarr
184
+
185
+ self.z = zarr.open(self.path, mode="r+")
186
+
187
+ @cached_property
188
+ def data_array(self):
189
+ import zarr
190
+
191
+ return zarr.open(self.path, mode="r+")["data"]
192
+
193
+
194
+ class NewDataset(Dataset):
195
+ def __init__(self, path, overwrite=False):
196
+ super().__init__(path)
197
+ self.path = path
198
+
199
+ import zarr
200
+
201
+ self.z = zarr.open(self.path, mode="w")
202
+ self.z.create_group("_build")
203
+
204
+
205
+ class Actor: # TODO: rename to Creator
206
+ dataset_class = WritableDataset
207
+
208
+ def __init__(self, path, cache=None):
209
+ # Catch all floating point errors, including overflow, sqrt(<0), etc
210
+ np.seterr(all="raise", under="warn")
211
+
212
+ self.path = path
35
213
  self.cache = cache
214
+ self.dataset = self.dataset_class(self.path)
215
+
216
+ def run(self):
217
+ # to be implemented in the sub-classes
218
+ raise NotImplementedError()
219
+
220
+ def update_metadata(self, **kwargs):
221
+ self.dataset.update_metadata(**kwargs)
222
+
223
+ def _cache_context(self):
224
+ from .utils import cache_context
225
+
226
+ return cache_context(self.cache)
227
+
228
+ def check_unkown_kwargs(self, kwargs):
229
+ # remove this latter
230
+ LOG.warning(f"💬 Unknown kwargs for {self.__class__.__name__}: {kwargs}")
231
+
232
+ def read_dataset_metadata(self, path):
233
+ ds = open_dataset(path)
234
+ self.dataset_shape = ds.shape
235
+ self.variables_names = ds.variables
236
+ assert len(self.variables_names) == ds.shape[1], self.dataset_shape
237
+ self.dates = ds.dates
238
+
239
+ self.missing_dates = sorted(list([self.dates[i] for i in ds.missing]))
240
+
241
+ def check_missing_dates(expected):
242
+ import zarr
243
+
244
+ z = zarr.open(path, "r")
245
+ missing_dates = z.attrs.get("missing_dates", [])
246
+ missing_dates = sorted([np.datetime64(d) for d in missing_dates])
247
+ if missing_dates != expected:
248
+ LOG.warning("Missing dates given in recipe do not match the actual missing dates in the dataset.")
249
+ LOG.warning(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}")
250
+ LOG.warning(f"Missing dates in dataset: {sorted(str(x) for x in expected)}")
251
+ raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.")
252
+
253
+ check_missing_dates(self.missing_dates)
254
+
255
+
256
+ class Patch(Actor):
257
+ def __init__(self, path, options=None, **kwargs):
258
+ self.path = path
259
+ self.options = options or {}
260
+
261
+ def run(self):
262
+ from .patch import apply_patch
263
+
264
+ apply_patch(self.path, **self.options)
265
+
266
+
267
+ class Size(Actor):
268
+ def __init__(self, path, **kwargs):
269
+ super().__init__(path)
270
+
271
+ def run(self):
272
+ from .size import compute_directory_sizes
273
+
274
+ metadata = compute_directory_sizes(self.path)
275
+ self.update_metadata(**metadata)
276
+
277
+
278
+ class HasRegistryMixin:
279
+ @cached_property
280
+ def registry(self):
281
+ from .zarr import ZarrBuiltRegistry
282
+
283
+ return ZarrBuiltRegistry(self.path, use_threads=self.use_threads)
284
+
285
+
286
+ class HasStatisticTempMixin:
287
+ @cached_property
288
+ def tmp_statistics(self):
289
+ directory = self.statistics_temp_dir or os.path.join(self.path + ".storage_for_statistics.tmp")
290
+ return TmpStatistics(directory)
291
+
292
+
293
+ class HasElementForDataMixin:
294
+ def create_elements(self, config):
295
+
296
+ assert self.registry
297
+ assert self.tmp_statistics
298
+
299
+ LOG.info(dict(config.dates))
300
+
301
+ self.groups = Groups(**config.dates)
302
+ LOG.info(self.groups)
303
+
304
+ self.output = build_output(config.output, parent=self)
305
+
306
+ self.input = build_input_(main_config=config, output_config=self.output)
307
+ LOG.info(self.input)
308
+
309
+
310
+ def build_input_(main_config, output_config):
311
+ from earthkit.data.core.order import build_remapping
312
+
313
+ builder = build_input(
314
+ main_config.input,
315
+ data_sources=main_config.get("data_sources", {}),
316
+ order_by=output_config.order_by,
317
+ flatten_grid=output_config.flatten_grid,
318
+ remapping=build_remapping(output_config.remapping),
319
+ use_grib_paramid=main_config.build.use_grib_paramid,
320
+ )
321
+ LOG.debug("✅ INPUT_BUILDER")
322
+ LOG.debug(builder)
323
+ return builder
324
+
325
+
326
+ class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin):
327
+ dataset_class = NewDataset
328
+ def __init__(self, path, config, check_name=False, overwrite=False, use_threads=False, statistics_temp_dir=None, progress=None, test=False, cache=None, **kwargs): # fmt: skip
329
+ if _path_readable(path) and not overwrite:
330
+ raise Exception(f"{path} already exists. Use overwrite=True to overwrite.")
331
+
332
+ super().__init__(path, cache=cache)
333
+ self.config = config
334
+ self.check_name = check_name
36
335
  self.use_threads = use_threads
37
- self.statistics_tmp = statistics_tmp
38
- self.overwrite = overwrite
336
+ self.statistics_temp_dir = statistics_temp_dir
337
+ self.progress = progress
39
338
  self.test = test
40
- self.progress = progress if progress is not None else _ignore
41
339
 
42
- def init(self, check_name=False):
43
- # check path
44
- _, ext = os.path.splitext(self.path)
45
- assert ext != "zarr", f"Unsupported extension={ext}"
46
- from .loaders import InitialiserLoader
340
+ self.main_config = loader_config(config, is_test=test)
341
+
342
+ # self.registry.delete() ??
343
+ self.tmp_statistics.delete()
344
+
345
+ assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by
346
+ self.create_elements(self.main_config)
47
347
 
48
- if self._path_readable() and not self.overwrite:
49
- raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.")
348
+ LOG.info(f"Groups: {self.groups}")
50
349
 
350
+ one_date = self.groups.one_date()
351
+ # assert False, (type(one_date), type(self.groups))
352
+ self.minimal_input = self.input.select(one_date)
353
+ LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}")
354
+ LOG.info(self.minimal_input)
355
+
356
+ def run(self):
51
357
  with self._cache_context():
52
- obj = InitialiserLoader.from_config(
53
- path=self.path,
54
- config=self.config,
55
- statistics_tmp=self.statistics_tmp,
56
- use_threads=self.use_threads,
57
- progress=self.progress,
58
- test=self.test,
358
+ return self._run()
359
+
360
+ def _run(self):
361
+ """Create an empty dataset of the right final shape
362
+
363
+ Read a small part of the data to get the shape of the data and the resolution and more metadata.
364
+ """
365
+
366
+ LOG.info("Config loaded ok:")
367
+ # LOG.info(self.main_config)
368
+
369
+ dates = self.groups.provider.values
370
+ frequency = self.groups.provider.frequency
371
+ missing = self.groups.provider.missing
372
+
373
+ assert isinstance(frequency, datetime.timedelta), frequency
374
+
375
+ LOG.info(f"Found {len(dates)} datetimes.")
376
+ LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ")
377
+ LOG.info(f"Missing dates: {len(missing)}")
378
+ lengths = tuple(len(g) for g in self.groups)
379
+
380
+ variables = self.minimal_input.variables
381
+ LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.")
382
+
383
+ variables_with_nans = self.main_config.statistics.get("allow_nans", [])
384
+
385
+ ensembles = self.minimal_input.ensembles
386
+ LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.")
387
+
388
+ grid_points = self.minimal_input.grid_points
389
+ LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}")
390
+
391
+ resolution = self.minimal_input.resolution
392
+ LOG.info(f"{resolution=}")
393
+
394
+ coords = self.minimal_input.coords
395
+ coords["dates"] = dates
396
+ total_shape = self.minimal_input.shape
397
+ total_shape[0] = len(dates)
398
+ LOG.info(f"total_shape = {total_shape}")
399
+
400
+ chunks = self.output.get_chunking(coords)
401
+ LOG.info(f"{chunks=}")
402
+ dtype = self.output.dtype
403
+
404
+ LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}")
405
+
406
+ metadata = {}
407
+ metadata["uuid"] = str(uuid.uuid4())
408
+
409
+ metadata.update(self.main_config.get("add_metadata", {}))
410
+
411
+ metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict()
412
+
413
+ metadata["description"] = self.main_config.description
414
+ metadata["licence"] = self.main_config["licence"]
415
+ metadata["attribution"] = self.main_config["attribution"]
416
+
417
+ metadata["remapping"] = self.output.remapping
418
+ metadata["order_by"] = self.output.order_by_as_list
419
+ metadata["flatten_grid"] = self.output.flatten_grid
420
+
421
+ metadata["ensemble_dimension"] = len(ensembles)
422
+ metadata["variables"] = variables
423
+ metadata["variables_with_nans"] = variables_with_nans
424
+ metadata["allow_nans"] = self.main_config.build.get("allow_nans", False)
425
+ metadata["resolution"] = resolution
426
+
427
+ metadata["data_request"] = self.minimal_input.data_request
428
+ metadata["field_shape"] = self.minimal_input.field_shape
429
+ metadata["proj_string"] = self.minimal_input.proj_string
430
+
431
+ metadata["start_date"] = dates[0].isoformat()
432
+ metadata["end_date"] = dates[-1].isoformat()
433
+ metadata["frequency"] = frequency
434
+ metadata["missing_dates"] = [_.isoformat() for _ in missing]
435
+
436
+ metadata["version"] = VERSION
437
+
438
+ self.dataset.check_name(
439
+ raise_exception=self.check_name,
440
+ is_test=self.test,
441
+ resolution=resolution,
442
+ dates=dates,
443
+ frequency=frequency,
444
+ )
445
+
446
+ if len(dates) != total_shape[0]:
447
+ raise ValueError(
448
+ f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) "
449
+ f"does not match data shape {total_shape[0]}. {total_shape=}"
59
450
  )
60
- return obj.initialise(check_name=check_name)
61
451
 
62
- def load(self, parts=None):
63
- from .loaders import ContentLoader
452
+ dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"])
453
+
454
+ metadata.update(self.main_config.get("force_metadata", {}))
455
+
456
+ ###############################################################
457
+ # write metadata
458
+ ###############################################################
459
+
460
+ self.update_metadata(**metadata)
461
+
462
+ self.dataset.add_dataset(
463
+ name="data",
464
+ chunks=chunks,
465
+ dtype=dtype,
466
+ shape=total_shape,
467
+ dimensions=("time", "variable", "ensemble", "cell"),
468
+ )
469
+ self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",))
470
+ self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",))
471
+ self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",))
472
+
473
+ self.registry.create(lengths=lengths)
474
+ self.tmp_statistics.create(exist_ok=False)
475
+ self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version)
476
+
477
+ statistics_start, statistics_end = build_statistics_dates(
478
+ dates,
479
+ self.main_config.statistics.get("start"),
480
+ self.main_config.statistics.get("end"),
481
+ )
482
+ self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end)
483
+ LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}")
484
+
485
+ self.registry.add_to_history("init finished")
486
+
487
+ assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks())
488
+
489
+ # Return the number of groups to process, so we can show a nice progress bar
490
+ return len(lengths)
491
+
64
492
 
493
+ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin):
494
+ def __init__(self, path, parts=None, use_threads=False, statistics_temp_dir=None, progress=None, cache=None, **kwargs): # fmt: skip
495
+ super().__init__(path, cache=cache)
496
+ self.use_threads = use_threads
497
+ self.statistics_temp_dir = statistics_temp_dir
498
+ self.progress = progress
499
+ self.parts = parts
500
+ self.dataset = WritableDataset(self.path)
501
+
502
+ self.main_config = self.dataset.get_main_config()
503
+ self.create_elements(self.main_config)
504
+ self.read_dataset_metadata(self.dataset.path)
505
+
506
+ total = len(self.registry.get_flags())
507
+ self.chunk_filter = ChunkFilter(parts=self.parts, total=total)
508
+
509
+ self.data_array = self.dataset.data_array
510
+ self.n_groups = len(self.groups)
511
+
512
+ def run(self):
65
513
  with self._cache_context():
66
- loader = ContentLoader.from_dataset_config(
67
- path=self.path,
68
- statistics_tmp=self.statistics_tmp,
69
- use_threads=self.use_threads,
70
- progress=self.progress,
71
- parts=parts,
514
+ self._run()
515
+
516
+ def _run(self):
517
+ for igroup, group in enumerate(self.groups):
518
+ if not self.chunk_filter(igroup):
519
+ continue
520
+ if self.registry.get_flag(igroup):
521
+ LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)")
522
+ continue
523
+
524
+ # assert isinstance(group[0], datetime.datetime), type(group[0])
525
+ LOG.debug(f"Building data for group {igroup}/{self.n_groups}")
526
+
527
+ result = self.input.select(dates=group)
528
+ assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group)
529
+
530
+ # There are several groups.
531
+ # There is one result to load for each group.
532
+ self.load_result(result)
533
+ self.registry.set_flag(igroup)
534
+
535
+ self.registry.add_provenance(name="provenance_load")
536
+ self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config)
537
+
538
+ self.dataset.print_info()
539
+
540
+ def load_result(self, result):
541
+ # There is one cube to load for each result.
542
+ dates = list(result.group_of_dates)
543
+
544
+ cube = result.get_cube()
545
+ shape = cube.extended_user_shape
546
+ dates_in_data = cube.user_coords["valid_datetime"]
547
+
548
+ LOG.debug(f"Loading {shape=} in {self.data_array.shape=}")
549
+
550
+ def check_shape(cube, dates, dates_in_data):
551
+ if cube.extended_user_shape[0] != len(dates):
552
+ print(
553
+ f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}"
554
+ )
555
+ print("Requested dates", compress_dates(dates))
556
+ print("Cube dates", compress_dates(dates_in_data))
557
+
558
+ a = set(as_datetime(_) for _ in dates)
559
+ b = set(as_datetime(_) for _ in dates_in_data)
560
+
561
+ print("Missing dates", compress_dates(a - b))
562
+ print("Extra dates", compress_dates(b - a))
563
+
564
+ raise ValueError(
565
+ f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}"
566
+ )
567
+
568
+ check_shape(cube, dates, dates_in_data)
569
+
570
+ def check_dates_in_data(lst, lst2):
571
+ lst2 = [np.datetime64(_) for _ in lst2]
572
+ lst = [np.datetime64(_) for _ in lst]
573
+ assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2)
574
+
575
+ check_dates_in_data(dates_in_data, dates)
576
+
577
+ def dates_to_indexes(dates, all_dates):
578
+ x = np.array(dates, dtype=np.datetime64)
579
+ y = np.array(all_dates, dtype=np.datetime64)
580
+ bitmap = np.isin(x, y)
581
+ return np.where(bitmap)[0]
582
+
583
+ indexes = dates_to_indexes(self.dates, dates_in_data)
584
+
585
+ array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes)
586
+ self.load_cube(cube, array)
587
+
588
+ stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans())
589
+ self.tmp_statistics.write(indexes, stats, dates=dates_in_data)
590
+
591
+ array.flush()
592
+
593
+ def _get_allow_nans(self):
594
+ config = self.main_config
595
+ if "allow_nans" in config.build:
596
+ return config.build.allow_nans
597
+
598
+ return config.statistics.get("allow_nans", [])
599
+
600
+ def load_cube(self, cube, array):
601
+ # There are several cubelets for each cube
602
+ start = time.time()
603
+ load = 0
604
+ save = 0
605
+
606
+ reading_chunks = None
607
+ total = cube.count(reading_chunks)
608
+ LOG.debug(f"Loading datacube: {cube}")
609
+
610
+ def position(x):
611
+ if isinstance(x, str) and "/" in x:
612
+ x = x.split("/")
613
+ return int(x[0])
614
+ return None
615
+
616
+ bar = tqdm.tqdm(
617
+ iterable=cube.iterate_cubelets(reading_chunks),
618
+ total=total,
619
+ desc=f"Loading datacube {cube}",
620
+ position=position(self.parts),
621
+ )
622
+ for i, cubelet in enumerate(bar):
623
+ bar.set_description(f"Loading {i}/{total}")
624
+
625
+ now = time.time()
626
+ data = cubelet.to_numpy()
627
+ local_indexes = cubelet.coords
628
+ load += time.time() - now
629
+
630
+ name = self.variables_names[local_indexes[1]]
631
+ check_data_values(
632
+ data[:],
633
+ name=name,
634
+ log=[i, data.shape, local_indexes],
635
+ allow_nans=self._get_allow_nans(),
72
636
  )
73
- loader.load()
74
-
75
- def statistics(self, force=False, output=None, start=None, end=None):
76
- from .loaders import StatisticsAdder
77
-
78
- loader = StatisticsAdder.from_dataset(
79
- path=self.path,
80
- use_threads=self.use_threads,
81
- progress=self.progress,
82
- statistics_tmp=self.statistics_tmp,
83
- statistics_output=output,
84
- recompute=False,
85
- statistics_start=start,
86
- statistics_end=end,
637
+
638
+ now = time.time()
639
+ array[local_indexes] = data
640
+ save += time.time() - now
641
+
642
+ now = time.time()
643
+ save += time.time() - now
644
+ LOG.debug(
645
+ f"Elapsed: {seconds_to_human(time.time() - start)}, "
646
+ f"load time: {seconds_to_human(load)}, "
647
+ f"write time: {seconds_to_human(save)}."
87
648
  )
88
- loader.run()
89
- assert loader.ready()
90
649
 
91
- def size(self):
92
- from .loaders import DatasetHandler
93
- from .size import compute_directory_sizes
94
650
 
95
- metadata = compute_directory_sizes(self.path)
96
- handle = DatasetHandler.from_dataset(path=self.path, use_threads=self.use_threads)
97
- handle.update_metadata(**metadata)
651
+ class Cleanup(Actor, HasRegistryMixin, HasStatisticTempMixin):
652
+ def __init__(self, path, statistics_temp_dir=None, delta=[], use_threads=False, **kwargs):
653
+ super().__init__(path)
654
+ self.use_threads = use_threads
655
+ self.statistics_temp_dir = statistics_temp_dir
656
+ self.additinon_temp_dir = statistics_temp_dir
657
+ self.actors = [
658
+ _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir)
659
+ for d in delta
660
+ ]
661
+
662
+ def run(self):
663
+ self.tmp_statistics.delete()
664
+ self.registry.clean()
665
+ for actor in self.actors:
666
+ actor.cleanup()
667
+
668
+
669
+ class Verify(Actor):
670
+ def __init__(self, path, **kwargs):
671
+ super().__init__(path)
672
+
673
+ def run(self):
674
+ LOG.info(f"Verifying dataset at {self.path}")
675
+ LOG.info(str(self.dataset.anemoi_dataset))
676
+
677
+
678
+ class AdditionsMixin:
679
+ def skip(self):
680
+ frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency)
681
+ if not self.delta.total_seconds() % frequency.total_seconds() == 0:
682
+ LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.")
683
+ return True
684
+ return False
685
+
686
+ @cached_property
687
+ def tmp_storage_path(self):
688
+ name = "storage_for_additions"
689
+ if self.delta:
690
+ name += frequency_to_string(self.delta)
691
+ return os.path.join(f"{self.path}.{name}.tmp")
692
+
693
+ def read_from_dataset(self):
694
+ self.variables = self.dataset.anemoi_dataset.variables
695
+ self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency)
696
+ start = self.dataset.zarr_metadata["statistics_start_date"]
697
+ end = self.dataset.zarr_metadata["statistics_end_date"]
698
+ self.start = datetime.datetime.fromisoformat(start)
699
+ self.end = datetime.datetime.fromisoformat(end)
700
+
701
+ ds = open_dataset(self.path, start=self.start, end=self.end)
702
+ self.dates = ds.dates
703
+ self.total = len(self.dates)
704
+
705
+ idelta = self.delta.total_seconds() // self.frequency.total_seconds()
706
+ assert int(idelta) == idelta, idelta
707
+ idelta = int(idelta)
708
+ self.ds = DeltaDataset(ds, idelta)
709
+
710
+
711
+ class DeltaDataset:
712
+ def __init__(self, ds, idelta):
713
+ self.ds = ds
714
+ self.idelta = idelta
715
+
716
+ def __getitem__(self, i):
717
+ j = i - self.idelta
718
+ if j < 0:
719
+ raise MissingDateError(f"Missing date {j}")
720
+ return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...]
721
+
722
+
723
+ class _InitAdditions(Actor, HasRegistryMixin, AdditionsMixin):
724
+ def __init__(self, path, delta, use_threads=False, progress=None, **kwargs):
725
+ super().__init__(path)
726
+ self.delta = frequency_to_timedelta(delta)
727
+ self.use_threads = use_threads
728
+ self.progress = progress
729
+
730
+ def run(self):
731
+ if self.skip():
732
+ LOG.info(f"Skipping delta={self.delta}")
733
+ return
734
+
735
+ self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True)
736
+ self.tmp_storage.delete()
737
+ self.tmp_storage.create()
738
+ LOG.info(f"Dataset {self.tmp_storage_path} additions initialized.")
98
739
 
99
740
  def cleanup(self):
100
- from .loaders import DatasetHandlerWithStatistics
741
+ self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False)
742
+ self.tmp_storage.delete()
743
+ LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}")
101
744
 
102
- cleaner = DatasetHandlerWithStatistics.from_dataset(
103
- path=self.path, use_threads=self.use_threads, progress=self.progress, statistics_tmp=self.statistics_tmp
104
- )
105
- cleaner.tmp_statistics.delete()
106
- cleaner.registry.clean()
107
745
 
108
- def patch(self, **kwargs):
109
- from .patch import apply_patch
746
+ class _RunAdditions(Actor, HasRegistryMixin, AdditionsMixin):
747
+ def __init__(self, path, delta, parts=None, use_threads=False, progress=None, **kwargs):
748
+ super().__init__(path)
749
+ self.delta = frequency_to_timedelta(delta)
750
+ self.use_threads = use_threads
751
+ self.progress = progress
752
+ self.parts = parts
110
753
 
111
- apply_patch(self.path, **kwargs)
754
+ self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False)
755
+ LOG.info(f"Writing in {self.tmp_storage_path}")
112
756
 
113
- def init_additions(self, delta=[1, 3, 6, 12, 24], statistics=True):
114
- from .loaders import StatisticsAddition
115
- from .loaders import TendenciesStatisticsAddition
116
- from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency
757
+ def run(self):
758
+ if self.skip():
759
+ LOG.info(f"Skipping delta={self.delta}")
760
+ return
117
761
 
118
- if statistics:
119
- a = StatisticsAddition.from_dataset(path=self.path, use_threads=self.use_threads)
120
- a.initialise()
762
+ self.read_from_dataset()
121
763
 
122
- for d in delta:
764
+ chunk_filter = ChunkFilter(parts=self.parts, total=self.total)
765
+ for i in range(0, self.total):
766
+ if not chunk_filter(i):
767
+ continue
768
+ date = self.dates[i]
123
769
  try:
124
- a = TendenciesStatisticsAddition.from_dataset(
125
- path=self.path, use_threads=self.use_threads, progress=self.progress, delta=d
126
- )
127
- a.initialise()
128
- except TendenciesStatisticsDeltaNotMultipleOfFrequency:
129
- LOG.info(f"Skipping delta={d} as it is not a multiple of the frequency.")
770
+ arr = self.ds[i]
771
+ stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans)
772
+ self.tmp_storage.add([date, i, stats], key=date)
773
+ except MissingDateError:
774
+ self.tmp_storage.add([date, i, "missing"], key=date)
775
+ self.tmp_storage.flush()
776
+ LOG.debug(f"Dataset {self.path} additions run.")
777
+
778
+ def allow_nans(self):
779
+ if self.dataset.anemoi_dataset.metadata.get("allow_nans", False):
780
+ return True
130
781
 
131
- def run_additions(self, parts=None, delta=[1, 3, 6, 12, 24], statistics=True):
132
- from .loaders import StatisticsAddition
133
- from .loaders import TendenciesStatisticsAddition
134
- from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency
782
+ variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None)
783
+ if variables_with_nans is not None:
784
+ return variables_with_nans
785
+ warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.")
786
+ return True
135
787
 
136
- if statistics:
137
- a = StatisticsAddition.from_dataset(path=self.path, use_threads=self.use_threads)
138
- a.run(parts)
139
788
 
140
- for d in delta:
141
- try:
142
- a = TendenciesStatisticsAddition.from_dataset(
143
- path=self.path, use_threads=self.use_threads, progress=self.progress, delta=d
144
- )
145
- a.run(parts)
146
- except TendenciesStatisticsDeltaNotMultipleOfFrequency:
147
- LOG.debug(f"Skipping delta={d} as it is not a multiple of the frequency.")
789
+ class _FinaliseAdditions(Actor, HasRegistryMixin, AdditionsMixin):
790
+ def __init__(self, path, delta, use_threads=False, progress=None, **kwargs):
791
+ super().__init__(path)
792
+ self.delta = frequency_to_timedelta(delta)
793
+ self.use_threads = use_threads
794
+ self.progress = progress
795
+
796
+ self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False)
797
+ LOG.info(f"Reading from {self.tmp_storage_path}.")
798
+
799
+ def run(self):
800
+ if self.skip():
801
+ LOG.info(f"Skipping delta={self.delta}.")
802
+ return
803
+
804
+ self.read_from_dataset()
805
+
806
+ shape = (len(self.dates), len(self.variables))
807
+ agg = dict(
808
+ minimum=np.full(shape, np.nan, dtype=np.float64),
809
+ maximum=np.full(shape, np.nan, dtype=np.float64),
810
+ sums=np.full(shape, np.nan, dtype=np.float64),
811
+ squares=np.full(shape, np.nan, dtype=np.float64),
812
+ count=np.full(shape, -1, dtype=np.int64),
813
+ has_nans=np.full(shape, False, dtype=np.bool_),
814
+ )
815
+ LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}")
816
+
817
+ found = set()
818
+ ifound = set()
819
+ missing = set()
820
+ for _date, (date, i, stats) in self.tmp_storage.items():
821
+ assert _date == date
822
+ if stats == "missing":
823
+ missing.add(date)
824
+ continue
825
+
826
+ assert date not in found, f"Duplicates found {date}"
827
+ found.add(date)
828
+ ifound.add(i)
829
+
830
+ for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
831
+ agg[k][i, ...] = stats[k]
832
+
833
+ assert len(found) + len(missing) == len(self.dates), (
834
+ len(found),
835
+ len(missing),
836
+ len(self.dates),
837
+ )
838
+ assert found.union(missing) == set(self.dates), (
839
+ found,
840
+ missing,
841
+ set(self.dates),
842
+ )
148
843
 
149
- def finalise_additions(self, delta=[1, 3, 6, 12, 24], statistics=True):
150
- from .loaders import StatisticsAddition
151
- from .loaders import TendenciesStatisticsAddition
152
- from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency
844
+ if len(ifound) < 2:
845
+ LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.")
846
+ self.tmp_storage.delete()
847
+ return
153
848
 
154
- if statistics:
155
- a = StatisticsAddition.from_dataset(path=self.path, use_threads=self.use_threads)
156
- a.finalise()
849
+ mask = sorted(list(ifound))
850
+ for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
851
+ agg[k] = agg[k][mask, ...]
157
852
 
158
- for d in delta:
159
- try:
160
- a = TendenciesStatisticsAddition.from_dataset(
161
- path=self.path, use_threads=self.use_threads, progress=self.progress, delta=d
162
- )
163
- a.finalise()
164
- except TendenciesStatisticsDeltaNotMultipleOfFrequency:
165
- LOG.debug(f"Skipping delta={d} as it is not a multiple of the frequency.")
166
-
167
- def finalise(self, **kwargs):
168
- self.statistics(**kwargs)
169
- self.size()
170
-
171
- def create(self):
172
- self.init()
173
- self.load()
174
- self.finalise()
175
- self.additions()
176
- self.cleanup()
177
-
178
- def additions(self, delta=[1, 3, 6, 12, 24]):
179
- self.init_additions(delta=delta)
180
- self.run_additions(delta=delta)
181
- self.finalise_additions(delta=delta)
853
+ for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]:
854
+ assert agg[k].shape == agg["count"].shape, (
855
+ agg[k].shape,
856
+ agg["count"].shape,
857
+ )
182
858
 
183
- def _cache_context(self):
184
- from .utils import cache_context
859
+ minimum = np.nanmin(agg["minimum"], axis=0)
860
+ maximum = np.nanmax(agg["maximum"], axis=0)
861
+ sums = np.nansum(agg["sums"], axis=0)
862
+ squares = np.nansum(agg["squares"], axis=0)
863
+ count = np.nansum(agg["count"], axis=0)
864
+ has_nans = np.any(agg["has_nans"], axis=0)
865
+
866
+ assert sums.shape == count.shape
867
+ assert sums.shape == squares.shape
868
+ assert sums.shape == minimum.shape
869
+ assert sums.shape == maximum.shape
870
+ assert sums.shape == has_nans.shape
871
+
872
+ mean = sums / count
873
+ assert sums.shape == mean.shape
874
+
875
+ x = squares / count - mean * mean
876
+ # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0
877
+ # remove negative variance due to numerical errors
878
+ for i, name in enumerate(self.variables):
879
+ x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1])
880
+ check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares)
881
+
882
+ stdev = np.sqrt(x)
883
+ assert sums.shape == stdev.shape
884
+
885
+ self.summary = Summary(
886
+ minimum=minimum,
887
+ maximum=maximum,
888
+ mean=mean,
889
+ count=count,
890
+ sums=sums,
891
+ squares=squares,
892
+ stdev=stdev,
893
+ variables_names=self.variables,
894
+ has_nans=has_nans,
895
+ )
896
+ LOG.info(f"Dataset {self.path} additions finalised.")
897
+ # self.check_statistics()
898
+ self._write(self.summary)
899
+ self.tmp_storage.delete()
185
900
 
186
- return cache_context(self.cache)
901
+ def _write(self, summary):
902
+ for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]:
903
+ name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}"
904
+ self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",))
905
+ self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end")
906
+ LOG.debug(f"Wrote additions in {self.path}")
187
907
 
188
- def _path_readable(self):
189
- import zarr
190
908
 
191
- try:
192
- zarr.open(self.path, "r")
193
- return True
194
- except zarr.errors.PathNotFoundError:
195
- return False
909
+ def multi_addition(cls):
910
+ class MultiAdditions:
911
+ def __init__(self, *args, **kwargs):
912
+ self.actors = []
913
+
914
+ for k in kwargs.pop("delta", []):
915
+ self.actors.append(cls(*args, delta=k, **kwargs))
916
+
917
+ if not self.actors:
918
+ LOG.warning("No delta found in kwargs, no additions will be computed.")
919
+
920
+ def run(self):
921
+ for actor in self.actors:
922
+ actor.run()
923
+
924
+ return MultiAdditions
196
925
 
197
- def verify(self):
198
- from .loaders import DatasetVerifier
199
926
 
200
- handle = DatasetVerifier.from_dataset(path=self.path, use_threads=self.use_threads)
927
+ InitAdditions = multi_addition(_InitAdditions)
928
+ RunAdditions = multi_addition(_RunAdditions)
929
+ FinaliseAdditions = multi_addition(_FinaliseAdditions)
930
+
931
+
932
+ class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin):
933
+ def __init__(self, path, use_threads=False, statistics_temp_dir=None, progress=None, **kwargs):
934
+ super().__init__(path)
935
+ self.use_threads = use_threads
936
+ self.progress = progress
937
+ self.statistics_temp_dir = statistics_temp_dir
938
+
939
+ def run(self):
940
+ start, end = (
941
+ self.dataset.zarr_metadata["statistics_start_date"],
942
+ self.dataset.zarr_metadata["statistics_end_date"],
943
+ )
944
+ start, end = np.datetime64(start), np.datetime64(end)
945
+ dates = self.dataset.anemoi_dataset.dates
946
+
947
+ assert type(dates[0]) is type(start), (type(dates[0]), type(start))
948
+
949
+ dates = [d for d in dates if d >= start and d <= end]
950
+ dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing]
951
+ variables = self.dataset.anemoi_dataset.variables
952
+ stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans)
953
+
954
+ LOG.info(stats)
955
+
956
+ if not all(self.registry.get_flags(sync=False)):
957
+ raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.")
958
+
959
+ for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]:
960
+ self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",))
961
+
962
+ self.registry.add_to_history("compute_statistics_end")
963
+ LOG.info(f"Wrote statistics in {self.path}")
964
+
965
+ @cached_property
966
+ def allow_nans(self):
967
+ import zarr
201
968
 
202
- handle.verify()
969
+ z = zarr.open(self.path, mode="r")
970
+ if "allow_nans" in z.attrs:
971
+ return z.attrs["allow_nans"]
972
+
973
+ if "variables_with_nans" in z.attrs:
974
+ return z.attrs["variables_with_nans"]
975
+
976
+ warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.")
977
+ return True
978
+
979
+
980
+ def chain(tasks):
981
+ class Chain(Actor):
982
+ def __init__(self, **kwargs):
983
+ self.kwargs = kwargs
984
+
985
+ def run(self):
986
+ for cls in tasks:
987
+ t = cls(**self.kwargs)
988
+ t.run()
989
+
990
+ return Chain
991
+
992
+
993
+ def creator_factory(name, trace=None, **kwargs):
994
+ if trace:
995
+ from anemoi.datasets.create.trace import enable_trace
996
+
997
+ enable_trace(trace)
998
+
999
+ cls = dict(
1000
+ init=Init,
1001
+ load=Load,
1002
+ size=Size,
1003
+ patch=Patch,
1004
+ statistics=Statistics,
1005
+ finalise=chain([Statistics, Size, Cleanup]),
1006
+ cleanup=Cleanup,
1007
+ verify=Verify,
1008
+ init_additions=InitAdditions,
1009
+ load_additions=RunAdditions,
1010
+ run_additions=RunAdditions,
1011
+ finalise_additions=chain([FinaliseAdditions, Size]),
1012
+ additions=chain([InitAdditions, RunAdditions, FinaliseAdditions, Size, Cleanup]),
1013
+ )[name]
1014
+ LOG.debug(f"Creating {cls.__name__} with {kwargs}")
1015
+ return cls(**kwargs)