anemoi-datasets 0.5.0__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/inspect.py +1 -1
  3. anemoi/datasets/commands/publish.py +30 -0
  4. anemoi/datasets/create/__init__.py +42 -3
  5. anemoi/datasets/create/check.py +6 -0
  6. anemoi/datasets/create/functions/filters/rename.py +2 -3
  7. anemoi/datasets/create/functions/sources/__init__.py +7 -1
  8. anemoi/datasets/create/functions/sources/accumulations.py +2 -0
  9. anemoi/datasets/create/functions/sources/grib.py +1 -1
  10. anemoi/datasets/create/functions/sources/xarray/__init__.py +7 -2
  11. anemoi/datasets/create/functions/sources/xarray/coordinates.py +12 -1
  12. anemoi/datasets/create/functions/sources/xarray/field.py +13 -4
  13. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +16 -16
  14. anemoi/datasets/create/functions/sources/xarray/flavour.py +130 -13
  15. anemoi/datasets/create/functions/sources/xarray/grid.py +106 -17
  16. anemoi/datasets/create/functions/sources/xarray/metadata.py +3 -11
  17. anemoi/datasets/create/functions/sources/xarray/time.py +1 -5
  18. anemoi/datasets/create/functions/sources/xarray/variable.py +10 -10
  19. anemoi/datasets/create/input/__init__.py +69 -0
  20. anemoi/datasets/create/input/action.py +123 -0
  21. anemoi/datasets/create/input/concat.py +92 -0
  22. anemoi/datasets/create/input/context.py +59 -0
  23. anemoi/datasets/create/input/data_sources.py +71 -0
  24. anemoi/datasets/create/input/empty.py +42 -0
  25. anemoi/datasets/create/input/filter.py +76 -0
  26. anemoi/datasets/create/input/function.py +122 -0
  27. anemoi/datasets/create/input/join.py +57 -0
  28. anemoi/datasets/create/input/misc.py +85 -0
  29. anemoi/datasets/create/input/pipe.py +33 -0
  30. anemoi/datasets/create/input/repeated_dates.py +217 -0
  31. anemoi/datasets/create/input/result.py +413 -0
  32. anemoi/datasets/create/input/step.py +99 -0
  33. anemoi/datasets/create/{template.py → input/template.py} +0 -42
  34. anemoi/datasets/create/statistics/__init__.py +1 -1
  35. anemoi/datasets/create/zarr.py +4 -2
  36. anemoi/datasets/dates/__init__.py +1 -0
  37. anemoi/datasets/dates/groups.py +12 -4
  38. anemoi/datasets/fields.py +66 -0
  39. anemoi/datasets/utils/fields.py +47 -0
  40. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/METADATA +1 -1
  41. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/RECORD +46 -30
  42. anemoi/datasets/create/input.py +0 -1087
  43. /anemoi/datasets/create/{trace.py → input/trace.py} +0 -0
  44. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/LICENSE +0 -0
  45. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/WHEEL +0 -0
  46. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/entry_points.txt +0 -0
  47. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/top_level.txt +0 -0
@@ -8,39 +8,128 @@
8
8
  #
9
9
 
10
10
 
11
+ import logging
12
+ from functools import cached_property
13
+
11
14
  import numpy as np
12
15
 
16
+ LOG = logging.getLogger(__name__)
17
+
13
18
 
14
19
  class Grid:
15
- def __init__(self, lat, lon):
16
- self.lat = lat
17
- self.lon = lon
20
+
21
+ def __init__(self):
22
+ pass
18
23
 
19
24
  @property
20
25
  def latitudes(self):
21
- return self.grid_points()[0]
26
+ return self.grid_points[0]
22
27
 
23
28
  @property
24
29
  def longitudes(self):
25
- return self.grid_points()[1]
30
+ return self.grid_points[1]
31
+
32
+
33
+ class LatLonGrid(Grid):
34
+ def __init__(self, lat, lon, variable_dims):
35
+ super().__init__()
36
+ self.lat = lat
37
+ self.lon = lon
38
+
26
39
 
40
+ class XYGrid(Grid):
41
+ def __init__(self, x, y):
42
+ self.x = x
43
+ self.y = y
27
44
 
28
- class MeshedGrid(Grid):
29
- _cache = None
30
45
 
46
+ class MeshedGrid(LatLonGrid):
47
+
48
+ @cached_property
31
49
  def grid_points(self):
