anemoi-datasets 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/create.py +3 -2
  3. anemoi/datasets/create/__init__.py +30 -32
  4. anemoi/datasets/create/config.py +4 -3
  5. anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +57 -0
  6. anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +57 -0
  7. anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +54 -0
  8. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +59 -0
  9. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +115 -0
  10. anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +390 -0
  11. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +77 -0
  12. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +55 -0
  13. anemoi/datasets/create/functions/sources/grib.py +86 -1
  14. anemoi/datasets/create/functions/sources/hindcasts.py +14 -73
  15. anemoi/datasets/create/functions/sources/mars.py +9 -3
  16. anemoi/datasets/create/functions/sources/xarray/field.py +7 -1
  17. anemoi/datasets/create/functions/sources/xarray/metadata.py +13 -11
  18. anemoi/datasets/create/input.py +39 -17
  19. anemoi/datasets/create/persistent.py +1 -1
  20. anemoi/datasets/create/utils.py +3 -0
  21. anemoi/datasets/data/dataset.py +11 -1
  22. anemoi/datasets/data/debug.py +5 -1
  23. anemoi/datasets/data/masked.py +2 -2
  24. anemoi/datasets/data/rescale.py +147 -0
  25. anemoi/datasets/data/stores.py +20 -7
  26. anemoi/datasets/dates/__init__.py +112 -30
  27. anemoi/datasets/dates/groups.py +84 -19
  28. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/METADATA +10 -19
  29. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/RECORD +33 -24
  30. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/WHEEL +1 -1
  31. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/LICENSE +0 -0
  32. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/entry_points.txt +0 -0
  33. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,6 @@
6
6
  # granted to it by virtue of its status as an intergovernmental organisation
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
- import datetime
10
9
  import logging
11
10
 
12
11
  from earthkit.data.core.fieldlist import MultiFieldList
@@ -14,7 +13,6 @@ from earthkit.data.core.fieldlist import MultiFieldList
14
13
  from anemoi.datasets.create.functions.sources.mars import mars
15
14
 
16
15
  LOGGER = logging.getLogger(__name__)
17
- DEBUG = True
18
16
 
19
17
 
20
18
  def _to_list(x):
@@ -23,91 +21,34 @@ def _to_list(x):
23
21
  return [x]
24
22
 
25
23
 
26
- class HindcastCompute:
27
- def __init__(self, base_times, available_steps, request):
28
- self.base_times = base_times
29
- self.available_steps = available_steps
30
- self.request = request
31
-
32
- def compute_hindcast(self, date):
33
- result = []
34
- for step in sorted(self.available_steps): # Use the shortest step
35
- start_date = date - datetime.timedelta(hours=step)
36
- hours = start_date.hour
37
- if hours in self.base_times:
38
- r = self.request.copy()
39
- r["date"] = start_date
40
- r["time"] = f"{start_date.hour:02d}00"
41
- r["step"] = step
42
- result.append(r)
43
-
44
- if not result:
45
- raise ValueError(
46
- f"Cannot find data for {self.request} for {date} (base_times={self.base_times}, "
47
- f"available_steps={self.available_steps})"
48
- )
49
-
50
- if len(result) > 1:
51
- raise ValueError(
52
- f"Multiple requests for {self.request} for {date} (base_times={self.base_times}, "
53
- f"available_steps={self.available_steps})"
54
- )
55
-
56
- return result[0]
57
-
58
-
59
- def use_reference_year(reference_year, request):
60
- request = request.copy()
61
- hdate = request.pop("date")
62
-
63
- if hdate.year >= reference_year:
64
- return None, False
24
+ def hindcasts(context, dates, **request):
65
25
 
66
- try:
67
- date = datetime.datetime(reference_year, hdate.month, hdate.day)
68
- except ValueError:
69
- if hdate.month == 2 and hdate.day == 29:
70
- return None, False
71
- raise
26
+ from anemoi.datasets.dates import HindcastsDates
72
27
 
73
- request.update(date=date.strftime("%Y-%m-%d"), hdate=hdate.strftime("%Y-%m-%d"))
74
- return request, True
28
+ provider = context.dates_provider
29
+ assert isinstance(provider, HindcastsDates)
75
30
 
31
+ context.trace("H️", f"hindcasts {len(dates)=}")
76
32
 
77
- def hindcasts(context, dates, **request):
78
33
  request["param"] = _to_list(request["param"])
79
- request["step"] = _to_list(request["step"])
34
+ request["step"] = _to_list(request.get("step", 0))
80
35
  request["step"] = [int(_) for _ in request["step"]]
