anemoi-datasets 0.5.27__py3-none-any.whl → 0.5.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/recipe/__init__.py +93 -0
  3. anemoi/datasets/commands/recipe/format.py +55 -0
  4. anemoi/datasets/commands/recipe/migrate.py +555 -0
  5. anemoi/datasets/create/__init__.py +46 -13
  6. anemoi/datasets/create/config.py +52 -53
  7. anemoi/datasets/create/input/__init__.py +43 -63
  8. anemoi/datasets/create/input/action.py +296 -236
  9. anemoi/datasets/create/input/context/__init__.py +71 -0
  10. anemoi/datasets/create/input/context/field.py +54 -0
  11. anemoi/datasets/create/input/data_sources.py +2 -1
  12. anemoi/datasets/create/input/misc.py +0 -71
  13. anemoi/datasets/create/input/repeated_dates.py +0 -114
  14. anemoi/datasets/create/input/result/__init__.py +17 -0
  15. anemoi/datasets/create/input/{result.py → result/field.py} +10 -92
  16. anemoi/datasets/create/sources/accumulate.py +517 -0
  17. anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
  18. anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
  19. anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +149 -0
  20. anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
  21. anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
  22. anemoi/datasets/create/sources/constants.py +39 -38
  23. anemoi/datasets/create/sources/empty.py +26 -22
  24. anemoi/datasets/create/sources/forcings.py +29 -28
  25. anemoi/datasets/create/sources/grib.py +92 -72
  26. anemoi/datasets/create/sources/grib_index.py +102 -54
  27. anemoi/datasets/create/sources/hindcasts.py +56 -55
  28. anemoi/datasets/create/sources/legacy.py +10 -62
  29. anemoi/datasets/create/sources/mars.py +159 -154
  30. anemoi/datasets/create/sources/netcdf.py +28 -24
  31. anemoi/datasets/create/sources/opendap.py +28 -24
  32. anemoi/datasets/create/sources/recentre.py +42 -41
  33. anemoi/datasets/create/sources/repeated_dates.py +44 -0
  34. anemoi/datasets/create/sources/source.py +26 -48
  35. anemoi/datasets/create/sources/xarray_support/__init__.py +30 -24
  36. anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
  37. anemoi/datasets/create/sources/xarray_support/field.py +4 -4
  38. anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
  39. anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
  40. anemoi/datasets/create/sources/xarray_zarr.py +28 -24
  41. anemoi/datasets/create/sources/zenodo.py +43 -39
  42. anemoi/datasets/create/utils.py +0 -42
  43. anemoi/datasets/data/complement.py +26 -17
  44. anemoi/datasets/data/dataset.py +12 -0
  45. anemoi/datasets/data/grids.py +0 -152
  46. anemoi/datasets/data/masked.py +74 -13
  47. anemoi/datasets/data/missing.py +5 -0
  48. anemoi/datasets/data/rolling_average.py +141 -0
  49. anemoi/datasets/data/stores.py +7 -9
  50. anemoi/datasets/dates/__init__.py +2 -0
  51. anemoi/datasets/dumper.py +76 -0
  52. anemoi/datasets/grids.py +1 -178
  53. anemoi/datasets/schemas/recipe.json +131 -0
  54. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/METADATA +9 -6
  55. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/RECORD +59 -57
  56. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/WHEEL +1 -1
  57. anemoi/datasets/create/filter.py +0 -47
  58. anemoi/datasets/create/input/concat.py +0 -161
  59. anemoi/datasets/create/input/context.py +0 -86
  60. anemoi/datasets/create/input/empty.py +0 -53
  61. anemoi/datasets/create/input/filter.py +0 -117
  62. anemoi/datasets/create/input/function.py +0 -232
  63. anemoi/datasets/create/input/join.py +0 -129
  64. anemoi/datasets/create/input/pipe.py +0 -66
  65. anemoi/datasets/create/input/step.py +0 -173
  66. anemoi/datasets/create/input/template.py +0 -161
  67. anemoi/datasets/create/sources/accumulations.py +0 -1062
  68. anemoi/datasets/create/sources/accumulations2.py +0 -647
  69. anemoi/datasets/create/sources/tendencies.py +0 -198
  70. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/entry_points.txt +0 -0
  71. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/licenses/LICENSE +0 -0
  72. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,44 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import logging