32
- if self._cache is not None:
33
- return self._cache
34
- lat = self.lat.variable.values
35
- lon = self.lon.variable.values
36
50
 
37
- lat, lon = np.meshgrid(lat, lon)
38
- self._cache = (lat.flatten(), lon.flatten())
39
- return self._cache
51
+ lat, lon = np.meshgrid(
52
+ self.lat.variable.values,
53
+ self.lon.variable.values,
54
+ )
55
+
56
+ return lat.flatten(), lon.flatten()
57
+
40
58
 
59
+ class UnstructuredGrid(LatLonGrid):
41
60
 
42
- class UnstructuredGrid(Grid):
61
+ def __init__(self, lat, lon, variable_dims):
62
+ super().__init__(lat, lon, variable_dims)
63
+ assert len(lat) == len(lon), (len(lat), len(lon))
64
+ self.variable_dims = variable_dims
65
+ self.grid_dims = lat.variable.dims
66
+ assert lon.variable.dims == self.grid_dims, (lon.variable.dims, self.grid_dims)
67
+ assert set(self.variable_dims) == set(self.grid_dims), (self.variable_dims, self.grid_dims)
68
+
69
+ @cached_property
43
70
  def grid_points(self):
44
- lat = self.lat.variable.values.flatten()
45
- lon = self.lon.variable.values.flatten()
71
+
72
+ assert 1 <= len(self.variable_dims) <= 2
73
+
74
+ if len(self.variable_dims) == 1:
75
+ return self.lat.variable.values.flatten(), self.lon.variable.values.flatten()
76
+
77
+ if len(self.variable_dims) == 2 and self.variable_dims == self.grid_dims:
78
+ return self.lat.variable.values.flatten(), self.lon.variable.values.flatten()
79
+
80
+ LOG.warning(
81
+ "UnstructuredGrid: variable indexing %s does not match grid indexing %s", self.variable_dims, self.grid_dims
82
+ )
83
+
84
+ lat = self.lat.variable.values.transpose().flatten()
85
+ lon = self.lon.variable.values.transpose().flatten()
86
+
46
87
  return lat, lon
88
+
89
+
90
+ class ProjectionGrid(XYGrid):
91
+ def __init__(self, x, y, projection):
92
+ super().__init__(x, y)
93
+ self.projection = projection
94
+
95
+ def transformer(self):
96
+ from pyproj import CRS
97
+ from pyproj import Transformer
98
+
99
+ if isinstance(self.projection, dict):
100
+ data_crs = CRS.from_cf(self.projection)
101
+ else:
102
+ data_crs = self.projection
103
+ wgs84_crs = CRS.from_epsg(4326) # WGS84
104
+
105
+ return Transformer.from_crs(data_crs, wgs84_crs, always_xy=True)
106
+
107
+
108
+ class MeshProjectionGrid(ProjectionGrid):
109
+
110
+ @cached_property
111
+ def grid_points(self):
112
+
113
+ transformer = self.transformer()
114
+ xv, yv = np.meshgrid(self.x.variable.values, self.y.variable.values) # , indexing="ij")
115
+ lon, lat = transformer.transform(xv, yv)
116
+ return lat.flatten(), lon.flatten()
117
+
118
+
119
+ class UnstructuredProjectionGrid(XYGrid):
120
+ @cached_property
121
+ def grid_points(self):
122
+ assert False, "Not implemented"
123
+
124
+ # lat, lon = transformer.transform(
125
+ # self.y.variable.values.flatten(),
126
+ # self.x.variable.values.flatten(),
127
+
128
+ # )
129
+
130
+ # lat = lat[::len(lat)//100]
131
+ # lon = lon[::len(lon)//100]
132
+
133
+ # print(len(lat), len(lon))
134
+
135
+ # return np.meshgrid(lat, lon)
@@ -40,7 +40,9 @@ class _MDMapping:
40
40
  return f"MDMapping({self.mapping})"
41
41
 
42
42
  def fill_time_metadata(self, field, md):
43
- md["valid_datetime"] = as_datetime(self.variable.time.fill_time_metadata(field._md, md)).isoformat()
43
+ valid_datetime = self.variable.time.fill_time_metadata(field._md, md)
44
+ if valid_datetime is not None:
45
+ md["valid_datetime"] = as_datetime(valid_datetime).isoformat()
44
46
 
45
47
 
46
48
  class XArrayMetadata(RawMetadata):
@@ -71,16 +73,6 @@ class XArrayMetadata(RawMetadata):
71
73
 