81
36
 
82
- if request.get("stream") == "enfh" and "base_times" not in request:
83
- request["base_times"] = [0]
84
-
85
- available_steps = request.pop("step")
86
- available_steps = _to_list(available_steps)
87
-
88
- base_times = request.pop("base_times")
89
-
90
- reference_year = request.pop("reference_year")
37
+ context.trace("H️", f"hindcast {request}")
91
38
 
92
- context.trace("H️", f"hindcast {request} {base_times} {available_steps} {reference_year}")
93
-
94
- c = HindcastCompute(base_times, available_steps, request)
95
39
  requests = []
96
40
  for d in dates:
97
- req = c.compute_hindcast(d)
98
- req, ok = use_reference_year(reference_year, req)
99
- if ok:
100
- requests.append(req)
101
-
102
- # print("HINDCASTS requests", reference_year, base_times, available_steps)
103
- # print("HINDCASTS dates", compress_dates(dates))
41
+ r = request.copy()
42
+ hindcast = provider.mapping[d]
43
+ r["hdate"] = hindcast.hdate.strftime("%Y-%m-%d")
44
+ r["date"] = hindcast.refdate.strftime("%Y-%m-%d")
45
+ r["time"] = hindcast.refdate.strftime("%H")
46
+ r["step"] = hindcast.step
47
+ requests.append(r)
104
48
 
105
49
  if len(requests) == 0:
106
- # print("HINDCASTS no requests")
107
50
  return MultiFieldList([])
108
51
 
109
- # print("HINDCASTS requests", requests)
110
-
111
52
  return mars(
112
53
  context,
113
54
  dates,
@@ -203,16 +203,22 @@ def mars(context, dates, *requests, request_already_using_valid_datetime=False,
203
203
  request_already_using_valid_datetime=request_already_using_valid_datetime,
204
204
  date_key=date_key,
205
205
  )
206
+
207
+ requests = list(requests)
208
+
206
209
  ds = from_source("empty")
210
+ context.trace("✅", f"{[str(d) for d in dates]}")
211
+ context.trace("✅", f"Will run {len(requests)} requests")
212
+ for r in requests:
213
+ r = {k: v for k, v in r.items() if v != ("-",)}
214
+ context.trace("✅", f"mars {r}")
215
+
207
216
  for r in requests:
208
217
  r = {k: v for k, v in r.items() if v != ("-",)}
209
218
 
210
219
  if context.use_grib_paramid and "param" in r:
211
220
  r = use_grib_paramid(r)
212
221
 
213
- if DEBUG:
214
- context.trace("✅", f"from_source(mars, {r}")
215
-
216
222
  for k, v in r.items():
217
223
  if k not in MARS_KEYS:
218
224
  raise ValueError(
@@ -7,6 +7,7 @@
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
9
 
10
+ import datetime
10
11
  import logging
11
12
 
12
13
  from earthkit.data.core.fieldlist import Field
@@ -103,7 +104,12 @@ class XArrayField(Field):
103
104
 
104
105
  @property
105
106
  def forecast_reference_time(self):
106
- return self.owner.forecast_reference_time
107
+ date, time = self.metadata("date", "time")
108
+ assert len(time) == 4, time
109
+ assert len(date) == 8, date
110
+ yyyymmdd = int(date)
111
+ time = int(time) // 100
112
+ return datetime.datetime(yyyymmdd // 10000, yyyymmdd // 100 % 100, yyyymmdd % 100, time)
107
113
 
108
114
  def __repr__(self):
109
115
  return repr(self._metadata)
@@ -70,15 +70,17 @@ class XArrayMetadata(RawMetadata):
70
70
  return self._as_mars()
71
71
 
72
72
  def _as_mars(self):
73
- return dict(
74
- param=self["variable"],
75
- step=self["step"],
76
- levelist=self["level"],
77
- levtype=self["levtype"],
78
- number=self["number"],
79
- date=self["date"],
80
- time=self["time"],
81
- )
73
+ return {}
74
+ # p = dict(
75
+ # param=self.get("variable", self.get("param")),
76
+ # step=self.get("step"),
77
+ # levelist=self.get("levelist", self.get("level")),
78
+ # levtype=self.get("levtype"),
79
+ # number=self.get("number"),
80
+ # date=self.get("date"),
81
+ # time=self.get("time"),
82
+ # )
83
+ # return {k: v for k, v in p.items() if v is not None}
82
84
 
83
85
  def _base_datetime(self):
84
86
  return self._field.forecast_reference_time
@@ -135,12 +137,12 @@ class XArrayFieldGeography(Geography):
135
137
  # TODO: implement resolution
136
138
  return None
137
139
 
138
- @property
140
+ # @property
139
141
  def mars_grid(self):
140
142
  # TODO: implement mars_grid
141
143
  return None
142
144
 
143
- @property
145
+ # @property
144
146
  def mars_area(self):
145
147
  # TODO: code me
146
148
  # return [self.north, self.west, self.south, self.east]
@@ -23,7 +23,7 @@ from earthkit.data.core.fieldlist import FieldList
23
23
  from earthkit.data.core.fieldlist import MultiFieldList
24
24
  from earthkit.data.core.order import build_remapping
25
25
 
26
- from anemoi.datasets.dates import Dates
26
+ from anemoi.datasets.dates import DatesProvider
27
27
 
28
28
  from .functions import import_function
29
29
  from .template import Context
@@ -75,7 +75,7 @@ def time_delta_to_string(delta):
75
75
 
76
76
 
77
77
  def is_function(name, kind):
78
- name, delta = parse_function_name(name) # noqa
78
+ name, _ = parse_function_name(name)
79
79
  try:
80
80
  import_function(name, kind)
81
81
  return True
@@ -204,11 +204,15 @@ class Result:
204
204
  _coords_already_built = False
205
205
 
206
206
  def __init__(self, context, action_path, dates):
207
+ from anemoi.datasets.dates.groups import GroupOfDates
208
+
209
+ assert isinstance(dates, GroupOfDates), dates
210
+
207
211
  assert isinstance(context, ActionContext), type(context)
208
212
  assert isinstance(action_path, list), action_path
209
213
 
210
214
  self.context = context
211
- self.dates = dates
215
+ self.group_of_dates = dates
212
216
  self.action_path = action_path
213
217
 
214
218
  @property
@@ -405,10 +409,10 @@ class Result:
405
409
  more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])
406
410
 
407
411
  dates = " no-dates"
408
- if self.dates is not None:
409
- dates = f" {len(self.dates)} dates"
412
+ if self.group_of_dates is not None:
413
+ dates = f" {len(self.group_of_dates)} dates"
410
414
  dates += " ("
411
- dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.dates)
415
+ dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.group_of_dates)
412
416
  if len(dates) > 100:
413
417
  dates = dates[:100] + "..."
414
418
  dates += ")"
@@ -423,7 +427,7 @@ class Result:
423
427
  raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
424
428
 
425
429
  def _trace_datasource(self, *args, **kwargs):
426
- return f"{self.__class__.__name__}({shorten(self.dates)})"
430
+ return f"{self.__class__.__name__}({self.group_of_dates})"
427
431
 
428
432
  def build_coords(self):
429
433
  if self._coords_already_built:
@@ -513,7 +517,7 @@ class Result:
513
517
  @cached_property
514
518
  def shape(self):
515
519
  return [
516
- len(self.dates),
520
+ len(self.group_of_dates),
517
521
  len(self.variables),
518
522
  len(self.ensembles),
519
523
  len(self.grid_values),
@@ -522,7 +526,7 @@ class Result:
522
526
  @cached_property
523
527
  def coords(self):
524
528
  return {
525
- "dates": self.dates,
529
+ "dates": list(self.group_of_dates),
526
530
  "variables": self.variables,
527
531
  "ensembles": self.ensembles,
528
532
  "values": self.grid_values,
@@ -573,7 +577,7 @@ class FunctionResult(Result):
573
577
  self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs))
574
578
 
575
579
  def _trace_datasource(self, *args, **kwargs):
576
- return f"{self.action.name}({shorten(self.dates)})"
580
+ return f"{self.action.name}({self.group_of_dates})"
577
581
 
578
582
  @cached_property
579
583
  @assert_fieldlist
@@ -583,14 +587,21 @@ class FunctionResult(Result):
583
587
  args, kwargs = resolve(self.context, (self.args, self.kwargs))
584
588
 
585
589
  try:
586
- return _tidy(self.action.function(FunctionContext(self), self.dates, *args, **kwargs))
590
+ return _tidy(
591
+ self.action.function(
592
+ FunctionContext(self),
593
+ list(self.group_of_dates), # Will provide a list of datetime objects
594
+ *args,
595
+ **kwargs,
596
+ )
597
+ )
587
598
  except Exception:
588
599
  LOG.error(f"Error in {self.action.function.__name__}", exc_info=True)