12
+ from typing import Any
13
+
14
+ from anemoi.transform.fields import new_field_with_valid_datetime
15
+ from anemoi.transform.fields import new_fieldlist_from_list
16
+
17
+ from anemoi.datasets.create.input.repeated_dates import DateMapper
18
+ from anemoi.datasets.create.source import Source
19
+ from anemoi.datasets.create.sources import source_registry
20
+
21
+ LOG = logging.getLogger(__name__)
22
+
23
+
24
+ @source_registry.register("repeated_dates")
25
+ class RepeatedDatesSource(Source):
26
+
27
+ def __init__(self, context, source: Any, mode: str, **kwargs) -> None:
28
+ # assert False, (context, source, mode, kwargs)
29
+ super().__init__(context, **kwargs)
30
+ self.mapper = DateMapper.from_mode(mode, source, kwargs)
31
+ self.source = source
32
+
33
+ def execute(self, group_of_dates):
34
+ source = self.context.create_source(self.source, "data_sources", str(id(self)))
35
+
36
+ result = []
37
+ for one_date_group, many_dates_group in self.mapper.transform(group_of_dates):
38
+ print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}")
39
+ source_results = source(self.context, one_date_group)
40
+ for field in source_results:
41
+ for date in many_dates_group:
42
+ result.append(new_field_with_valid_datetime(field, date))
43
+
44
+ return new_fieldlist_from_list(result)
@@ -12,58 +12,36 @@ from typing import Any
12
12
 
13
13
  from earthkit.data import from_source
14
14
 
15
- from anemoi.datasets.create.utils import to_datetime_list
15
+ from anemoi.datasets.create.sources import source_registry
16
16
 
17
- from .legacy import legacy_source
17
+ from .legacy import LegacySource
18
18
 
19
19
 
