anemoi-datasets 0.3.10__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/compare.py +59 -0
  3. anemoi/datasets/commands/create.py +84 -3
  4. anemoi/datasets/commands/inspect.py +9 -9
  5. anemoi/datasets/commands/scan.py +4 -4
  6. anemoi/datasets/compute/recentre.py +14 -9
  7. anemoi/datasets/create/__init__.py +44 -17
  8. anemoi/datasets/create/check.py +6 -5
  9. anemoi/datasets/create/chunks.py +1 -1
  10. anemoi/datasets/create/config.py +6 -27
  11. anemoi/datasets/create/functions/__init__.py +3 -3
  12. anemoi/datasets/create/functions/filters/empty.py +4 -4
  13. anemoi/datasets/create/functions/filters/rename.py +14 -6
  14. anemoi/datasets/create/functions/filters/rotate_winds.py +16 -60
  15. anemoi/datasets/create/functions/filters/unrotate_winds.py +14 -64
  16. anemoi/datasets/create/functions/sources/__init__.py +39 -0
  17. anemoi/datasets/create/functions/sources/accumulations.py +38 -56
  18. anemoi/datasets/create/functions/sources/constants.py +11 -4
  19. anemoi/datasets/create/functions/sources/empty.py +2 -2
  20. anemoi/datasets/create/functions/sources/forcings.py +3 -3
  21. anemoi/datasets/create/functions/sources/grib.py +8 -4
  22. anemoi/datasets/create/functions/sources/hindcasts.py +32 -364
  23. anemoi/datasets/create/functions/sources/mars.py +57 -26
  24. anemoi/datasets/create/functions/sources/netcdf.py +2 -60
  25. anemoi/datasets/create/functions/sources/opendap.py +3 -2
  26. anemoi/datasets/create/functions/sources/source.py +3 -3
  27. anemoi/datasets/create/functions/sources/tendencies.py +7 -7
  28. anemoi/datasets/create/functions/sources/xarray/__init__.py +73 -0
  29. anemoi/datasets/create/functions/sources/xarray/coordinates.py +234 -0
  30. anemoi/datasets/create/functions/sources/xarray/field.py +109 -0
  31. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +171 -0
  32. anemoi/datasets/create/functions/sources/xarray/flavour.py +330 -0
  33. anemoi/datasets/create/functions/sources/xarray/grid.py +46 -0
  34. anemoi/datasets/create/functions/sources/xarray/metadata.py +161 -0
  35. anemoi/datasets/create/functions/sources/xarray/time.py +98 -0
  36. anemoi/datasets/create/functions/sources/xarray/variable.py +198 -0
  37. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +42 -0
  38. anemoi/datasets/create/functions/sources/xarray_zarr.py +15 -0
  39. anemoi/datasets/create/functions/sources/zenodo.py +40 -0
  40. anemoi/datasets/create/input.py +309 -191
  41. anemoi/datasets/create/loaders.py +155 -77
  42. anemoi/datasets/create/patch.py +17 -14
  43. anemoi/datasets/create/persistent.py +1 -1
  44. anemoi/datasets/create/size.py +4 -5
  45. anemoi/datasets/create/statistics/__init__.py +51 -17
  46. anemoi/datasets/create/template.py +11 -61
  47. anemoi/datasets/create/trace.py +91 -0
  48. anemoi/datasets/create/utils.py +5 -52
  49. anemoi/datasets/create/zarr.py +24 -10
  50. anemoi/datasets/data/dataset.py +4 -4
  51. anemoi/datasets/data/misc.py +9 -37
  52. anemoi/datasets/data/stores.py +37 -14
  53. anemoi/datasets/dates/__init__.py +7 -1
  54. anemoi/datasets/dates/groups.py +3 -0
  55. {anemoi_datasets-0.3.10.dist-info → anemoi_datasets-0.4.2.dist-info}/METADATA +24 -8
  56. anemoi_datasets-0.4.2.dist-info/RECORD +86 -0
  57. {anemoi_datasets-0.3.10.dist-info → anemoi_datasets-0.4.2.dist-info}/WHEEL +1 -1
  58. anemoi_datasets-0.3.10.dist-info/RECORD +0 -73
  59. {anemoi_datasets-0.3.10.dist-info → anemoi_datasets-0.4.2.dist-info}/LICENSE +0 -0
  60. {anemoi_datasets-0.3.10.dist-info → anemoi_datasets-0.4.2.dist-info}/entry_points.txt +0 -0
  61. {anemoi_datasets-0.3.10.dist-info → anemoi_datasets-0.4.2.dist-info}/top_level.txt +0 -0