589
600
  raise
590
601
 
591
602
  def __repr__(self):
592
603
  try:
593
- return f"{self.action.name}({shorten(self.dates)})"
604
+ return f"{self.action.name}({self.group_of_dates})"
594
605
  except Exception:
595
606
  return f"{self.__class__.__name__}(unitialised)"
596
607
 
@@ -609,7 +620,7 @@ class JoinResult(Result):
609
620
  @notify_result
610
621
  @trace_datasource
611
622
  def datasource(self):
612
- ds = EmptyResult(self.context, self.action_path, self.dates).datasource
623
+ ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
613
624
  for i in self.results:
614
625
  ds += i.datasource
615
626
  return _tidy(ds)
@@ -824,7 +835,7 @@ class ConcatResult(Result):
824
835
  @notify_result
825
836
  @trace_datasource
826
837
  def datasource(self):
827
- ds = EmptyResult(self.context, self.action_path, self.dates).datasource
838
+ ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
828
839
  for i in self.results:
829
840
  ds += i.datasource
830
841
  return _tidy(ds)
@@ -904,7 +915,7 @@ class ConcatAction(Action):
904
915
  cfg = deepcopy(cfg)
905
916
  dates_cfg = cfg.pop("dates")
906
917
  assert isinstance(dates_cfg, dict), dates_cfg
907
- filtering_dates = Dates.from_config(**dates_cfg)
918
+ filtering_dates = DatesProvider.from_config(**dates_cfg)
908
919
  action = action_factory(cfg, context, action_path + [str(i)])
909
920
  parts.append((filtering_dates, action))
910
921
  self.parts = parts
@@ -915,9 +926,11 @@ class ConcatAction(Action):
915
926
 
916
927
  @trace_select
917
928
  def select(self, dates):
929
+ from anemoi.datasets.dates.groups import GroupOfDates
930
+
918
931
  results = []
919
932
  for filtering_dates, action in self.parts:
920
- newdates = sorted(set(dates) & set(filtering_dates))
933
+ newdates = GroupOfDates(sorted(set(dates) & set(filtering_dates)), dates.provider)
921
934
  if newdates:
922
935
  results.append(action.select(newdates))
923
936
  if not results:
@@ -953,8 +966,10 @@ def action_factory(config, context, action_path):
953
966
 
954
967
  if isinstance(config[key], list):
955
968
  args, kwargs = config[key], {}
956
- if isinstance(config[key], dict):
969
+ elif isinstance(config[key], dict):
957
970
  args, kwargs = [], config[key]
971
+ else:
972
+ raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}")
958
973
 
959
974
  cls = {
960
975
  # "date_shift": DateShiftAction,
@@ -1021,6 +1036,13 @@ class FunctionContext:
1021
1036
  def trace(self, emoji, *args):
1022
1037
  trace(emoji, *args)
1023
1038
 
1039
+ def info(self, *args, **kwargs):
1040
+ LOG.info(*args, **kwargs)
1041
+
1042
+ @property
1043
+ def dates_provider(self):
1044
+ return self.owner.group_of_dates.provider
1045
+
1024
1046
 
1025
1047
  class ActionContext(Context):
1026
1048
  def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid):
@@ -68,7 +68,7 @@ class PersistentDict:
68
68
  path = os.path.join(self.dirname, f"{h}.pickle")
69
69
 
70
70
  if os.path.exists(path):
71
- LOG.warn(f"{path} already exists")
71
+ LOG.warning(f"{path} already exists")
72
72
 
73
73
  tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}"
74
74
  with open(tmp_path, "wb") as f:
@@ -62,6 +62,9 @@ def make_list_int(value):
62
62
 
63
63
 
64
64
  def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]"):
65
+
66
+ dates = [d.hdate if hasattr(d, "hdate") else d for d in dates]
67
+
65
68
  assert isinstance(frequency, datetime.timedelta), frequency
66
69
  start = np.datetime64(start)
67
70
  end = np.datetime64(end)
@@ -23,7 +23,11 @@ LOG = logging.getLogger(__name__)
23
23
  class Dataset:
24
24
  arguments = {}
25
25
 
26
- def mutate(self):
26
+ def mutate(self) -> "Dataset":
27
+ """
28
+ Give an opportunity to a subclass to return a new Dataset
29
+ object of a different class, if needed.
30
+ """
27
31
  return self
28
32
 
29
33
  def swap_with_parent(self, parent):
@@ -90,6 +94,12 @@ class Dataset:
90
94
  rename = kwargs.pop("rename")
