anemoi-datasets 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/compare.py +59 -0
  3. anemoi/datasets/commands/create.py +84 -3
  4. anemoi/datasets/commands/inspect.py +3 -3
  5. anemoi/datasets/create/__init__.py +44 -17
  6. anemoi/datasets/create/check.py +6 -5
  7. anemoi/datasets/create/chunks.py +1 -1
  8. anemoi/datasets/create/config.py +5 -26
  9. anemoi/datasets/create/functions/filters/rename.py +9 -1
  10. anemoi/datasets/create/functions/filters/rotate_winds.py +10 -1
  11. anemoi/datasets/create/functions/sources/__init__.py +39 -0
  12. anemoi/datasets/create/functions/sources/accumulations.py +11 -41
  13. anemoi/datasets/create/functions/sources/constants.py +3 -0
  14. anemoi/datasets/create/functions/sources/grib.py +4 -0
  15. anemoi/datasets/create/functions/sources/hindcasts.py +32 -377
  16. anemoi/datasets/create/functions/sources/mars.py +53 -22
  17. anemoi/datasets/create/functions/sources/netcdf.py +2 -60
  18. anemoi/datasets/create/functions/sources/opendap.py +3 -2
  19. anemoi/datasets/create/functions/sources/xarray/__init__.py +73 -0
  20. anemoi/datasets/create/functions/sources/xarray/coordinates.py +234 -0
  21. anemoi/datasets/create/functions/sources/xarray/field.py +109 -0
  22. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +171 -0
  23. anemoi/datasets/create/functions/sources/xarray/flavour.py +330 -0
  24. anemoi/datasets/create/functions/sources/xarray/grid.py +46 -0
  25. anemoi/datasets/create/functions/sources/xarray/metadata.py +161 -0
  26. anemoi/datasets/create/functions/sources/xarray/time.py +98 -0
  27. anemoi/datasets/create/functions/sources/xarray/variable.py +198 -0
  28. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +42 -0
  29. anemoi/datasets/create/functions/sources/xarray_zarr.py +15 -0
  30. anemoi/datasets/create/functions/sources/zenodo.py +40 -0
  31. anemoi/datasets/create/input.py +290 -172
  32. anemoi/datasets/create/loaders.py +120 -71
  33. anemoi/datasets/create/patch.py +17 -14
  34. anemoi/datasets/create/persistent.py +1 -1
  35. anemoi/datasets/create/size.py +4 -5
  36. anemoi/datasets/create/statistics/__init__.py +49 -16
  37. anemoi/datasets/create/template.py +11 -61
  38. anemoi/datasets/create/trace.py +91 -0
  39. anemoi/datasets/create/utils.py +0 -48
  40. anemoi/datasets/create/zarr.py +24 -10
  41. anemoi/datasets/data/misc.py +9 -37
  42. anemoi/datasets/data/stores.py +29 -14
  43. anemoi/datasets/dates/__init__.py +7 -1
  44. anemoi/datasets/dates/groups.py +3 -0
  45. {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/METADATA +18 -3
  46. anemoi_datasets-0.4.2.dist-info/RECORD +86 -0
  47. {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/WHEEL +1 -1
  48. anemoi_datasets-0.4.0.dist-info/RECORD +0 -73
  49. {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/LICENSE +0 -0
  50. {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/entry_points.txt +0 -0
  51. {anemoi_datasets-0.4.0.dist-info → anemoi_datasets-0.4.2.dist-info}/top_level.txt +0 -0
@@ -89,20 +89,23 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa
89
89
  continue
90
90
  print("---")
91
91
  print(f"❗ Negative variance for {name=}, variance={y}")
92
- print(f" max={maximum[i]} min={minimum[i]} mean={mean[i]} count={count[i]} sum={sums[i]} square={squares[i]}")
92
+ print(f" min={minimum[i]} max={maximum[i]} mean={mean[i]} count={count[i]} sums={sums[i]} squares={squares[i]}")
93
93
  print(f" -> sums: min={np.min(sums[i])}, max={np.max(sums[i])}, argmin={np.argmin(sums[i])}")
94
94
  print(f" -> squares: min={np.min(squares[i])}, max={np.max(squares[i])}, argmin={np.argmin(squares[i])}")
95
95
  print(f" -> count: min={np.min(count[i])}, max={np.max(count[i])}, argmin={np.argmin(count[i])}")
96
+ print(
97
+ f" squares / count - mean * mean = {squares[i] / count[i]} - {mean[i] * mean[i]} = {squares[i] / count[i] - mean[i] * mean[i]}"
98
+ )
96
99
 
97
100
  raise ValueError("Negative variance")
98
101
 
99
102
 
100
- def compute_statistics(array, check_variables_names=None, allow_nan=False):
103
+ def compute_statistics(array, check_variables_names=None, allow_nans=False):
101
104
  """Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary."""
102
105
 
103
106
  nvars = array.shape[1]
104
107
 
105
- LOG.info(f"Stats {nvars}, {array.shape}, {check_variables_names}")
108
+ LOG.debug(f"Stats {nvars}, {array.shape}, {check_variables_names}")
106
109
  if check_variables_names:
107
110
  assert nvars == len(check_variables_names), (nvars, check_variables_names)
108
111
  stats_shape = (array.shape[0], nvars)
@@ -118,7 +121,7 @@ def compute_statistics(array, check_variables_names=None, allow_nan=False):
118
121
  values = chunk.reshape((nvars, -1))
119
122
 
120
123
  for j, name in enumerate(check_variables_names):
121
- check_data_values(values[j, :], name=name, allow_nan=allow_nan)
124
+ check_data_values(values[j, :], name=name, allow_nans=allow_nans)
122
125
  if np.isnan(values[j, :]).all():
123
126
  # LOG.warning(f"All NaN values for {name} ({j}) for date {i}")
124
127
  raise ValueError(f"All NaN values for {name} ({j}) for date {i}")
@@ -179,12 +182,12 @@ class TmpStatistics:
179
182
  pickle.dump((key, dates, data), f)
180
183
  shutil.move(tmp_path, path)
181
184
 
182
- LOG.info(f"Written statistics data for {len(dates)} dates in {path} ({dates})")
185
+ LOG.debug(f"Written statistics data for {len(dates)} dates in {path} ({dates})")
183
186
 
184
187
  def _gather_data(self):
185
188
  # use glob to read all pickles
186
189
  files = glob.glob(self.dirname + "/*.npz")
187
- LOG.info(f"Reading stats data, found {len(files)} files in {self.dirname}")
190
+ LOG.debug(f"Reading stats data, found {len(files)} files in {self.dirname}")
188
191
  assert len(files) > 0, f"No files found in {self.dirname}"
189
192
  for f in files:
190
193
  with open(f, "rb") as f:
@@ -211,17 +214,17 @@ def normalise_dates(dates):
211
214
  class StatAggregator:
212
215
  NAMES = ["minimum", "maximum", "sums", "squares", "count", "has_nans"]
213
216
 
214
- def __init__(self, owner, dates, variables_names, allow_nan):
217
+ def __init__(self, owner, dates, variables_names, allow_nans):
215
218
  dates = sorted(dates)
216
219
  dates = to_datetimes(dates)
217
220
  assert dates, "No dates selected"
218
221
  self.owner = owner
219
222
  self.dates = dates
220
223
  self.variables_names = variables_names
221
- self.allow_nan = allow_nan
224
+ self.allow_nans = allow_nans
222
225
 
223
226
  self.shape = (len(self.dates), len(self.variables_names))
224
- LOG.info(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}")
227
+ LOG.debug(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}")
225
228
 
226
229
  self.minimum = np.full(self.shape, np.nan, dtype=np.float64)
227
230
  self.maximum = np.full(self.shape, np.nan, dtype=np.float64)
@@ -284,7 +287,7 @@ class StatAggregator:
284
287
  assert d in found, f"Statistics for date {d} not precomputed."
285
288
  assert len(self.dates) == len(found), "Not all dates found in precomputed statistics"
286
289
  assert len(self.dates) == offset, "Not all dates found in precomputed statistics."
287
- LOG.info(f"Statistics for {len(found)} dates found.")
290
+ LOG.debug(f"Statistics for {len(found)} dates found.")
288
291
 
289
292
  def aggregate(self):
290
293
  minimum = np.nanmin(self.minimum, axis=0)
@@ -298,13 +301,43 @@ class StatAggregator:
298
301
  assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape
299
302
 
300
303
  x = squares / count - mean * mean
301
- # remove negative variance due to numerical errors
302
- # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0
303
- check_variance(x, self.variables_names, minimum, maximum, mean, count, sums, squares)
304
- stdev = np.sqrt(x)
305
304
 
306
- for j, name in enumerate(self.variables_names):
307
- check_data_values(np.array([mean[j]]), name=name, allow_nan=False)
305
+ # def fix_variance(x, name, minimum, maximum, mean, count, sums, squares):
306
+ # assert x.shape == minimum.shape == maximum.shape == mean.shape == count.shape == sums.shape == squares.shape
307
+ # assert x.shape == (1,)
308
+ # x, minimum, maximum, mean, count, sums, squares = x[0], minimum[0], maximum[0], mean[0], count[0], sums[0], squares[0]
309
+ # if x >= 0:
310
+ # return x
311
+ #
312
+ # order = np.sqrt((squares / count + mean * mean)/2)
313
+ # range = maximum - minimum
314
+ # LOG.warning(f"Negative variance for {name=}, variance={x}")
315
+ # LOG.warning(f"square / count - mean * mean = {squares / count} - {mean * mean} = {squares / count - mean * mean}")
316
+ # LOG.warning(f"Variable order of magnitude is {order}.")
317
+ # LOG.warning(f"Range is {range} ({maximum=} - {minimum=}).")
318
+ # LOG.warning(f"Count is {count}.")
319
+ # if abs(x) < order * 1e-6 and abs(x) < range * 1e-6:
320
+ # LOG.warning(f"Variance is negative but very small, setting to 0.")
321
+ # return x*0
322
+ # return x
323
+
324
+ for i, name in enumerate(self.variables_names):
325
+ # remove negative variance due to numerical errors
326
+ # Not needed for now, fix_variance is disabled
327
+ # x[i] = fix_variance(x[i:i+1], name, minimum[i:i+1], maximum[i:i+1], mean[i:i+1], count[i:i+1], sums[i:i+1], squares[i:i+1])
328
+ check_variance(
329
+ x[i : i + 1],
330
+ [name],
331
+ minimum[i : i + 1],
332
+ maximum[i : i + 1],
333
+ mean[i : i + 1],
334
+ count[i : i + 1],
335
+ sums[i : i + 1],
336
+ squares[i : i + 1],
337
+ )
338
+ check_data_values(np.array([mean[i]]), name=name, allow_nans=False)
339
+
340
+ stdev = np.sqrt(x)
308
341
 
309
342
  return Summary(
310
343
  minimum=minimum,
@@ -8,72 +8,16 @@
8
8
  #
9
9
 
10
10
  import logging
11
- import os
12
11
  import re
13
12
  import textwrap
14
13
  from functools import wraps
15
14
 
16
- LOG = logging.getLogger(__name__)
17
-
18
- TRACE_INDENT = 0
19
-
20
-
21
- def step(action_path):
22
- return f"[{'.'.join(action_path)}]"
15
+ from anemoi.utils.humanize import plural
23
16
 
17
+ from .trace import step
18
+ from .trace import trace
24
19
 
25
- def trace(emoji, *args):
26
- if os.environ.get("ANEMOI_DATASET_TRACE_CREATE") is None:
27
- return
28
- print(emoji, " " * TRACE_INDENT, *args)
29
-
30
-
31
- def trace_datasource(method):
32
- @wraps(method)
33
- def wrapper(self, *args, **kwargs):
34
- global TRACE_INDENT
35
- trace(
36
- "🌍",
37
- "=>",
38
- step(self.action_path),
39
- self._trace_datasource(*args, **kwargs),
40
- )
41
- TRACE_INDENT += 1
42
- result = method(self, *args, **kwargs)
43
- TRACE_INDENT -= 1
44
- trace(
45
- "🍎",
46
- "<=",
47
- step(self.action_path),
48
- textwrap.shorten(repr(result), 256),
49
- )
50
- return result
51
-
52
- return wrapper
53
-
54
-
55
- def trace_select(method):
56
- @wraps(method)
57
- def wrapper(self, *args, **kwargs):
58
- global TRACE_INDENT
59
- trace(
60
- "👓",
61
- "=>",
62
- ".".join(self.action_path),
63
- self._trace_select(*args, **kwargs),
64
- )
65
- TRACE_INDENT += 1
66
- result = method(self, *args, **kwargs)
67
- TRACE_INDENT -= 1
68
- trace(
69
- "🍍",
70
- "<=",
71
- ".".join(self.action_path),
72
- textwrap.shorten(repr(result), 256),
73
- )
74
- return result
75
-
76
- return wrapper
20
+ LOG = logging.getLogger(__name__)
77
21
 
78
22
 
79
23
  def notify_result(method):
@@ -99,7 +43,13 @@ class Context:
99
43
  self.used_references.add(key)
100
44
 
101
45
  def notify_result(self, key, result):
102
- trace("🎯", step(key), "notify result", result)
46
+ trace(
47
+ "🎯",
48
+ step(key),
49
+ "notify result",
50
+ textwrap.shorten(repr(result).replace(",", ", "), width=40),
51
+ plural(len(result), "field"),
52
+ )
103
53
  assert isinstance(key, (list, tuple)), key
104
54
  key = tuple(key)
105
55
  if key in self.used_references:
@@ -0,0 +1,91 @@
1
+ # (C) Copyright 2024 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+ #
9
+
10
+ import logging
11
+ import textwrap
12
+ import threading
13
+ from functools import wraps
14
+
15
+ LOG = logging.getLogger(__name__)
16
+
17
+
18
+ thread_local = threading.local()
19
+ TRACE = 0
20
+
21
+
22
+ def enable_trace(on_off):
23
+ global TRACE
24
+ TRACE = on_off
25
+
26
+
27
+ def step(action_path):
28
+ return f"[{'.'.join(action_path)}]"
29
+
30
+
31
+ def trace(emoji, *args):
32
+
33
+ if not TRACE:
34
+ return
35
+
36
+ if not hasattr(thread_local, "TRACE_INDENT"):
37
+ thread_local.TRACE_INDENT = 0
38
+
39
+ print(emoji, " " * thread_local.TRACE_INDENT, *args)
40
+
41
+
42
+ def trace_datasource(method):
43
+ @wraps(method)
44
+ def wrapper(self, *args, **kwargs):
45
+
46
+ if not hasattr(thread_local, "TRACE_INDENT"):
47
+ thread_local.TRACE_INDENT = 0
48
+
49
+ trace(
50
+ "🌍",
51
+ "=>",
52
+ step(self.action_path),
53
+ self._trace_datasource(*args, **kwargs),
54
+ )
55
+ thread_local.TRACE_INDENT += 1
56
+ result = method(self, *args, **kwargs)
57
+ thread_local.TRACE_INDENT -= 1
58
+ trace(
59
+ "🍎",
60
+ "<=",
61
+ step(self.action_path),
62
+ textwrap.shorten(repr(result), 256),
63
+ )
64
+ return result
65
+
66
+ return wrapper
67
+
68
+
69
+ def trace_select(method):
70
+ @wraps(method)
71
+ def wrapper(self, *args, **kwargs):
72
+ if not hasattr(thread_local, "TRACE_INDENT"):
73
+ thread_local.TRACE_INDENT = 0
74
+ trace(
75
+ "👓",
76
+ "=>",
77
+ ".".join(self.action_path),
78
+ self._trace_select(*args, **kwargs),
79
+ )
80
+ thread_local.TRACE_INDENT += 1
81
+ result = method(self, *args, **kwargs)
82
+ thread_local.TRACE_INDENT -= 1
83
+ trace(
84
+ "🍍",
85
+ "<=",
86
+ ".".join(self.action_path),
87
+ textwrap.shorten(repr(result), 256),
88
+ )
89
+ return result
90
+
91
+ return wrapper
@@ -7,15 +7,11 @@
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
9
 
10
- import json
11
10
  import os
12
11
  from contextlib import contextmanager
13
12
 
14
13
  import numpy as np
15
- import yaml
16
14
  from earthkit.data import settings
17
- from earthkit.data.utils.humanize import seconds # noqa: F401
18
- from tqdm.auto import tqdm
19
15
 
20
16
 
21
17
  def cache_context(dirname):
@@ -31,26 +27,6 @@ def cache_context(dirname):
31
27
  return settings.temporary({"cache-policy": "user", "user-cache-directory": dirname})
32
28
 
33
29
 
34
- def bytes(n):
35
- """>>> bytes(4096)
36
- '4 KiB'
37
- >>> bytes(4000)
38
- '3.9 KiB'
39
- """
40
- if n < 0:
41
- sign = "-"
42
- n -= 0
43
- else:
44
- sign = ""
45
-
46
- u = ["", " KiB", " MiB", " GiB", " TiB", " PiB", " EiB", " ZiB", " YiB"]
47
- i = 0
48
- while n >= 1024:
49
- n /= 1024.0
50
- i += 1
51
- return "%s%g%s" % (sign, int(n * 10 + 0.5) / 10.0, u[i])
52
-
53
-
54
30
  def to_datetime_list(*args, **kwargs):
55
31
  from earthkit.data.utils.dates import to_datetime_list as to_datetime_list_
56
32
 
@@ -63,15 +39,6 @@ def to_datetime(*args, **kwargs):
63
39
  return to_datetime_(*args, **kwargs)
64
40
 
65
41
 
66
- def load_json_or_yaml(path):
67
- with open(path, "r") as f:
68
- if path.endswith(".json"):
69
- return json.load(f)
70
- if path.endswith(".yaml") or path.endswith(".yml"):
71
- return yaml.safe_load(f)
72
- raise ValueError(f"Cannot read file {path}. Need json or yaml with appropriate extension.")
73
-
74
-
75
42
  def make_list_int(value):
76
43
  if isinstance(value, str):
77
44
  if "/" not in value:
@@ -118,18 +85,3 @@ def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]
118
85
  assert d1 == d2, (i, d1, d2)
119
86
 
120
87
  return dates_
121
-
122
-
123
- def progress_bar(*, iterable=None, total=None, initial=0, desc=None):
124
- return tqdm(
125
- iterable=iterable,
126
- total=total,
127
- initial=initial,
128
- unit_scale=True,
129
- unit_divisor=1024,
130
- unit="B",
131
- disable=False,
132
- leave=False,
133
- desc=desc,
134
- # dynamic_ncols=True, # make this the default?
135
- )
@@ -24,8 +24,12 @@ def add_zarr_dataset(
24
24
  shape=None,
25
25
  array=None,
26
26
  overwrite=True,
27
+ dimensions=None,
27
28
  **kwargs,
28
29
  ):
30
+ assert dimensions is not None, "Please pass dimensions to add_zarr_dataset."
31
+ assert isinstance(dimensions, (tuple, list))
32
+
29
33
  if dtype is None:
30
34
  assert array is not None, (name, shape, array, dtype, zarr_root)
31
35
  dtype = array.dtype
@@ -44,6 +48,7 @@ def add_zarr_dataset(
44
48
  **kwargs,
45
49
  )
46
50
  a[...] = array
51
+ a.attrs["_ARRAY_DIMENSIONS"] = dimensions
47
52
  return a
48
53
 
49
54
  if "fill_value" not in kwargs:
@@ -69,6 +74,7 @@ def add_zarr_dataset(
69
74
  overwrite=overwrite,
70
75
  **kwargs,
71
76
  )
77
+ a.attrs["_ARRAY_DIMENSIONS"] = dimensions
72
78
  return a
73
79
 
74
80
 
@@ -79,22 +85,27 @@ class ZarrBuiltRegistry:
79
85
  flags = None
80
86
  z = None
81
87
 
82
- def __init__(self, path, synchronizer_path=None):
88
+ def __init__(self, path, synchronizer_path=None, use_threads=False):
83
89
  import zarr
84
90
 
85
91
  assert isinstance(path, str), path
86
92
  self.zarr_path = path
87
93
 
88
- if synchronizer_path is None:
89
- synchronizer_path = self.zarr_path + ".sync"
90
- self.synchronizer_path = synchronizer_path
91
- self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path)
94
+ if use_threads:
95
+ self.synchronizer = zarr.ThreadSynchronizer()
96
+ self.synchronizer_path = None
97
+ else:
98
+ if synchronizer_path is None:
99
+ synchronizer_path = self.zarr_path + ".sync"
100
+ self.synchronizer_path = synchronizer_path
101
+ self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path)
92
102
 
93
103
  def clean(self):
94
- try:
95
- shutil.rmtree(self.synchronizer_path)
96
- except FileNotFoundError:
97
- pass
104
+ if self.synchronizer_path is not None:
105
+ try:
106
+ shutil.rmtree(self.synchronizer_path)
107
+ except FileNotFoundError:
108
+ pass
98
109
 
99
110
  def _open_write(self):
100
111
  import zarr
@@ -112,7 +123,7 @@ class ZarrBuiltRegistry:
112
123
  def new_dataset(self, *args, **kwargs):
113
124
  z = self._open_write()
114
125
  zarr_root = z["_build"]
115
- add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, **kwargs)
126
+ add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs)
116
127
 