@@ -10,9 +10,8 @@
10
10
  import logging
11
11
  import os
12
12
 
13
- from anemoi.utils.humanize import bytes
14
-
15
- from anemoi.datasets.create.utils import progress_bar
13
+ import tqdm
14
+ from anemoi.utils.humanize import bytes_to_human
16
15
 
17
16
  LOG = logging.getLogger(__name__)
18
17
 
@@ -22,14 +21,14 @@ def compute_directory_sizes(path):
22
21
  return None
23
22
 
24
23
  size, n = 0, 0
25
- bar = progress_bar(iterable=os.walk(path), desc=f"Computing size of {path}")
24
+ bar = tqdm.tqdm(iterable=os.walk(path), desc=f"Computing size of {path}")
26
25
  for dirpath, _, filenames in bar:
27
26
  for filename in filenames:
28
27
  file_path = os.path.join(dirpath, filename)
29
28
  size += os.path.getsize(file_path)
30
29
  n += 1
31
30
 
32
- LOG.info(f"Total size: {bytes(size)}")
31
+ LOG.info(f"Total size: {bytes_to_human(size)}")
33
32
  LOG.info(f"Total number of files: {n}")
34
33
 
35
34
  return dict(total_size=size, total_number_of_files=n)
@@ -71,7 +71,7 @@ def to_datetime(date):
71
71
  if isinstance(date, str):
72
72
  return np.datetime64(date)
73
73
  if isinstance(date, datetime.datetime):
74
- return np.datetime64(date)
74
+ return np.datetime64(date, "s")
75
75
  return date
76
76
 
77
77
 
@@ -89,20 +89,23 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa
89
89
  continue
90
90
  print("---")
91
91
  print(f"❗ Negative variance for {name=}, variance={y}")
92
- print(f" max={maximum[i]} min={minimum[i]} mean={mean[i]} count={count[i]} sum={sums[i]} square={squares[i]}")
92
+ print(f" min={minimum[i]} max={maximum[i]} mean={mean[i]} count={count[i]} sums={sums[i]} squares={squares[i]}")
93
93
  print(f" -> sums: min={np.min(sums[i])}, max={np.max(sums[i])}, argmin={np.argmin(sums[i])}")
94
94
  print(f" -> squares: min={np.min(squares[i])}, max={np.max(squares[i])}, argmin={np.argmin(squares[i])}")
95
95
  print(f" -> count: min={np.min(count[i])}, max={np.max(count[i])}, argmin={np.argmin(count[i])}")
96
+ print(
97
+ f" squares / count - mean * mean = {squares[i] / count[i]} - {mean[i] * mean[i]} = {squares[i] / count[i] - mean[i] * mean[i]}"
98
+ )
96
99
 
97
100
  raise ValueError("Negative variance")
98
101
 
99
102
 
100
- def compute_statistics(array, check_variables_names=None, allow_nan=False):
103
+ def compute_statistics(array, check_variables_names=None, allow_nans=False):
101
104
  """Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary."""
102
105
 
103
106
  nvars = array.shape[1]
104
107
 
105
- LOG.info(f"Stats {nvars}, {array.shape}, {check_variables_names}")
108
+ LOG.debug(f"Stats {nvars}, {array.shape}, {check_variables_names}")
106
109
  if check_variables_names:
107
110
  assert nvars == len(check_variables_names), (nvars, check_variables_names)
108
111
  stats_shape = (array.shape[0], nvars)
@@ -118,7 +121,7 @@ def compute_statistics(array, check_variables_names=None, allow_nan=False):
118
121
  values = chunk.reshape((nvars, -1))
119
122
 
120
123
  for j, name in enumerate(check_variables_names):
121
- check_data_values(values[j, :], name=name, allow_nan=allow_nan)
124
+ check_data_values(values[j, :], name=name, allow_nans=allow_nans)
122
125
  if np.isnan(values[j, :]).all():