91
95
  return Rename(self, rename)._subset(**kwargs).mutate()
92
96
 
97
+ if "rescale" in kwargs:
98
+ from .rescale import Rescale
99
+
100
+ rescale = kwargs.pop("rescale")
101
+ return Rescale(self, rescale)._subset(**kwargs).mutate()
102
+
93
103
  if "statistics" in kwargs:
94
104
  from ..data import open_dataset
95
105
  from .statistics import Statistics
@@ -209,10 +209,14 @@ def _debug_indexing(method):
209
209
  return wrapper
210
210
 
211
211
 
212
+ def _identity(x):
213
+ return x
214
+
215
+
212
216
  if DEBUG_ZARR_INDEXING:
213
217
  debug_indexing = _debug_indexing
214
218
  else:
215
- debug_indexing = lambda x: x # noqa
219
+ debug_indexing = _identity
216
220
 
217
221
 
218
222
  def debug_zarr_loading(on_off):
@@ -112,5 +112,5 @@ class Cropping(Masked):
112
112
  def tree(self):
113
113
  return Node(self, [self.forward.tree()], area=self.area)
114
114
 
115
- def metadata_specific(self, **kwargs):
116
- return super().metadata_specific(area=self.area, **kwargs)
115
+ def subclass_metadata_specific(self):
116
+ return dict(area=self.area)
@@ -0,0 +1,147 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ import logging
9
+ from functools import cached_property
10
+
11
+ import numpy as np
12
+
13
+ from .debug import Node
14
+ from .debug import debug_indexing
15
+ from .forwards import Forwards
16
+ from .indexing import apply_index_to_slices_changes
17
+ from .indexing import expand_list_indexing
18
+ from .indexing import index_to_slices
19
+ from .indexing import update_tuple
20
+
21
+ LOG = logging.getLogger(__name__)
22
+
23
+
24
+ def make_rescale(variable, rescale):
25
+
26
+ if isinstance(rescale, (tuple, list)):
27
+
28
+ assert len(rescale) == 2, rescale
29
+
30
+ if isinstance(rescale[0], (int, float)):
31
+ return rescale
32
+
33
+ from cfunits import Units
34
+
35
+ u0 = Units(rescale[0])
36
+ u1 = Units(rescale[1])
37
+
38
+ x1, x2 = 0.0, 1.0
39
+ y1, y2 = Units.conform([x1, x2], u0, u1)
40
+
41
+ a = (y2 - y1) / (x2 - x1)
42
+ b = y1 - a * x1
43
+
44
+ return a, b
45
+
46
+ return rescale
47
+
48
+ if isinstance(rescale, dict):
49
+ assert "scale" in rescale, rescale
50
+ assert "offset" in rescale, rescale
51
+ return rescale["scale"], rescale["offset"]
52
+
53
+ assert False
54
+
55
+
56
+ class Rescale(Forwards):
57
+ def __init__(self, dataset, rescale):
58
+ super().__init__(dataset)
59
+ for n in rescale:
60
+ assert n in dataset.variables, n
61
+
62
+ variables = dataset.variables
63
+
64
+ self._a = np.ones(len(variables))
65
+ self._b = np.zeros(len(variables))
66
+
67
+ self.rescale = {}
68
+ for i, v in enumerate(variables):
69
+ if v in rescale:
70
+ a, b = make_rescale(v, rescale[v])
71
+ self.rescale[v] = a, b
72
+ self._a[i], self._b[i] = a, b
73
+
74
+ self._a = self._a[np.newaxis, :, np.newaxis, np.newaxis]
75
+ self._b = self._b[np.newaxis, :, np.newaxis, np.newaxis]
76
+
77
+ self._a = self._a.astype(self.forward.dtype)
78
+ self._b = self._b.astype(self.forward.dtype)
79
+
80
+ def tree(self):
81
+ return Node(self, [self.forward.tree()], rescale=self.rescale)
82
+
83
+ def subclass_metadata_specific(self):
84
+ return dict(rescale=self.rescale)
85
+
86
+ @debug_indexing
87
+ @expand_list_indexing
88
+ def _get_tuple(self, index):
89
+ index, changes = index_to_slices(index, self.shape)
90
+ index, previous = update_tuple(index, 1, slice(None))
91
+ result = self.forward[index]
92
+ result = result * self._a + self._b
93
+ result = result[:, previous]
94
+ result = apply_index_to_slices_changes(result, changes)
95
+ return result
96
+
97
+ @debug_indexing
98
+ def __get_slice_(self, n):
99
+ data = self.forward[n]
100
+ return data * self._a + self._b
101
+
102
+ @debug_indexing
103
+ def __getitem__(self, n):
104
+
105
+ if isinstance(n, tuple):
106
+ return self._get_tuple(n)
107
+
108
+ if isinstance(n, slice):
109
+ return self.__get_slice_(n)
110
+
111
+ data = self.forward[n]
112
+
113
+ return data * self._a[0] + self._b[0]
114
+
115
+ @cached_property
116
+ def statistics(self):
117
+ result = {}
118
+ a = self._a.squeeze()
119
+ assert np.all(a >= 0)
120
+
121
+ b = self._b.squeeze()
122
+ for k, v in self.forward.statistics.items():
123
+ if k in ("maximum", "minimum", "mean"):
124
+ result[k] = v * a + b
125
+ continue
126
+
127
+ if k in ("stdev",):
128
+ result[k] = v * a
129
+ continue
130
+
131
+ raise NotImplementedError("rescale statistics", k)
132
+
133
+ return result
134
+
135
+ def statistics_tendencies(self, delta=None):
136
+ result = {}
137
+ a = self._a.squeeze()
138
+ assert np.all(a >= 0)
139
+
140
+ for k, v in self.forward.statistics_tendencies(delta).items():
141
+ if k in ("maximum", "minimum", "mean", "stdev"):
142
+ result[k] = v * a
143
+ continue
144
+
145
+ raise NotImplementedError("rescale tendencies statistics", k)
146
+
147
+ return result
@@ -5,6 +5,7 @@
5
5
  # granted to it by virtue of its status as an intergovernmental organisation