72
74
  def _as_mars(self):
73
75
  return {}
74
- # p = dict(
75
- # param=self.get("variable", self.get("param")),
76
- # step=self.get("step"),
77
- # levelist=self.get("levelist", self.get("level")),
78
- # levtype=self.get("levtype"),
79
- # number=self.get("number"),
80
- # date=self.get("date"),
81
- # time=self.get("time"),
82
- # )
83
- # return {k: v for k, v in p.items() if v is not None}
84
76
 
85
77
  def _base_datetime(self):
86
78
  return self._field.forecast_reference_time
@@ -42,11 +42,7 @@ class Time:
42
42
  class Constant(Time):
43
43
 
44
44
  def fill_time_metadata(self, coords_values, metadata):
45
- raise NotImplementedError("Constant time not implemented")
46
- # print("Constant", coords_values, metadata)
47
- # metadata["date"] = time.strftime("%Y%m%d")
48
- # metadata["time"] = time.strftime("%H%M")
49
- # metadata["step"] = 0
45
+ return None
50
46
 
51
47
 
52
48
  class Analysis(Time):
@@ -24,7 +24,7 @@ class Variable:
24
24
  self,
25
25
  *,
26
26
  ds,
27
- var,
27
+ variable,
28
28
  coordinates,
29
29
  grid,
30
30
  time,
@@ -32,13 +32,13 @@ class Variable:
32
32
  array_backend=None,
33
33
  ):
34
34
  self.ds = ds
35
- self.var = var
35
+ self.variable = variable
36
36
 
37
37
  self.grid = grid
38
38
  self.coordinates = coordinates
39
39
 
40
40
  self._metadata = metadata.copy()
41
- self._metadata.update({"variable": var.name})
41
+ self._metadata.update({"variable": variable.name})
42
42
 
43
43
  self.time = time
44
44
 
@@ -51,20 +51,20 @@ class Variable:
51
51
 
52
52
  @property
53
53
  def name(self):
54
- return self.var.name
54
+ return self.variable.name
55
55
 
56
56
  def __len__(self):
57
57
  return self.length
58
58
 
59
59
  @property
60
60
  def grid_mapping(self):
61
- grid_mapping = self.var.attrs.get("grid_mapping", None)
61
+ grid_mapping = self.variable.attrs.get("grid_mapping", None)
62
62
  if grid_mapping is None:
63
63
  return None
64
64
  return self.ds[grid_mapping].attrs
65
65
 
66
66
  def grid_points(self):
67
- return self.grid.grid_points()
67
+ return self.grid.grid_points
68
68
 
69
69
  @property
70
70
  def latitudes(self):
@@ -76,7 +76,7 @@ class Variable:
76
76
 
77
77
  def __repr__(self):
78
78
  return "Variable[name=%s,coordinates=%s,metadata=%s]" % (
79
- self.var.name,
79
+ self.variable.name,
80
80
  self.coordinates,
81
81
  self._metadata,
82
82
  )
@@ -90,7 +90,7 @@ class Variable:
90
90
 
91
91
  coords = np.unravel_index(i, self.shape)
92
92
  kwargs = {k: v for k, v in zip(self.names, coords)}
93
- return XArrayField(self, self.var.isel(kwargs))
93
+ return XArrayField(self, self.variable.isel(kwargs))
94
94
 
95
95
  def sel(self, missing, **kwargs):
96
96
 
@@ -117,7 +117,7 @@ class Variable:
117
117
 