123
126
  # LOG.warning(f"All NaN values for {name} ({j}) for date {i}")
124
127
  raise ValueError(f"All NaN values for {name} ({j}) for date {i}")
@@ -179,12 +182,12 @@ class TmpStatistics:
179
182
  pickle.dump((key, dates, data), f)
180
183
  shutil.move(tmp_path, path)
181
184
 
182
- LOG.info(f"Written statistics data for {len(dates)} dates in {path} ({dates})")
185
+ LOG.debug(f"Written statistics data for {len(dates)} dates in {path} ({dates})")
183
186
 
184
187
  def _gather_data(self):
185
188
  # use glob to read all pickles
186
189
  files = glob.glob(self.dirname + "/*.npz")
187
- LOG.info(f"Reading stats data, found {len(files)} files in {self.dirname}")
190
+ LOG.debug(f"Reading stats data, found {len(files)} files in {self.dirname}")
188
191
  assert len(files) > 0, f"No files found in {self.dirname}"
189
192
  for f in files:
190
193
  with open(f, "rb") as f:
@@ -211,17 +214,17 @@ def normalise_dates(dates):
211
214
  class StatAggregator:
212
215
  NAMES = ["minimum", "maximum", "sums", "squares", "count", "has_nans"]
213
216
 
214
- def __init__(self, owner, dates, variables_names, allow_nan):
217
+ def __init__(self, owner, dates, variables_names, allow_nans):
215
218
  dates = sorted(dates)
216
219
  dates = to_datetimes(dates)
217
220
  assert dates, "No dates selected"
218
221
  self.owner = owner
219
222
  self.dates = dates
220
223
  self.variables_names = variables_names
221
- self.allow_nan = allow_nan
224
+ self.allow_nans = allow_nans
222
225
 
223
226
  self.shape = (len(self.dates), len(self.variables_names))
224
- LOG.info(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}")
227
+ LOG.debug(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}")
225
228
 
226
229
  self.minimum = np.full(self.shape, np.nan, dtype=np.float64)
227
230
  self.maximum = np.full(self.shape, np.nan, dtype=np.float64)
@@ -242,6 +245,7 @@ class StatAggregator:
242
245
 
243
246
  found = set()
244
247
  offset = 0
248
+
245
249
  for _, _dates, stats in self.owner._gather_data():
246
250
  assert isinstance(stats, dict), stats
247
251
  assert stats["minimum"].shape[0] == len(_dates), (stats["minimum"].shape, len(_dates))
@@ -283,7 +287,7 @@ class StatAggregator:
283
287
  assert d in found, f"Statistics for date {d} not precomputed."
284
288
  assert len(self.dates) == len(found), "Not all dates found in precomputed statistics"
285
289
  assert len(self.dates) == offset, "Not all dates found in precomputed statistics."
286
- LOG.info(f"Statistics for {len(found)} dates found.")
290
+ LOG.debug(f"Statistics for {len(found)} dates found.")
287
291
 
288
292
  def aggregate(self):
289
293
  minimum = np.nanmin(self.minimum, axis=0)
@@ -297,13 +301,43 @@ class StatAggregator:
297
301
  assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape
298
302
 
299
303
  x = squares / count - mean * mean
300
- # remove negative variance due to numerical errors
301
- # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0
302
- check_variance(x, self.variables_names, minimum, maximum, mean, count, sums, squares)
303
- stdev = np.sqrt(x)
304
304
 