6
6
  # nor does it submit to any jurisdiction.
7
7
 
8
+
8
9
  import logging
9
10
  import os
10
11
  import warnings
@@ -83,6 +84,8 @@ class S3Store(ReadOnlyStore):
83
84
 
84
85
 
85
86
  class DebugStore(ReadOnlyStore):
87
+ """A store to debug the zarr loading."""
88
+
86
89
  def __init__(self, store):
87
90
  assert not isinstance(store, DebugStore)
88
91
  self.store = store
@@ -148,6 +151,8 @@ def open_zarr(path, dont_fail=False, cache=None):
148
151
 
149
152
 
150
153
  class Zarr(Dataset):
154
+ """A zarr dataset."""
155
+
151
156
  def __init__(self, path):
152
157
  if isinstance(path, zarr.hierarchy.Group):
153
158
  self.was_zarr = True
@@ -244,14 +249,20 @@ class Zarr(Dataset):
244
249
  delta = self.frequency
245
250
  if isinstance(delta, int):
246
251
  delta = f"{delta}h"
247
- from anemoi.datasets.create.loaders import TendenciesStatisticsAddition
252
+ from anemoi.utils.dates import frequency_to_string
253
+ from anemoi.utils.dates import frequency_to_timedelta
254
+
255
+ delta = frequency_to_timedelta(delta)
256
+ delta = frequency_to_string(delta)
257
+
258
+ def func(k):
259
+ return f"statistics_tendencies_{delta}_{k}"
248
260
 
249
- func = TendenciesStatisticsAddition.final_storage_name_from_delta
250
261
  return dict(
251
- mean=self.z[func("mean", delta)][:],
252
- stdev=self.z[func("stdev", delta)][:],
253
- maximum=self.z[func("maximum", delta)][:],
254
- minimum=self.z[func("minimum", delta)][:],
262
+ mean=self.z[func("mean")][:],
263
+ stdev=self.z[func("stdev")][:],
264
+ maximum=self.z[func("maximum")][:],
265
+ minimum=self.z[func("minimum")][:],
255
266
  )
256
267
 
257
268
  @property
@@ -322,11 +333,13 @@ class Zarr(Dataset):
322
333
 
323
334
 
324
335
  class ZarrWithMissingDates(Zarr):
336
+ """A zarr dataset with missing dates."""
337
+
325
338
  def __init__(self, path):
326
339
  super().__init__(path)
327
340
 
328
341
  missing_dates = self.z.attrs.get("missing_dates", [])
329
- missing_dates = [np.datetime64(x) for x in missing_dates]
342
+ missing_dates = set([np.datetime64(x) for x in missing_dates])
330
343
  self.missing_to_dates = {i: d for i, d in enumerate(self.dates) if d in missing_dates}
331
344
  self.missing = set(self.missing_to_dates)
332
345