117
128
  def add_to_history(self, action, **kwargs):
118
129
  new = dict(
@@ -143,6 +154,9 @@ class ZarrBuiltRegistry:
143
154
  z.attrs["latest_write_timestamp"] = datetime.datetime.utcnow().isoformat()
144
155
  z["_build"][self.name_flags][i] = value
145
156
 
157
+ def ready(self):
158
+ return all(self.get_flags())
159
+
146
160
  def create(self, lengths, overwrite=False):
147
161
  self.new_dataset(name=self.name_lengths, array=np.array(lengths, dtype="i4"))
148
162
  self.new_dataset(name=self.name_flags, array=np.array([False] * len(lengths), dtype=bool))
@@ -8,63 +8,35 @@
8
8
  import calendar
9
9
  import datetime
10
10
  import logging
11
- import os
12
11
  import re
13
12
  from pathlib import PurePath
14
13
 
15
14
  import numpy as np
16
15
  import zarr
16
+ from anemoi.utils.config import load_config as load_settings
17
17
 
18
18
  from .dataset import Dataset
19
19
 
20
20
  LOG = logging.getLogger(__name__)
21
21
 
22
- CONFIG = None
23
22
 
24
- try:
25
- import tomllib # Only available since 3.11
26
- except ImportError:
27
- import tomli as tomllib
23
+ def load_config():
24
+ return load_settings(defaults={"datasets": {"named": {}, "path": []}})
28
25
 
29
26
 
30
27
  def add_named_dataset(name, path, **kwargs):
31
- load_config()
32
- if name in CONFIG["datasets"]["named"]:
28
+ config = load_config()
29
+ if name["datasets"]["named"]:
33
30
  raise ValueError(f"Dataset {name} already exists")
34
31
 
35
- CONFIG["datasets"]["named"][name] = path
32
+ config["datasets"]["named"][name] = path
36
33
 
37
34
 
38
35
  def add_dataset_path(path):
39
- load_config()
40
-
41
- if path not in CONFIG["datasets"]["path"]:
42
- CONFIG["datasets"]["path"].append(path)
43
-
44
- # save_config()
45
-
46
-
47
- def load_config():
48
- global CONFIG
49
- if CONFIG is not None:
50
- return CONFIG
51
-
52
- conf = os.path.expanduser("~/.config/anemoi/settings.toml")
53
- if not os.path.exists(conf):
54
- conf = os.path.expanduser("~/.anemoi.toml")
55
-
56
- if os.path.exists(conf):
57
-
58
- with open(conf, "rb") as f:
59
- CONFIG = tomllib.load(f)
60
- else:
61
- CONFIG = {}
62
-
63
- CONFIG.setdefault("datasets", {})
64
- CONFIG["datasets"].setdefault("path", [])
65
- CONFIG["datasets"].setdefault("named", {})
36
+ config = load_config()
66
37
 
67
- return CONFIG
38
+ if path not in config["datasets"]["path"]:
39
+ config["datasets"]["path"].append(path)
68
40
 
69
41
 
70
42
  def _frequency_to_hours(frequency):
@@ -9,6 +9,7 @@ import logging
9
9
  import os
10
10
  import warnings
11
11
  from functools import cached_property
12
+ from urllib.parse import urlparse
12
13
 
13
14
  import numpy as np
14
15
  import zarr
@@ -40,7 +41,9 @@ class ReadOnlyStore(zarr.storage.BaseStore):
40
41
 
41
42
 
42
43
  class HTTPStore(ReadOnlyStore):
43
- """We write our own HTTPStore because the one used by zarr (fsspec) does not play well with fork() and multiprocessing."""
44
+ """We write our own HTTPStore because the one used by zarr (s3fs)
45
+ does not play well with fork() and multiprocessing.
46
+ """
44
47
 
45
48
  def __init__(self, url):
46
49
  self.url = url
@@ -58,17 +61,16 @@ class HTTPStore(ReadOnlyStore):
58
61
 
59
62
 
60
63
  class S3Store(ReadOnlyStore):
61
- """We write our own S3Store because the one used by zarr (fsspec)
62
- does not play well with fork() and multiprocessing. Also, we get
63
- to control the s3 client.
64
+ """We write our own S3Store because the one used by zarr (s3fs)
65
+ does not play well with fork(). We also get to control the s3 client
66
+ options using the anemoi configs.
64
67
  """
65
68
 
66
- def __init__(self, url):
69
+ def __init__(self, url, region=None):
67
70
  from anemoi.utils.s3 import s3_client
68
71
 
69
72
  _, _, self.bucket, self.key = url.split("/", 3)
70
-
71
- self.s3 = s3_client(self.bucket)
73
+ self.s3 = s3_client(self.bucket, region=region)
72
74
 
73
75
  def __getitem__(self, key):
74
76
  try:
@@ -101,15 +103,27 @@ class DebugStore(ReadOnlyStore):
101
103
  return key in self.store
102
104
 
103
105
 
104
- def open_zarr(path, dont_fail=False, cache=None):
105
- try:
106
- store = path
106
+ def name_to_zarr_store(path_or_url):
107
+ store = path_or_url
108
+
109
+ if store.startswith("s3://"):
110
+ store = S3Store(store)
107
111
 
108
- if store.startswith("http://") or store.startswith("https://"):
112
+ elif store.startswith("http://") or store.startswith("https://"):
113
+ parsed = urlparse(store)
114
+ bits = parsed.netloc.split(".")
115
+ if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"):
116
+ s3_url = f"s3://{bits[0]}{parsed.path}"
117
+ store = S3Store(s3_url, region=bits[2])
118
+ else:
109
119
  store = HTTPStore(store)
110
120
 
111
- elif store.startswith("s3://"):
112
- store = S3Store(store)
121
+ return store
122
+
123
+
124
+ def open_zarr(path, dont_fail=False, cache=None):
125
+ try:
126
+ store = name_to_zarr_store(path)
113
127
 
114
128
  if DEBUG_ZARR_LOADING:
115
129
  if isinstance(store, str):
@@ -117,7 +131,8 @@ def open_zarr(path, dont_fail=False, cache=None):
117
131
 
118
132
  if not os.path.isdir(store):
119
133
  raise NotImplementedError(
120
- "DEBUG_ZARR_LOADING is only implemented for DirectoryStore. Please disable it for other backends."
134
+ "DEBUG_ZARR_LOADING is only implemented for DirectoryStore. "
135
+ "Please disable it for other backends."
121
136
  )
122
137
  store = zarr.storage.DirectoryStore(store)
123
138
  store = DebugStore(store)
@@ -96,7 +96,7 @@ class ValuesDates(Dates):
96
96
 
97
97
 
98
98
  class StartEndDates(Dates):
99
- def __init__(self, start, end, frequency=1, **kwargs):
99
+ def __init__(self, start, end, frequency=1, months=None, **kwargs):
100
100
  frequency = frequency_to_hours(frequency)
101
101
 
102
102
  def _(x):
@@ -128,6 +128,12 @@ class StartEndDates(Dates):
128
128
  date = start
129
129
  self.values = []
130
130
  while date <= end:
131
+
132
+ if months is not None:
133
+ if date.month not in months:
134
+ date += increment
135
+ continue
136
+
131
137
  self.values.append(date)
132
138
  date += increment
133
139
 
@@ -61,6 +61,9 @@ class Groups:
61
61
  count += 1
62
62
  return count
63
63
 
64
+ def __repr__(self):
65
+ return f"{self.__class__.__name__}(dates={len(self)})"
66
+
64
67
 
65
68
  class Filter:
66
69
  def __init__(self, missing):