305
- for j, name in enumerate(self.variables_names):
306
- check_data_values(np.array([mean[j]]), name=name, allow_nan=False)
305
+ # def fix_variance(x, name, minimum, maximum, mean, count, sums, squares):
306
+ # assert x.shape == minimum.shape == maximum.shape == mean.shape == count.shape == sums.shape == squares.shape
307
+ # assert x.shape == (1,)
308
+ # x, minimum, maximum, mean, count, sums, squares = x[0], minimum[0], maximum[0], mean[0], count[0], sums[0], squares[0]
309
+ # if x >= 0:
310
+ # return x
311
+ #
312
+ # order = np.sqrt((squares / count + mean * mean)/2)
313
+ # range = maximum - minimum
314
+ # LOG.warning(f"Negative variance for {name=}, variance={x}")
315
+ # LOG.warning(f"square / count - mean * mean = {squares / count} - {mean * mean} = {squares / count - mean * mean}")
316
+ # LOG.warning(f"Variable order of magnitude is {order}.")
317
+ # LOG.warning(f"Range is {range} ({maximum=} - {minimum=}).")
318
+ # LOG.warning(f"Count is {count}.")
319
+ # if abs(x) < order * 1e-6 and abs(x) < range * 1e-6:
320
+ # LOG.warning(f"Variance is negative but very small, setting to 0.")
321
+ # return x*0
322
+ # return x
323
+
324
+ for i, name in enumerate(self.variables_names):
325
+ # remove negative variance due to numerical errors
326
+ # Not needed for now, fix_variance is disabled
327
+ # x[i] = fix_variance(x[i:i+1], name, minimum[i:i+1], maximum[i:i+1], mean[i:i+1], count[i:i+1], sums[i:i+1], squares[i:i+1])
328
+ check_variance(
329
+ x[i : i + 1],
330
+ [name],
331
+ minimum[i : i + 1],
332
+ maximum[i : i + 1],
333
+ mean[i : i + 1],
334
+ count[i : i + 1],
335
+ sums[i : i + 1],
336
+ squares[i : i + 1],
337
+ )
338
+ check_data_values(np.array([mean[i]]), name=name, allow_nans=False)
339
+
340
+ stdev = np.sqrt(x)
307
341
 