118
118
  variable = Variable(
119
119
  ds=self.ds,
120
- var=self.var.isel({k: i}),
120
+ var=self.variable.isel({k: i}),
121
121
  coordinates=coordinates,
122
122
  grid=self.grid,
123
123
  time=self.time,
@@ -136,7 +136,7 @@ class Variable:
136
136
  name = kwargs.pop("variable")
137
137
  if not isinstance(name, (list, tuple)):
138
138
  name = [name]
139
- if self.var.name not in name:
139
+ if self.variable.name not in name:
140
140
  return False, None
141
141
  return True, kwargs
142
142
  return True, kwargs
@@ -0,0 +1,69 @@
1
+ # (C) Copyright 2023 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+ #
9
+ import datetime
10
+ import itertools
11
+ import logging
12
+ import math
13
+ import time
14
+ from collections import defaultdict
15
+ from copy import deepcopy
16
+ from functools import cached_property
17
+ from functools import wraps
18
+
19
+ import numpy as np
20
+ from anemoi.utils.dates import as_datetime as as_datetime
21
+ from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
22
+
23
+ from anemoi.datasets.dates import DatesProvider as DatesProvider
24
+ from anemoi.datasets.fields import FieldArray as FieldArray
25
+ from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
26
+
27
+ from .trace import trace_select
28
+
29
+ LOG = logging.getLogger(__name__)
30
+
31
+
32
+ class InputBuilder:
33
+ def __init__(self, config, data_sources, **kwargs):
34
+ self.kwargs = kwargs
35
+
36
+ config = deepcopy(config)
37
+ if data_sources:
38
+ config = dict(
39
+ data_sources=dict(
40
+ sources=data_sources,
41
+ input=config,
42
+ )
43
+ )
44
+ self.config = config
45
+ self.action_path = ["input"]
46
+
47
+ @trace_select
48
+ def select(self, group_of_dates):
49
+ from .action import ActionContext
50
+ from .action import action_factory
51
+
52
+ """This changes the context."""
53
+ context = ActionContext(**self.kwargs)
54
+ action = action_factory(self.config, context, self.action_path)
55
+ return action.select(group_of_dates)
56
+
57
+ def __repr__(self):
58
+ from .action import ActionContext
59
+ from .action import action_factory
60
+
61
+ context = ActionContext(**self.kwargs)
62
+ a = action_factory(self.config, context, self.action_path)
63
+ return repr(a)
64
+
65
+ def _trace_select(self, group_of_dates):
66
+ return f"InputBuilder({group_of_dates})"
67
+
68
+
69
+ build_input = InputBuilder
@@ -0,0 +1,123 @@
1
+ # (C) Copyright 2024 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+ #
9
+ import logging
10
+ from copy import deepcopy
11
+
12
+ from anemoi.utils.dates import as_datetime as as_datetime
13
+ from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
14
+ from earthkit.data.core.order import build_remapping
15
+
16
+ from anemoi.datasets.dates import DatesProvider as DatesProvider
17
+ from anemoi.datasets.fields import FieldArray as FieldArray
18
+ from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
19
+
20
+ from .context import Context
21
+ from .misc import is_function
22
+
23
+ LOG = logging.getLogger(__name__)
24
+
25
+
26
+ class Action:
27
+ def __init__(self, context, action_path, /, *args, **kwargs):
28
+ if "args" in kwargs and "kwargs" in kwargs:
29
+ """We have:
30
+ args = []
31
+ kwargs = {args: [...], kwargs: {...}}
32
+ move the content of kwargs to args and kwargs.
33
+ """
34
+ assert len(kwargs) == 2, (args, kwargs)
35
+ assert not args, (args, kwargs)
36
+ args = kwargs.pop("args")
37
+ kwargs = kwargs.pop("kwargs")
38
+
39
+ assert isinstance(context, ActionContext), type(context)
40
+ self.context = context
41
+ self.kwargs = kwargs
42
+ self.args = args
43
+ self.action_path = action_path
44
+
45
+ @classmethod
46
+ def _short_str(cls, x):
47
+ x = str(x)
48
+ if len(x) < 1000:
49
+ return x
50
+ return x[:1000] + "..."
51
+
52
+ def __repr__(self, *args, _indent_="\n", _inline_="", **kwargs):
53
+ more = ",".join([str(a)[:5000] for a in args])
54
+ more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])
55
+
56
+ more = more[:5000]
57
+ txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}"
58
+ if _indent_:
59
+ txt = txt.replace("\n", "\n ")
60
+ return txt
61
+
62
+ def select(self, dates, **kwargs):
63
+ self._raise_not_implemented()
64
+
65
+ def _raise_not_implemented(self):
66
+ raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
67
+
68
+ def _trace_select(self, group_of_dates):
69
+ return f"{self.__class__.__name__}({group_of_dates})"
70
+
71
+
72
+ class ActionContext(Context):
73
+ def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid):
74
+ super().__init__()
75
+ self.order_by = order_by
76
+ self.flatten_grid = flatten_grid
77
+ self.remapping = build_remapping(remapping)
78
+ self.use_grib_paramid = use_grib_paramid
79
+
80
+
81
+ def action_factory(config, context, action_path):
82
+
83
+ from .concat import ConcatAction
84
+ from .data_sources import DataSourcesAction
85
+ from .function import FunctionAction
86
+ from .join import JoinAction
87
+ from .pipe import PipeAction
88
+ from .repeated_dates import RepeatedDatesAction
89
+
90
+ # from .data_sources import DataSourcesAction
91
+
92
+ assert isinstance(context, Context), (type, context)
93
+ if not isinstance(config, dict):
94
+ raise ValueError(f"Invalid input config {config}")
95
+ if len(config) != 1:
96
+ raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}")
97
+
98
+ config = deepcopy(config)
99
+ key = list(config.keys())[0]
100
+
101
+ if isinstance(config[key], list):
102
+ args, kwargs = config[key], {}
103
+ elif isinstance(config[key], dict):
104
+ args, kwargs = [], config[key]
105
+ else:
106
+ raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}")
107
+
108
+ cls = {
109
+ "data_sources": DataSourcesAction,
110
+ "concat": ConcatAction,
111
+ "join": JoinAction,
112
+ "pipe": PipeAction,
113
+ "function": FunctionAction,
114
+ "repeated_dates": RepeatedDatesAction,
115
+ }.get(key)
116
+
117
+ if cls is None:
118
+ if not is_function(key, "sources"):
119
+ raise ValueError(f"Unknown action '{key}' in {config}")
120
+ cls = FunctionAction
121
+ args = [key] + args
122
+
123
+ return cls(context, action_path + [key], *args, **kwargs)
@@ -0,0 +1,92 @@
1
+ # (C) Copyright 2024 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+ #
9
+ import logging
10
+ from copy import deepcopy
11
+ from functools import cached_property
12
+
13
+ from anemoi.datasets.dates import DatesProvider
14
+
15
+ from .action import Action
16
+ from .action import action_factory
17
+ from .empty import EmptyResult
18
+ from .misc import _tidy
19
+ from .misc import assert_fieldlist
20
+ from .result import Result
21
+ from .template import notify_result
22
+ from .trace import trace_datasource
23
+ from .trace import trace_select
24
+
25
+ LOG = logging.getLogger(__name__)
26
+
27
+
28
+ class ConcatResult(Result):
29
+ def __init__(self, context, action_path, group_of_dates, results, **kwargs):
30
+ super().__init__(context, action_path, group_of_dates)
31
+ self.results = [r for r in results if not r.empty]
32
+
33
+ @cached_property
34
+ @assert_fieldlist
35
+ @notify_result
36
+ @trace_datasource
37
+ def datasource(self):
38
+ ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
39
+ for i in self.results:
40
+ ds += i.datasource
41
+ return _tidy(ds)
42
+
43
+ @property
44
+ def variables(self):
45
+ """Check that all the results objects have the same variables."""
46
+ variables = None
47
+ for f in self.results:
48
+ if f.empty:
49
+ continue
50
+ if variables is None:
51
+ variables = f.variables
52
+ assert variables == f.variables, (variables, f.variables)
53
+ assert variables is not None, self.results
54
+ return variables
55
+
56
+ def __repr__(self):
57
+ content = "\n".join([str(i) for i in self.results])
58
+ return super().__repr__(content)
59
+
60
+
61
+ class ConcatAction(Action):
62
+ def __init__(self, context, action_path, *configs):
63
+ super().__init__(context, action_path, *configs)
64
+ parts = []
65
+ for i, cfg in enumerate(configs):
66
+ if "dates" not in cfg:
67
+ raise ValueError(f"Missing 'dates' in {cfg}")
68
+ cfg = deepcopy(cfg)
69
+ dates_cfg = cfg.pop("dates")
70
+ assert isinstance(dates_cfg, dict), dates_cfg
71
+ filtering_dates = DatesProvider.from_config(**dates_cfg)
72
+ action = action_factory(cfg, context, action_path + [str(i)])
73
+ parts.append((filtering_dates, action))
74
+ self.parts = parts
75
+
76
+ def __repr__(self):
77
+ content = "\n".join([str(i) for i in self.parts])
78
+ return super().__repr__(content)
79
+
80
+ @trace_select
81
+ def select(self, group_of_dates):
82
+ from anemoi.datasets.dates.groups import GroupOfDates
83
+
84
+ results = []
85
+ for filtering_dates, action in self.parts:
86
+ newdates = GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider)
87
+ if newdates:
88
+ results.append(action.select(newdates))
89
+ if not results:
90
+ return EmptyResult(self.context, self.action_path, group_of_dates)
91
+
92
+ return ConcatResult(self.context, self.action_path, group_of_dates, results)
@@ -0,0 +1,59 @@
1
+ # (C) Copyright 2024 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+ #
9
+ import logging
10
+ import textwrap
11
+
12
+ from anemoi.utils.dates import as_datetime as as_datetime
13
+ from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
14
+ from anemoi.utils.humanize import plural
15
+
16
+ from anemoi.datasets.dates import DatesProvider as DatesProvider
17
+ from anemoi.datasets.fields import FieldArray as FieldArray
18
+ from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
19
+
20
+ from .trace import step
21
+ from .trace import trace
22
+
23
+ LOG = logging.getLogger(__name__)
24
+
25
+
26
+ class Context:
27
+ def __init__(self):
28
+ # used_references is a set of reference paths that will be needed
29
+ self.used_references = set()
30
+ # results is a dictionary of reference path -> obj
31
+ self.results = {}
32
+
33
+ def will_need_reference(self, key):
34
+ assert isinstance(key, (list, tuple)), key
35
+ key = tuple(key)
36
+ self.used_references.add(key)
37
+
38
+ def notify_result(self, key, result):
39
+ trace(
40
+ "🎯",
41
+ step(key),
42
+ "notify result",
43
+ textwrap.shorten(repr(result).replace(",", ", "), width=40),
44
+ plural(len(result), "field"),
45
+ )
46
+ assert isinstance(key, (list, tuple)), key
47
+ key = tuple(key)
48
+ if key in self.used_references:
49
+ if key in self.results:
50
+ raise ValueError(f"Duplicate result {key}")
51
+ self.results[key] = result
52
+
53
+ def get_result(self, key):
54
+ assert isinstance(key, (list, tuple)), key
55
+ key = tuple(key)
56
+ if key in self.results:
57
+ return self.results[key]
58
+ all_keys = sorted(list(self.results.keys()))
59
+ raise ValueError(f"Cannot find result {key} in {all_keys}")
@@ -0,0 +1,71 @@
1
+ # (C) Copyright 2024 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+ #
9
+ import logging
10
+ from functools import cached_property
11
+
12
+ from anemoi.utils.dates import as_datetime as as_datetime
13
+ from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
14
+
15
+ from anemoi.datasets.dates import DatesProvider as DatesProvider
16
+ from anemoi.datasets.fields import FieldArray as FieldArray
17
+ from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
18
+
19
+ from .action import Action
20
+ from .action import action_factory
21
+ from .misc import _tidy
22
+ from .result import Result
23
+
24
+ LOG = logging.getLogger(__name__)
25
+
26
+
27
+ class DataSourcesAction(Action):
28
+ def __init__(self, context, action_path, sources, input):
29
+ super().__init__(context, ["data_sources"], *sources)
30
+ if isinstance(sources, dict):
31
+ configs = [(str(k), c) for k, c in sources.items()]
32
+ elif isinstance(sources, list):
33
+ configs = [(str(i), c) for i, c in enumerate(sources)]
34
+ else:
35
+ raise ValueError(f"Invalid data_sources, expecting list or dict, got {type(sources)}: {sources}")
36
+
37
+ self.sources = [action_factory(config, context, ["data_sources"] + [a_path]) for a_path, config in configs]
38
+ self.input = action_factory(input, context, ["input"])
39
+
40
+ def select(self, group_of_dates):
41
+ sources_results = [a.select(group_of_dates) for a in self.sources]
42
+ return DataSourcesResult(
43
+ self.context,
44
+ self.action_path,
45
+ group_of_dates,
46
+ self.input.select(group_of_dates),
47
+ sources_results,
48
+ )
49
+
50
+ def __repr__(self):
51
+ content = "\n".join([str(i) for i in self.sources])
52
+ return super().__repr__(content)
53
+
54
+
55
+ class DataSourcesResult(Result):
56
+ def __init__(self, context, action_path, dates, input_result, sources_results):
57
+ super().__init__(context, action_path, dates)
58
+ # result is the main input result
59
+ self.input_result = input_result
60
+ # sources_results is the list of the sources_results
61
+ self.sources_results = sources_results
62
+
63
+ @cached_property
64
+ def datasource(self):
65
+ for i in self.sources_results:
66
+ # for each result trigger the datasource to be computed
67
+ # and saved in context
68
+ self.context.notify_result(i.action_path[:-1], i.datasource)
69
+ # then return the input result
70
+ # which can use the datasources of the included results
71
+ return _tidy(self.input_result.datasource)