20
- @legacy_source(__file__)
21
- def source(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any:
22
- """Generates a source based on the provided context, dates, and additional keyword arguments.
20
+ @source_registry.register("source")
21
+ class GenericSource(LegacySource):
23
22
 
24
- Parameters
25
- ----------
26
- context : Optional[Any]
27
- The context in which the source is generated.
28
- dates : List[datetime]
29
- A list of datetime objects representing the dates.
30
- **kwargs : Any
31
- Additional keyword arguments for the source generation.
23
+ @staticmethod
24
+ def _execute(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any:
25
+ """Generates a source based on the provided context, dates, and additional keyword arguments.
32
26
 
33
- Returns
34
- -------
35
- Any
36
- The generated source.
37
- """
38
- name = kwargs.pop("name")
39
- context.trace("✅", f"from_source({name}, {dates}, {kwargs}")
40
- if kwargs["date"] == "$from_dates":
41
- kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates})
42
- if kwargs["time"] == "$from_dates":
43
- kwargs["time"] = list({d.strftime("%H%M") for d in dates})
44
- return from_source(name, **kwargs)
27
+ Parameters
28
+ ----------
29
+ context : Optional[Any]
30
+ The context in which the source is generated.
31
+ dates : List[datetime]
32
+ A list of datetime objects representing the dates.
33
+ **kwargs : Any
34
+ Additional keyword arguments for the source generation.
45
35
 
46
-
47
- execute = source
48
-
49
- if __name__ == "__main__":
50
- import yaml
51
-
52
- config: dict[str, Any] = yaml.safe_load(
36
+ Returns
37
+ -------
38
+ Any
39
+ The generated source.
53
40
  """
54
- name: mars
55
- class: ea
56
- expver: '0001'
57
- grid: 20.0/20.0
58
- levtype: sfc
59
- param: [2t]
60
- number: [0, 1]
61
- date: $from_dates
62
- time: $from_dates
63
- """
64
- )
65
- dates: list[str] = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]")
66
- dates = to_datetime_list(dates)
67
-
68
- for f in source(None, dates, **config):
69
- print(f, f.to_numpy().mean())
41
+ name = kwargs.pop("name")
42
+ context.trace("✅", f"from_source({name}, {dates}, {kwargs}")
43
+ if kwargs["date"] == "$from_dates":
44
+ kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates})
45
+ if kwargs["time"] == "$from_dates":
46
+ kwargs["time"] = list({d.strftime("%H%M") for d in dates})
47
+ return from_source(name, **kwargs)
@@ -17,7 +17,8 @@ from earthkit.data.core.fieldlist import MultiFieldList
17
17
 
18
18
  from anemoi.datasets.create.sources.patterns import iterate_patterns
19
19
 
20
- from ..legacy import legacy_source
20
+ from .. import source_registry
21
+ from ..legacy import LegacySource
21
22
  from .fieldlist import XarrayFieldList
22
23
 
23
24
  LOG = logging.getLogger(__name__)
@@ -96,6 +97,7 @@ def load_one(
96
97
  if isinstance(dataset, xr.Dataset):
97
98
  data = dataset
98
99
  else:
100
+ print(f"Opening dataset {dataset} with options {options}")
99
101
  data = xr.open_dataset(dataset, **options)
100
102
 
101
103
  fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
@@ -152,26 +154,30 @@ def load_many(emoji: str, context: Any, dates: list[datetime.datetime], pattern:
152
154
  return MultiFieldList(result)
153
155
 
154
156
 
155
- @legacy_source("xarray")
156
- def execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList:
157
- """Executes the loading of datasets.
158
-
159
- Parameters
160
- ----------
161
- context : Any
162
- Context object.
163
- dates : List[str]
164
- List of dates.
165
- url : str
166
- URL pattern for loading datasets.
167
- *args : Any
168
- Additional arguments.
169
- **kwargs : Any
170
- Additional keyword arguments.
171
-
172
- Returns
173
- -------
174
- ekd.FieldList
175
- The loaded datasets.
176
- """
177
- return load_many("🌐", context, dates, url, *args, **kwargs)
157
+ @source_registry.register("xarray")
158
+ class LegacyXarraySource(LegacySource):
159
+ name = "xarray"
160
+
161
+ @staticmethod
162
+ def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList:
163
+ """Executes the loading of datasets.
164
+
165
+ Parameters
166
+ ----------
167
+ context : Any
168
+ Context object.
169
+ dates : List[str]
170
+ List of dates.
171
+ url : str
172
+ URL pattern for loading datasets.
173
+ *args : Any
174
+ Additional arguments.
175
+ **kwargs : Any
176
+ Additional keyword arguments.
177
+
178
+ Returns
179
+ -------
180
+ ekd.FieldList
181
+ The loaded datasets.
182
+ """
183
+ return load_many("🌐", context, dates, url, *args, **kwargs)
@@ -223,13 +223,10 @@ class Coordinate:
223
223
  # Assume the array is sorted
224
224
 
225
225
  index = np.searchsorted(values, value)
226
- index = index[index < len(values)]
227
-
228
- if np.all(values[index] == value):
226
+ if np.all(index < len(values)) and np.all(values[index] == value):
229
227
  return index
230
228
 
231
229
  # If not found, we need to check if the value is in the array
232
-
233
230
  index = np.where(np.isin(values, value))[0]
234
231
 
235
232
  # We could also return incomplete matches
@@ -121,16 +121,16 @@ class XArrayField(Field):
121
121
  Index to select a specific element, by default None.
122
122
  """
123
123
  if index is not None:
124
- values = self.selection[index]
124
+ values = self.selection[index].values
125
125
  else:
126
- values = self.selection
126
+ values = self.selection.values
127
127
 
128
128
  assert dtype is None
129
129
 
130
130
  if flatten:
131
- return values.values.flatten()
131
+ return values.flatten()
132
132
 
133
- return values # .reshape(self.shape)
133
+ return values
134
134
 
135
135
  @cached_property
136
136
  def _metadata(self) -> XArrayMetadata:
@@ -557,10 +557,10 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
557
557
  super().__init__(ds)
558
558
 
559
559
  def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> PointCoordinate | None:
560
- if attributes.standard_name in ["cell", "station", "poi", "point"]:
560
+ if attributes.standard_name in ["location", "cell", "id", "station", "poi", "point"]:
561
561
  return PointCoordinate(c)
562
562
 
563
- if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench
563
+ if attributes.name in ["location", "cell", "id", "station", "poi", "point"]: # WeatherBench
564
564
  return PointCoordinate(c)
565
565
 
566
566
  return None
@@ -10,13 +10,14 @@
10
10
 
11
11
  import logging
12
12
  from typing import Any
13
+ from typing import Literal
13
14
 
14
15
  import xarray as xr
15
16
 
16
17
  LOG = logging.getLogger(__name__)
17
18
 
18
19
 
19
- def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) -> Any:
20
+ def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) -> xr.Dataset:
20
21
  """Patch the attributes of the dataset.
21
22
 
22
23
  Parameters
@@ -38,7 +39,7 @@ def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) -> A
38
39
  return ds
39
40
 
40
41
 
41
- def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) -> Any:
42
+ def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) -> xr.Dataset:
42
43
  """Patch the coordinates of the dataset.