308
342
  return Summary(
309
343
  minimum=minimum,
@@ -8,72 +8,16 @@
8
8
  #
9
9
 
10
10
  import logging
11
- import os
12
11
  import re
13
12
  import textwrap
14
13
  from functools import wraps
15
14
 
16
- LOG = logging.getLogger(__name__)
17
-
18
- TRACE_INDENT = 0
19
-
20
-
21
- def step(action_path):
22
- return f"[{'.'.join(action_path)}]"
15
+ from anemoi.utils.humanize import plural
23
16
 
17
+ from .trace import step
18
+ from .trace import trace
24
19
 
25
- def trace(emoji, *args):
26
- if os.environ.get("ANEMOI_DATASET_TRACE_CREATE") is None:
27
- return
28
- print(emoji, " " * TRACE_INDENT, *args)
29
-
30
-
31
- def trace_datasource(method):
32
- @wraps(method)
33
- def wrapper(self, *args, **kwargs):
34
- global TRACE_INDENT
35
- trace(
36
- "🌍",
37
- "=>",
38
- step(self.action_path),
39
- self._trace_datasource(*args, **kwargs),
40
- )
41
- TRACE_INDENT += 1
42
- result = method(self, *args, **kwargs)
43
- TRACE_INDENT -= 1
44
- trace(
45
- "🍎",
46
- "<=",
47
- step(self.action_path),
48
- textwrap.shorten(repr(result), 256),
49
- )
50
- return result
51
-
52
- return wrapper
53
-
54
-
55
- def trace_select(method):
56
- @wraps(method)
57
- def wrapper(self, *args, **kwargs):
58
- global TRACE_INDENT
59
- trace(
60
- "👓",
61
- "=>",
62
- ".".join(self.action_path),
63
- self._trace_select(*args, **kwargs),
64
- )
65
- TRACE_INDENT += 1
66
- result = method(self, *args, **kwargs)
67
- TRACE_INDENT -= 1
68
- trace(
69
- "🍍",
70
- "<=",
71
- ".".join(self.action_path),
72
- textwrap.shorten(repr(result), 256),
73
- )
74
- return result
75
-
76
- return wrapper
20
+ LOG = logging.getLogger(__name__)
77
21
 
78
22
 
79
23
  def notify_result(method):
@@ -99,7 +43,13 @@ class Context:
99
43
  self.used_references.add(key)
100
44
 
101
45
  def notify_result(self, key, result):
102
- trace("🎯", step(key), "notify result", result)
46
+ trace(
47
+ "🎯",
48
+ step(key),
49
+ "notify result",
50
+ textwrap.shorten(repr(result).replace(",", ", "), width=40),
51
+ plural(len(result), "field"),
52
+ )
103
53
  assert isinstance(key, (list, tuple)), key
104
54
  key = tuple(key)
105
55
  if key in self.used_references:
@@ -0,0 +1,91 @@
1
+ # (C) Copyright 2024 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+ #
9
+
10
+ import logging
11
+ import textwrap
12
+ import threading
13
+ from functools import wraps
14
+
15
+ LOG = logging.getLogger(__name__)
16
+
17
+
18
+ thread_local = threading.local()
19
+ TRACE = 0
20
+
21
+
22
+ def enable_trace(on_off):
23
+ global TRACE
24
+ TRACE = on_off
25
+
26
+
27
+ def step(action_path):
28
+ return f"[{'.'.join(action_path)}]"
29
+
30
+
31
+ def trace(emoji, *args):
32
+
33
+ if not TRACE:
34
+ return
35
+
36
+ if not hasattr(thread_local, "TRACE_INDENT"):
37
+ thread_local.TRACE_INDENT = 0
38
+
39
+ print(emoji, " " * thread_local.TRACE_INDENT, *args)
40
+
41
+
42
+ def trace_datasource(method):
43
+ @wraps(method)
44
+ def wrapper(self, *args, **kwargs):
45
+
46
+ if not hasattr(thread_local, "TRACE_INDENT"):
47
+ thread_local.TRACE_INDENT = 0
48
+
49
+ trace(
50
+ "🌍",
51
+ "=>",
52
+ step(self.action_path),
53
+ self._trace_datasource(*args, **kwargs),
54
+ )
55
+ thread_local.TRACE_INDENT += 1
56
+ result = method(self, *args, **kwargs)
57
+ thread_local.TRACE_INDENT -= 1
58
+ trace(
59
+ "🍎",
60
+ "<=",
61
+ step(self.action_path),
62
+ textwrap.shorten(repr(result), 256),
63
+ )
64
+ return result
65
+
66
+ return wrapper
67
+
68
+
69
+ def trace_select(method):
70
+ @wraps(method)
71
+ def wrapper(self, *args, **kwargs):
72
+ if not hasattr(thread_local, "TRACE_INDENT"):
73
+ thread_local.TRACE_INDENT = 0
74
+ trace(
75
+ "👓",
76
+ "=>",
77
+ ".".join(self.action_path),
78
+ self._trace_select(*args, **kwargs),
79
+ )
80
+ thread_local.TRACE_INDENT += 1
81
+ result = method(self, *args, **kwargs)
82
+ thread_local.TRACE_INDENT -= 1
83
+ trace(
84
+ "🍍",
85
+ "<=",
86
+ ".".join(self.action_path),
87
+ textwrap.shorten(repr(result), 256),
88
+ )
89
+ return result
90
+
91
+ return wrapper
@@ -7,15 +7,11 @@
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
9
 
10
- import json
11
10
  import os
12
11
  from contextlib import contextmanager
13
12
 
14
13
  import numpy as np
15
- import yaml
16
- from climetlab import settings
17
- from climetlab.utils.humanize import seconds # noqa: F401
18
- from tqdm.auto import tqdm
14
+ from earthkit.data import settings
19
15
 
20
16
 
21
17
  def cache_context(dirname):
@@ -27,50 +23,22 @@ def cache_context(dirname):
27
23
  return no_cache_context()
28
24
 
29
25
  os.makedirs(dirname, exist_ok=True)
30
- return settings.temporary("cache-directory", dirname)
31
-
32
-
33
- def bytes(n):
34
- """>>> bytes(4096)
35
- '4 KiB'
36
- >>> bytes(4000)
37
- '3.9 KiB'
38
- """
39
- if n < 0:
40
- sign = "-"
41
- n -= 0
42
- else:
43
- sign = ""
44
-
45
- u = ["", " KiB", " MiB", " GiB", " TiB", " PiB", " EiB", " ZiB", " YiB"]
46
- i = 0
47
- while n >= 1024:
48
- n /= 1024.0
49
- i += 1
50
- return "%s%g%s" % (sign, int(n * 10 + 0.5) / 10.0, u[i])
26
+ # return settings.temporary("cache-directory", dirname)
27
+ return settings.temporary({"cache-policy": "user", "user-cache-directory": dirname})
51
28
 
52
29
 
53
30
  def to_datetime_list(*args, **kwargs):
54
- from climetlab.utils.dates import to_datetime_list as to_datetime_list_
31
+ from earthkit.data.utils.dates import to_datetime_list as to_datetime_list_
55
32
 
56
33
  return to_datetime_list_(*args, **kwargs)
57
34
 
58
35
 
59
36
  def to_datetime(*args, **kwargs):
60
- from climetlab.utils.dates import to_datetime as to_datetime_
37
+ from earthkit.data.utils.dates import to_datetime as to_datetime_
61
38
 
62
39
  return to_datetime_(*args, **kwargs)
63
40
 
64
41
 
65
- def load_json_or_yaml(path):
66
- with open(path, "r") as f:
67
- if path.endswith(".json"):
68
- return json.load(f)
69
- if path.endswith(".yaml") or path.endswith(".yml"):
70
- return yaml.safe_load(f)
71
- raise ValueError(f"Cannot read file {path}. Need json or yaml with appropriate extension.")
72
-
73
-
74
42
  def make_list_int(value):
75
43
  if isinstance(value, str):
76
44
  if "/" not in value:
@@ -117,18 +85,3 @@ def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]
117
85
  assert d1 == d2, (i, d1, d2)
118
86
 
119
87
  return dates_
120
-
121
-
122
- def progress_bar(*, iterable=None, total=None, initial=0, desc=None):
123
- return tqdm(
124
- iterable=iterable,
125
- total=total,
126
- initial=initial,
127
- unit_scale=True,
128
- unit_divisor=1024,
129
- unit="B",
130
- disable=False,
131
- leave=False,
132
- desc=desc,
133
- # dynamic_ncols=True, # make this the default?
134
- )
@@ -24,8 +24,12 @@ def add_zarr_dataset(
24
24
  shape=None,
25
25
  array=None,
26
26
  overwrite=True,
27
+ dimensions=None,
27
28
  **kwargs,
28
29
  ):
30
+ assert dimensions is not None, "Please pass dimensions to add_zarr_dataset."
31
+ assert isinstance(dimensions, (tuple, list))
32
+
29
33
  if dtype is None:
30
34
  assert array is not None, (name, shape, array, dtype, zarr_root)
31
35
  dtype = array.dtype
@@ -44,6 +48,7 @@ def add_zarr_dataset(
44
48
  **kwargs,
45
49
  )
46
50
  a[...] = array
51
+ a.attrs["_ARRAY_DIMENSIONS"] = dimensions
47
52
  return a
48
53
 
49
54
  if "fill_value" not in kwargs:
@@ -69,6 +74,7 @@ def add_zarr_dataset(
69
74
  overwrite=overwrite,
70
75
  **kwargs,
71
76
  )
77
+ a.attrs["_ARRAY_DIMENSIONS"] = dimensions
72
78
  return a
73
79
 
74
80
 
@@ -79,22 +85,27 @@ class ZarrBuiltRegistry:
79
85
  flags = None
80
86
  z = None
81
87
 
82
- def __init__(self, path, synchronizer_path=None):
88
+ def __init__(self, path, synchronizer_path=None, use_threads=False):
83
89
  import zarr
84
90
 
85
91
  assert isinstance(path, str), path
86
92
  self.zarr_path = path
87
93
 
88
- if synchronizer_path is None:
89
- synchronizer_path = self.zarr_path + ".sync"
90
- self.synchronizer_path = synchronizer_path
91
- self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path)
94
+ if use_threads:
95
+ self.synchronizer = zarr.ThreadSynchronizer()
96
+ self.synchronizer_path = None
97
+ else:
98
+ if synchronizer_path is None:
99
+ synchronizer_path = self.zarr_path + ".sync"
100
+ self.synchronizer_path = synchronizer_path
101
+ self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path)
92
102
 
93
103
  def clean(self):
94
- try:
95
- shutil.rmtree(self.synchronizer_path)
96
- except FileNotFoundError:
97
- pass
104
+ if self.synchronizer_path is not None:
105
+ try:
106
+ shutil.rmtree(self.synchronizer_path)
107
+ except FileNotFoundError:
108
+ pass
98
109
 
99
110
  def _open_write(self):