43
44
 
44
45
  Parameters
@@ -59,7 +60,7 @@ def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) -> Any:
59
60
  return ds
60
61
 
61
62
 
62
- def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any:
63
+ def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> xr.Dataset:
63
64
  """Rename variables in the dataset.
64
65
 
65
66
  Parameters
@@ -77,7 +78,7 @@ def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any:
77
78
  return ds.rename(renames)
78
79
 
79
80
 
80
- def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) -> Any:
81
+ def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) -> xr.Dataset:
81
82
  """Sort the coordinates of the dataset.
82
83
 
83
84
  Parameters
@@ -98,11 +99,175 @@ def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) -> Any:
98
99
  return ds
99
100
 
100
101
 
102
+ def patch_subset_dataset(ds: xr.Dataset, selection: dict[str, Any]) -> xr.Dataset:
103
+ """Select a subset of the dataset using xarray's sel method.
104
+
105
+ Parameters
106
+ ----------
107
+ ds : xr.Dataset
108
+ The dataset to patch.
109
+ selection : dict[str, Any]
110
+ Dictionary mapping dimension names to selection criteria.
111
+ Keys must be existing dimension names in the dataset.
112
+ Values can be any type accepted by xarray's sel method, including:
113
+ - Single values (int, float, str, datetime)
114
+ - Lists or arrays of values
115
+ - Slices (using slice() objects)
116
+ - Boolean arrays
117
+
118
+ Returns
119
+ -------
120
+ xr.Dataset
121
+ The patched dataset containing only the selected subset.
122
+
123
+ Examples
124
+ --------
125
+ >>> # Select specific time and pressure level
126
+ >>> patch_subset_dataset(ds, {
127
+ ... 'time': '2020-01-01',
128
+ ... 'pressure': 500
129
+ ... })
130
+
131
+ >>> # Select a range using slice
132
+ >>> patch_subset_dataset(ds, {
133
+ ... 'lat': slice(-90, 90),
134
+ ... 'lon': slice(0, 180)
135
+ ... })
136
+ """
137
+
138
+ ds = ds.sel(selection)
139
+
140
+ return ds
141
+
142
+
143
+ def patch_analysis_lead_to_valid_time(
144
+ ds: xr.Dataset,
145
+ time_coord_names: dict[Literal["analysis_time_coordinate", "lead_time_coordinate", "valid_time_coordinate"], str],
146
+ ) -> xr.Dataset:
147
+ """Convert analysis time and lead time coordinates to valid time.
148
+
149
+ This function creates a new valid time coordinate by adding the analysis time
150
+ and lead time coordinates, then stacks and reorganizes the dataset to use
151
+ valid time as the primary time dimension.
152
+
153
+ Parameters
154
+ ----------
155
+ ds : xr.Dataset
156
+ The dataset to patch.
157
+ time_coord_names : dict[str, str]
158
+ Dictionary mapping required keys to coordinate names in the dataset:
159
+
160
+ - 'analysis_time_coordinate' : str
161
+ Name of the analysis/initialization time coordinate.
162
+ - 'lead_time_coordinate' : str
163
+ Name of the forecast lead time coordinate.
164
+ - 'valid_time_coordinate' : str
165
+ Name for the new valid time coordinate to create.
166
+
167
+ Returns
168
+ -------
169
+ xr.Dataset
170
+ The patched dataset with valid time as the primary time coordinate.
171
+ The analysis and lead time coordinates are removed.
172
+
173
+ Examples
174
+ --------
175
+ >>> patch_analysis_lead_to_valid_time(ds, {
176
+ ... 'analysis_time_coordinate': 'forecast_reference_time',
177
+ ... 'lead_time_coordinate': 'step',
178
+ ... 'valid_time_coordinate': 'time'
179
+ ... })
180
+ """
181
+
182
+ assert time_coord_names.keys() == {
183
+ "analysis_time_coordinate",
184
+ "lead_time_coordinate",
185
+ "valid_time_coordinate",
186
+ }, "time_coord_names must contain exactly keys 'analysis_time_coordinate', 'lead_time_coordinate', and 'valid_time_coordinate'"
187
+
188
+ analysis_time_coordinate = time_coord_names["analysis_time_coordinate"]
189
+ lead_time_coordinate = time_coord_names["lead_time_coordinate"]
190
+ valid_time_coordinate = time_coord_names["valid_time_coordinate"]
191
+
192
+ valid_time = ds[analysis_time_coordinate] + ds[lead_time_coordinate]
193
+
194
+ ds = (
195
+ ds.assign_coords({valid_time_coordinate: valid_time})
196
+ .stack(time_index=[analysis_time_coordinate, lead_time_coordinate])
197
+ .set_index(time_index=valid_time_coordinate)
198
+ .rename(time_index=valid_time_coordinate)
199
+ .drop_vars([analysis_time_coordinate, lead_time_coordinate])
200
+ )
201
+
202
+ return ds
203
+
204
+
205
+ def patch_rolling_operation(
206
+ ds: xr.Dataset, vars_operation_config: dict[Literal["dim", "steps", "vars", "operation"], str | int | list[str]]
207
+ ) -> xr.Dataset:
208
+ """Apply a rolling operation to specified variables in the dataset.
209
+
210
+ This function calculates a rolling operation over a specified dimension for selected
211
+ variables. The rolling window requires all periods to be present (min_periods=steps).
212
+
213
+ Parameters
214
+ ----------
215
+ ds : xr.Dataset
216
+ The dataset to patch.
217
+ vars_operation_config: dict
218
+ Configuration for the rolling operation with the following keys:
219
+
220
+ - 'dim' : str
221
+ The dimension along which to apply the rolling operation (e.g., 'time').
222
+ - 'steps' : int
223
+ The number of steps in the rolling window.
224
+ - 'vars' : list[str]
225
+ List of variable names to apply the rolling operation to.
226
+ - 'operation' : str
227
+ The operation to apply ('sum', 'mean', 'min', 'max', 'std', etc.).
228
+
229
+ Returns
230
+ -------
231
+ xr.Dataset
232
+ The patched dataset with rolling operations applied to the specified variables.
233
+
234
+ Examples
235
+ --------
236
+ >>> patch_rolling_operation(ds, {
237
+ ... 'dim': 'time',
238
+ ... 'steps': 3,
239
+ ... 'vars': ['precipitation', 'radiation'],
240
+ ... 'operation': 'sum'
241
+ ... })
242
+ """
243
+
244
+ assert vars_operation_config.keys() == {
245
+ "dim",
246
+ "steps",
247
+ "vars",
248
+ "operation",
249
+ }, "vars_operation_config must contain exactly keys 'dim', 'steps', 'vars', and 'operation'"
250
+
251
+ dim = vars_operation_config["dim"]
252
+ steps = vars_operation_config["steps"]
253
+ vars = vars_operation_config["vars"]
254
+ operation = vars_operation_config["operation"]
255
+
256
+ for var in vars:
257
+ rolling = ds[var].rolling(dim={dim: steps}, min_periods=steps)
258
+ ds[var] = getattr(rolling, operation)()
259
+
260
+ return ds
261
+
262
+
101
263
  PATCHES = {
102
264
  "attributes": patch_attributes,
103
265
  "coordinates": patch_coordinates,
104
266
  "rename": patch_rename,
105
267
  "sort_coordinates": patch_sort_coordinate,
268
+ "analysis_lead_to_valid_time": patch_analysis_lead_to_valid_time,
269
+ "rolling_operation": patch_rolling_operation,
270
+ "subset_dataset": patch_subset_dataset,
106
271
  }