100
111
  import zarr
@@ -112,7 +123,7 @@ class ZarrBuiltRegistry:
112
123
  def new_dataset(self, *args, **kwargs):
113
124
  z = self._open_write()
114
125
  zarr_root = z["_build"]
115
- add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, **kwargs)
126
+ add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs)
116
127
 
117
128
  def add_to_history(self, action, **kwargs):
118
129
  new = dict(
@@ -143,6 +154,9 @@ class ZarrBuiltRegistry:
143
154
  z.attrs["latest_write_timestamp"] = datetime.datetime.utcnow().isoformat()
144
155
  z["_build"][self.name_flags][i] = value
145
156
 
157
+ def ready(self):
158
+ return all(self.get_flags())
159
+
146
160
  def create(self, lengths, overwrite=False):
147
161
  self.new_dataset(name=self.name_lengths, array=np.array(lengths, dtype="i4"))
148
162
  self.new_dataset(name=self.name_flags, array=np.array([False] * len(lengths), dtype=bool))
@@ -187,8 +187,8 @@ class Dataset:
187
187
  specific=self.metadata_specific(),
188
188
  frequency=self.frequency,
189
189
  variables=self.variables,
190
- start_date=self.dates[0],
191
- end_date=self.dates[-1],
190
+ start_date=self.dates[0].astype(str),
191
+ end_date=self.dates[-1].astype(str),
192
192
  )
193
193
  )
194
194
 
@@ -200,8 +200,8 @@ class Dataset:
200
200
  variables=self.variables,
201
201
  shape=self.shape,
202
202
  frequency=self.frequency,
203
- start_date=self.dates[0],
204
- end_date=self.dates[-1],
203
+ start_date=self.dates[0].astype(str),
204
+ end_date=self.dates[-1].astype(str),
205
205
  **kwargs,
206
206
  )
207
207
 
@@ -8,63 +8,35 @@
8
8
  import calendar
9
9
  import datetime
10
10
  import logging
11
- import os
12
11
  import re
13
12
  from pathlib import PurePath
14
13
 
15
14
  import numpy as np
16
15
  import zarr
16
+ from anemoi.utils.config import load_config as load_settings
17
17
 
18
18
  from .dataset import Dataset
19
19
 
20
20
  LOG = logging.getLogger(__name__)
21
21
 
22
- CONFIG = None
23
22
 
24
- try:
25
- import tomllib # Only available since 3.11
26
- except ImportError:
27
- import tomli as tomllib
23
+ def load_config():
24
+ return load_settings(defaults={"datasets": {"named": {}, "path": []}})
28
25
 
29
26
 
30
27
  def add_named_dataset(name, path, **kwargs):
31
- load_config()
32
- if name in CONFIG["datasets"]["named"]:
28
+ config = load_config()
29
+ if name["datasets"]["named"]:
33
30
  raise ValueError(f"Dataset {name} already exists")
34
31
 
35
- CONFIG["datasets"]["named"][name] = path
32
+ config["datasets"]["named"][name] = path
36
33
 
37
34
 
38
35
  def add_dataset_path(path):
39
- load_config()
40
-
41
- if path not in CONFIG["datasets"]["path"]:
42
- CONFIG["datasets"]["path"].append(path)
43
-
44
- # save_config()
45
-
46
-
47
- def load_config():
48
- global CONFIG
49
- if CONFIG is not None:
50
- return CONFIG
51
-
52
- conf = os.path.expanduser("~/.config/anemoi/settings.toml")
53
- if not os.path.exists(conf):
54
- conf = os.path.expanduser("~/.anemoi.toml")
55
-
56
- if os.path.exists(conf):
57
-
58
- with open(conf, "rb") as f:
59
- CONFIG = tomllib.load(f)
60
- else:
61
- CONFIG = {}
62
-
63
- CONFIG.setdefault("datasets", {})
64
- CONFIG["datasets"].setdefault("path", [])
65
- CONFIG["datasets"].setdefault("named", {})
36
+ config = load_config()
66
37
 
67
- return CONFIG
38
+ if path not in config["datasets"]["path"]:
39
+ config["datasets"]["path"].append(path)
68
40
 
69
41
 
70
42
  def _frequency_to_hours(frequency):