107
272
 
108
273
 
@@ -122,7 +287,15 @@ def patch_dataset(ds: xr.Dataset, patch: dict[str, dict[str, Any]]) -> Any:
122
287
  The patched dataset.
123
288
  """
124
289
 
125
- ORDER = ["coordinates", "attributes", "rename", "sort_coordinates"]
290
+ ORDER = [
291
+ "coordinates",
292
+ "attributes",
293
+ "rename",
294
+ "sort_coordinates",
295
+ "subset_dataset",
296
+ "analysis_lead_to_valid_time",
297
+ "rolling_operation",
298
+ ]
126
299
  for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])):
127
300
  if what not in PATCHES:
128
301
  raise ValueError(f"Unknown patch type {what!r}")
@@ -11,30 +11,34 @@ from typing import Any
11
11
 
12
12
  import earthkit.data as ekd
13
13
 
14
- from .legacy import legacy_source
14
+ from . import source_registry
15
+ from .legacy import LegacySource
15
16
  from .xarray import load_many
16
17
 
17
18
 
18
- @legacy_source(__file__)
19
- def execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList:
20
- """Execute the data loading process.
21
-
22
- Parameters
23
- ----------
24
- context : Any
25
- The context in which the execution occurs.
26
- dates : List[str]
27
- List of dates for which data is to be loaded.
28
- url : str
29
- The URL from which data is to be loaded.
30
- *args : tuple
31
- Additional positional arguments.
32
- **kwargs : dict
33
- Additional keyword arguments.
34
-
35
- Returns
36
- -------
37
- ekd.FieldList
38
- The loaded data.
39
- """
40
- return load_many("🇿", context, dates, url, *args, **kwargs)
19
+ @source_registry.register("xarray_zarr")
20
+ class XarrayZarrSource(LegacySource):
21
+
22
+ @staticmethod
23
+ def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList:
24
+ """Execute the data loading process.
25
+
26
+ Parameters
27
+ ----------
28
+ context : Any
29
+ The context in which the execution occurs.
30
+ dates : List[str]
31
+ List of dates for which data is to be loaded.
32
+ url : str
33
+ The URL from which data is to be loaded.
34
+ *args : tuple
35
+ Additional positional arguments.
36
+ **kwargs : dict
37
+ Additional keyword arguments.
38
+
39
+ Returns
40
+ -------
41
+ ekd.FieldList
42
+ The loaded data.
43
+ """
44
+ return load_many("🇿", context, dates, url, *args, **kwargs)
@@ -14,54 +14,58 @@ import earthkit.data as ekd
14
14
  from earthkit.data.core.fieldlist import MultiFieldList
15
15
  from earthkit.data.sources.url import download_and_cache
16
16
 
17
- from .legacy import legacy_source
17
+ from . import source_registry
18
+ from .legacy import LegacySource
18
19
  from .patterns import iterate_patterns
19
20
  from .xarray import load_one
20
21
 
21
22
 
22
- @legacy_source(__file__)
23
- def execute(context: Any, dates: Any, record_id: str, file_key: str, *args: Any, **kwargs: Any) -> ekd.FieldList:
24
- """Executes the download and processing of files from Zenodo.
23
+ @source_registry.register("zenodo")
24
+ class ZenodoSource(LegacySource):
25
25
 
26
- Parameters
27
- ----------
28
- context : Any
29
- The context in which the function is executed.
30
- dates : Any
31
- The dates for which the data is required.
32
- record_id : str
33
- The Zenodo record ID.
34
- file_key : str
35
- The key to identify the file.
36
- *args : Any
37
- Additional arguments.
38
- **kwargs : Any
39
- Additional keyword arguments.
26
+ @staticmethod
27
+ def _execute(context: Any, dates: Any, record_id: str, file_key: str, *args: Any, **kwargs: Any) -> ekd.FieldList:
28
+ """Executes the download and processing of files from Zenodo.
40
29
 
41
- Returns
42
- -------
43
- MultiFieldList
44
- A list of fields loaded from the downloaded files.
45
- """
46
- import requests
30
+ Parameters
31
+ ----------
32
+ context : Any
33
+ The context in which the function is executed.
34
+ dates : Any
35
+ The dates for which the data is required.
36
+ record_id : str
37
+ The Zenodo record ID.
38
+ file_key : str
39
+ The key to identify the file.
40
+ *args : Any
41
+ Additional arguments.
42
+ **kwargs : Any
43
+ Additional keyword arguments.
47
44
 
48
- result: list[Any] = []
45
+ Returns
46
+ -------
47
+ MultiFieldList
48
+ A list of fields loaded from the downloaded files.
49
+ """
50
+ import requests
49
51
 
50
- URLPATTERN = "https://zenodo.org/api/records/{record_id}"
51
- url = URLPATTERN.format(record_id=record_id)
52
- r = requests.get(url)
53
- r.raise_for_status()
54
- record: dict[str, Any] = r.json()
52
+ result: list[Any] = []
55
53
 
56
- urls: dict[str, str] = {}
57
- for file in record["files"]:
58
- urls[file["key"]] = file["links"]["self"]
54
+ URLPATTERN = "https://zenodo.org/api/records/{record_id}"
55
+ url = URLPATTERN.format(record_id=record_id)
56
+ r = requests.get(url)
57
+ r.raise_for_status()
58
+ record: dict[str, Any] = r.json()
59
59
 
60
- for url, dates in iterate_patterns(file_key, dates, **kwargs):
61
- if url not in urls:
62
- continue
60
+ urls: dict[str, str] = {}
61
+ for file in record["files"]:
62
+ urls[file["key"]] = file["links"]["self"]
63
63
 
64
- path = download_and_cache(urls[url])
65
- result.append(load_one("?", context, dates, path, options={}, flavour=None, **kwargs))
64
+ for url, dates in iterate_patterns(file_key, dates, **kwargs):
65
+ if url not in urls:
66
+ continue
66
67
 
67
- return MultiFieldList(result)
68
+ path = download_and_cache(urls[url])
69
+ result.append(load_one("?", context, dates, path, options={}, flavour=None, **kwargs))
70
+
71
+ return MultiFieldList(result)