anemoi-datasets 0.5.26__py3-none-any.whl → 0.5.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. anemoi/datasets/__init__.py +1 -2
  2. anemoi/datasets/_version.py +16 -3
  3. anemoi/datasets/commands/check.py +1 -1
  4. anemoi/datasets/commands/copy.py +1 -2
  5. anemoi/datasets/commands/create.py +1 -1
  6. anemoi/datasets/commands/inspect.py +27 -35
  7. anemoi/datasets/commands/recipe/__init__.py +93 -0
  8. anemoi/datasets/commands/recipe/format.py +55 -0
  9. anemoi/datasets/commands/recipe/migrate.py +555 -0
  10. anemoi/datasets/commands/validate.py +59 -0
  11. anemoi/datasets/compute/recentre.py +3 -6
  12. anemoi/datasets/create/__init__.py +64 -26
  13. anemoi/datasets/create/check.py +10 -12
  14. anemoi/datasets/create/chunks.py +1 -2
  15. anemoi/datasets/create/config.py +5 -6
  16. anemoi/datasets/create/input/__init__.py +44 -65
  17. anemoi/datasets/create/input/action.py +296 -238
  18. anemoi/datasets/create/input/context/__init__.py +71 -0
  19. anemoi/datasets/create/input/context/field.py +54 -0
  20. anemoi/datasets/create/input/data_sources.py +7 -9
  21. anemoi/datasets/create/input/misc.py +2 -75
  22. anemoi/datasets/create/input/repeated_dates.py +11 -130
  23. anemoi/datasets/{utils → create/input/result}/__init__.py +10 -1
  24. anemoi/datasets/create/input/{result.py → result/field.py} +36 -120
  25. anemoi/datasets/create/input/trace.py +1 -1
  26. anemoi/datasets/create/patch.py +1 -2
  27. anemoi/datasets/create/persistent.py +3 -5
  28. anemoi/datasets/create/size.py +1 -3
  29. anemoi/datasets/create/sources/accumulations.py +120 -145
  30. anemoi/datasets/create/sources/accumulations2.py +20 -53
  31. anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
  32. anemoi/datasets/create/sources/constants.py +39 -40
  33. anemoi/datasets/create/sources/empty.py +22 -19
  34. anemoi/datasets/create/sources/fdb.py +133 -0
  35. anemoi/datasets/create/sources/forcings.py +29 -29
  36. anemoi/datasets/create/sources/grib.py +94 -78
  37. anemoi/datasets/create/sources/grib_index.py +57 -55
  38. anemoi/datasets/create/sources/hindcasts.py +57 -59
  39. anemoi/datasets/create/sources/legacy.py +10 -62
  40. anemoi/datasets/create/sources/mars.py +121 -149
  41. anemoi/datasets/create/sources/netcdf.py +28 -25
  42. anemoi/datasets/create/sources/opendap.py +28 -26
  43. anemoi/datasets/create/sources/patterns.py +4 -6
  44. anemoi/datasets/create/sources/recentre.py +46 -48
  45. anemoi/datasets/create/sources/repeated_dates.py +44 -0
  46. anemoi/datasets/create/sources/source.py +26 -51
  47. anemoi/datasets/create/sources/tendencies.py +68 -98
  48. anemoi/datasets/create/sources/xarray.py +4 -6
  49. anemoi/datasets/create/sources/xarray_support/__init__.py +40 -36
  50. anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -12
  51. anemoi/datasets/create/sources/xarray_support/field.py +20 -16
  52. anemoi/datasets/create/sources/xarray_support/fieldlist.py +11 -15
  53. anemoi/datasets/create/sources/xarray_support/flavour.py +42 -42
  54. anemoi/datasets/create/sources/xarray_support/grid.py +15 -9
  55. anemoi/datasets/create/sources/xarray_support/metadata.py +19 -128
  56. anemoi/datasets/create/sources/xarray_support/patch.py +4 -6
  57. anemoi/datasets/create/sources/xarray_support/time.py +10 -13
  58. anemoi/datasets/create/sources/xarray_support/variable.py +21 -21
  59. anemoi/datasets/create/sources/xarray_zarr.py +28 -25
  60. anemoi/datasets/create/sources/zenodo.py +43 -41
  61. anemoi/datasets/create/statistics/__init__.py +3 -6
  62. anemoi/datasets/create/testing.py +4 -0
  63. anemoi/datasets/create/typing.py +1 -2
  64. anemoi/datasets/create/utils.py +0 -43
  65. anemoi/datasets/create/zarr.py +7 -2
  66. anemoi/datasets/data/__init__.py +15 -6
  67. anemoi/datasets/data/complement.py +7 -12
  68. anemoi/datasets/data/concat.py +5 -8
  69. anemoi/datasets/data/dataset.py +48 -47
  70. anemoi/datasets/data/debug.py +7 -9
  71. anemoi/datasets/data/ensemble.py +4 -6
  72. anemoi/datasets/data/fill_missing.py +7 -10
  73. anemoi/datasets/data/forwards.py +22 -26
  74. anemoi/datasets/data/grids.py +12 -168
  75. anemoi/datasets/data/indexing.py +9 -12
  76. anemoi/datasets/data/interpolate.py +7 -15
  77. anemoi/datasets/data/join.py +8 -12
  78. anemoi/datasets/data/masked.py +6 -11
  79. anemoi/datasets/data/merge.py +5 -9
  80. anemoi/datasets/data/misc.py +41 -45
  81. anemoi/datasets/data/missing.py +11 -16
  82. anemoi/datasets/data/observations/__init__.py +8 -14
  83. anemoi/datasets/data/padded.py +3 -5
  84. anemoi/datasets/data/records/backends/__init__.py +2 -2
  85. anemoi/datasets/data/rescale.py +5 -12
  86. anemoi/datasets/data/rolling_average.py +141 -0
  87. anemoi/datasets/data/select.py +13 -16
  88. anemoi/datasets/data/statistics.py +4 -7
  89. anemoi/datasets/data/stores.py +22 -29
  90. anemoi/datasets/data/subset.py +8 -11
  91. anemoi/datasets/data/unchecked.py +7 -11
  92. anemoi/datasets/data/xy.py +25 -21
  93. anemoi/datasets/dates/__init__.py +15 -18
  94. anemoi/datasets/dates/groups.py +7 -10
  95. anemoi/datasets/dumper.py +76 -0
  96. anemoi/datasets/grids.py +4 -185
  97. anemoi/datasets/schemas/recipe.json +131 -0
  98. anemoi/datasets/testing.py +93 -7
  99. anemoi/datasets/validate.py +598 -0
  100. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/METADATA +7 -4
  101. anemoi_datasets-0.5.28.dist-info/RECORD +134 -0
  102. anemoi/datasets/create/filter.py +0 -48
  103. anemoi/datasets/create/input/concat.py +0 -164
  104. anemoi/datasets/create/input/context.py +0 -89
  105. anemoi/datasets/create/input/empty.py +0 -54
  106. anemoi/datasets/create/input/filter.py +0 -118
  107. anemoi/datasets/create/input/function.py +0 -233
  108. anemoi/datasets/create/input/join.py +0 -130
  109. anemoi/datasets/create/input/pipe.py +0 -66
  110. anemoi/datasets/create/input/step.py +0 -177
  111. anemoi/datasets/create/input/template.py +0 -162
  112. anemoi_datasets-0.5.26.dist-info/RECORD +0 -131
  113. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/WHEEL +0 -0
  114. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/entry_points.txt +0 -0
  115. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/licenses/LICENSE +0 -0
  116. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/top_level.txt +0 -0
@@ -16,8 +16,6 @@ import uuid
16
16
  import warnings
17
17
  from functools import cached_property
18
18
  from typing import Any
19
- from typing import Optional
20
- from typing import Union
21
19
 
22
20
  import cftime
23
21
  import numpy as np
@@ -102,8 +100,8 @@ def json_tidy(o: Any) -> Any:
102
100
 
103
101
  def build_statistics_dates(
104
102
  dates: list[datetime.datetime],
105
- start: Optional[datetime.datetime],
106
- end: Optional[datetime.datetime],
103
+ start: datetime.datetime | None,
104
+ end: datetime.datetime | None,
107
105
  ) -> tuple[str, str]:
108
106
  """Compute the start and end dates for the statistics.
109
107
 
@@ -359,7 +357,7 @@ class Actor: # TODO: rename to Creator
359
357
 
360
358
  dataset_class = WritableDataset
361
359
 
362
- def __init__(self, path: str, cache: Optional[str] = None):
360
+ def __init__(self, path: str, cache: str | None = None):
363
361
  """Initialize an Actor instance.
364
362
 
365
363
  Parameters
@@ -577,10 +575,10 @@ class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
577
575
  check_name: bool = False,
578
576
  overwrite: bool = False,
579
577
  use_threads: bool = False,
580
- statistics_temp_dir: Optional[str] = None,
578
+ statistics_temp_dir: str | None = None,
581
579
  progress: Any = None,
582
580
  test: bool = False,
583
- cache: Optional[str] = None,
581
+ cache: str | None = None,
584
582
  **kwargs: Any,
585
583
  ):
586
584
  """Initialize an Init instance.
@@ -809,11 +807,11 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
809
807
  def __init__(
810
808
  self,
811
809
  path: str,
812
- parts: Optional[str] = None,
810
+ parts: str | None = None,
813
811
  use_threads: bool = False,
814
- statistics_temp_dir: Optional[str] = None,
812
+ statistics_temp_dir: str | None = None,
815
813
  progress: Any = None,
816
- cache: Optional[str] = None,
814
+ cache: str | None = None,
817
815
  **kwargs: Any,
818
816
  ):
819
817
  """Initialize a Load instance.
@@ -867,7 +865,7 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
867
865
  # assert isinstance(group[0], datetime.datetime), type(group[0])
868
866
  LOG.debug(f"Building data for group {igroup}/{self.n_groups}")
869
867
 
870
- result = self.input.select(group_of_dates=group)
868
+ result = self.input.select(argument=group)
871
869
  assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group)
872
870
 
873
871
  # There are several groups.
@@ -907,8 +905,8 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
907
905
  print("Requested dates", compress_dates(dates))
908
906
  print("Cube dates", compress_dates(dates_in_data))
909
907
 
910
- a = set(as_datetime(_) for _ in dates)
911
- b = set(as_datetime(_) for _ in dates_in_data)
908
+ a = {as_datetime(_) for _ in dates}
909
+ b = {as_datetime(_) for _ in dates_in_data}
912
910
 
913
911
  print("Missing dates", compress_dates(a - b))
914
912
  print("Extra dates", compress_dates(b - a))
@@ -958,7 +956,7 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
958
956
  array.flush()
959
957
  LOG.info("Flushed data array")
960
958
 
961
- def _get_allow_nans(self) -> Union[bool, list]:
959
+ def _get_allow_nans(self) -> bool | list:
962
960
  """Get the allow_nans configuration.
963
961
 
964
962
  Returns
@@ -991,7 +989,7 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
991
989
  total = cube.count(reading_chunks)
992
990
  LOG.debug(f"Loading datacube: {cube}")
993
991
 
994
- def position(x: Any) -> Optional[int]:
992
+ def position(x: Any) -> int | None:
995
993
  if isinstance(x, str) and "/" in x:
996
994
  x = x.split("/")
997
995
  return int(x[0])
@@ -1038,7 +1036,7 @@ class Cleanup(Actor, HasRegistryMixin, HasStatisticTempMixin):
1038
1036
  def __init__(
1039
1037
  self,
1040
1038
  path: str,
1041
- statistics_temp_dir: Optional[str] = None,
1039
+ statistics_temp_dir: str | None = None,
1042
1040
  delta: list = [],
1043
1041
  use_threads: bool = False,
1044
1042
  **kwargs: Any,
@@ -1217,19 +1215,19 @@ class _InitAdditions(Actor, HasRegistryMixin, AdditionsMixin):
1217
1215
  LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}")
1218
1216
 
1219
1217
 
1220
- class _RunAdditions(Actor, HasRegistryMixin, AdditionsMixin):
1218
+ class _LoadAdditions(Actor, HasRegistryMixin, AdditionsMixin):
1221
1219
  """A class to run dataset additions."""
1222
1220
 
1223
1221
  def __init__(
1224
1222
  self,
1225
1223
  path: str,
1226
1224
  delta: str,
1227
- parts: Optional[str] = None,
1225
+ parts: str | None = None,
1228
1226
  use_threads: bool = False,
1229
1227
  progress: Any = None,
1230
1228
  **kwargs: Any,
1231
1229
  ):
1232
- """Initialize a _RunAdditions instance.
1230
+ """Initialize a _LoadAdditions instance.
1233
1231
 
1234
1232
  Parameters
1235
1233
  ----------
@@ -1469,7 +1467,7 @@ def multi_addition(cls: type) -> type:
1469
1467
 
1470
1468
 
1471
1469
  InitAdditions = multi_addition(_InitAdditions)
1472
- RunAdditions = multi_addition(_RunAdditions)
1470
+ LoadAdditions = multi_addition(_LoadAdditions)
1473
1471
  FinaliseAdditions = multi_addition(_FinaliseAdditions)
1474
1472
 
1475
1473
 
@@ -1480,7 +1478,7 @@ class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin):
1480
1478
  self,
1481
1479
  path: str,
1482
1480
  use_threads: bool = False,
1483
- statistics_temp_dir: Optional[str] = None,
1481
+ statistics_temp_dir: str | None = None,
1484
1482
  progress: Any = None,
1485
1483
  **kwargs: Any,
1486
1484
  ):
@@ -1539,7 +1537,7 @@ class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin):
1539
1537
  LOG.info(f"Wrote statistics in {self.path}")
1540
1538
 
1541
1539
  @cached_property
1542
- def allow_nans(self) -> Union[bool, list]:
1540
+ def allow_nans(self) -> bool | list:
1543
1541
  """Check if NaNs are allowed."""
1544
1542
  import zarr
1545
1543
 
@@ -1581,7 +1579,7 @@ def chain(tasks: list) -> type:
1581
1579
  return Chain
1582
1580
 
1583
1581
 
1584
- def creator_factory(name: str, trace: Optional[str] = None, **kwargs: Any) -> Any:
1582
+ def creator_factory(name: str, trace: str | None = None, **kwargs: Any) -> Any:
1585
1583
  """Create a dataset creator.
1586
1584
 
1587
1585
  Parameters
@@ -1612,10 +1610,50 @@ def creator_factory(name: str, trace: Optional[str] = None, **kwargs: Any) -> An
1612
1610
  cleanup=Cleanup,
1613
1611
  verify=Verify,
1614
1612
  init_additions=InitAdditions,
1615
- load_additions=RunAdditions,
1616
- run_additions=RunAdditions,
1613
+ load_additions=LoadAdditions,
1617
1614
  finalise_additions=chain([FinaliseAdditions, Size]),
1618
- additions=chain([InitAdditions, RunAdditions, FinaliseAdditions, Size, Cleanup]),
1615
+ additions=chain([InitAdditions, LoadAdditions, FinaliseAdditions, Size, Cleanup]),
1619
1616
  )[name]
1620
1617
  LOG.debug(f"Creating {cls.__name__} with {kwargs}")
1621
1618
  return cls(**kwargs)
1619
+
1620
+
1621
+ def validate_config(config: Any) -> None:
1622
+
1623
+ import json
1624
+
1625
+ import jsonschema
1626
+
1627
+ def _tidy(d):
1628
+ if isinstance(d, dict):
1629
+ return {k: _tidy(v) for k, v in d.items()}
1630
+
1631
+ if isinstance(d, list):
1632
+ return [_tidy(v) for v in d if v is not None]
1633
+
1634
+ # jsonschema does not support datetime.date
1635
+ if isinstance(d, datetime.datetime):
1636
+ return d.isoformat()
1637
+
1638
+ if isinstance(d, datetime.date):
1639
+ return d.isoformat()
1640
+
1641
+ return d
1642
+
1643
+ # https://json-schema.org
1644
+
1645
+ with open(
1646
+ os.path.join(
1647
+ os.path.dirname(os.path.dirname(__file__)),
1648
+ "schemas",
1649
+ "recipe.json",
1650
+ )
1651
+ ) as f:
1652
+ schema = json.load(f)
1653
+
1654
+ try:
1655
+ jsonschema.validate(instance=_tidy(config), schema=schema)
1656
+ except jsonschema.exceptions.ValidationError as e:
1657
+ LOG.error("❌ Config validation failed (jsonschema):")
1658
+ LOG.error(e.message)
1659
+ raise
@@ -12,10 +12,8 @@ import datetime
12
12
  import logging
13
13
  import re
14
14
  import warnings
15
+ from collections.abc import Callable
15
16
  from typing import Any
16
- from typing import Callable
17
- from typing import Optional
18
- from typing import Union
19
17
 
20
18
  import numpy as np
21
19
  from anemoi.utils.config import load_config
@@ -31,10 +29,10 @@ class DatasetName:
31
29
  def __init__(
32
30
  self,
33
31
  name: str,
34
- resolution: Optional[str] = None,
35
- start_date: Optional[datetime.date] = None,
36
- end_date: Optional[datetime.date] = None,
37
- frequency: Optional[datetime.timedelta] = None,
32
+ resolution: str | None = None,
33
+ start_date: datetime.date | None = None,
34
+ end_date: datetime.date | None = None,
35
+ frequency: datetime.timedelta | None = None,
38
36
  ):
39
37
  """Initialize a DatasetName instance.
40
38
 
@@ -146,7 +144,7 @@ class DatasetName:
146
144
  "https://anemoi-registry.readthedocs.io/en/latest/naming-conventions.html"
147
145
  )
148
146
 
149
- def check_resolution(self, resolution: Optional[str]) -> None:
147
+ def check_resolution(self, resolution: str | None) -> None:
150
148
  """Check if the resolution matches the expected format.
151
149
 
152
150
  Parameters
@@ -175,7 +173,7 @@ class DatasetName:
175
173
  if not c.isalnum() and c not in "-":
176
174
  self.messages.append(f"the {self.name} should only contain alphanumeric characters and '-'.")
177
175
 
178
- def check_frequency(self, frequency: Optional[datetime.timedelta]) -> None:
176
+ def check_frequency(self, frequency: datetime.timedelta | None) -> None:
179
177
  """Check if the frequency matches the expected format.
180
178
 
181
179
  Parameters
@@ -189,7 +187,7 @@ class DatasetName:
189
187
  self._check_missing("frequency", frequency_str)
190
188
  self._check_mismatch("frequency", frequency_str)
191
189
 
192
- def check_start_date(self, start_date: Optional[datetime.date]) -> None:
190
+ def check_start_date(self, start_date: datetime.date | None) -> None:
193
191
  """Check if the start date matches the expected format.
194
192
 
195
193
  Parameters
@@ -203,7 +201,7 @@ class DatasetName:
203
201
  self._check_missing("start_date", start_date_str)
204
202
  self._check_mismatch("start_date", start_date_str)
205
203
 
206
- def check_end_date(self, end_date: Optional[datetime.date]) -> None:
204
+ def check_end_date(self, end_date: datetime.date | None) -> None:
207
205
  """Check if the end date matches the expected format.
208
206
 
209
207
  Parameters
@@ -251,7 +249,7 @@ class StatisticsValueError(ValueError):
251
249
 
252
250
 
253
251
  def check_data_values(
254
- arr: NDArray[Any], *, name: str, log: list = [], allow_nans: Union[bool, list, set, tuple, dict] = False
252
+ arr: NDArray[Any], *, name: str, log: list = [], allow_nans: bool | list | set | tuple | dict = False
255
253
  ) -> None:
256
254
  """Check the values in the data array for validity.
257
255
 
@@ -9,7 +9,6 @@
9
9
 
10
10
  import logging
11
11
  import warnings
12
- from typing import Union
13
12
 
14
13
  LOG = logging.getLogger(__name__)
15
14
 
@@ -27,7 +26,7 @@ class ChunkFilter:
27
26
  The chunks that are allowed to be processed.
28
27
  """
29
28
 
30
- def __init__(self, *, parts: Union[str, list], total: int):
29
+ def __init__(self, *, parts: str | list, total: int):
31
30
  """Initializes the ChunkFilter with the given parts and total number of chunks.
32
31
 
33
32
  Parameters
@@ -12,8 +12,6 @@ import logging
12
12
  import os
13
13
  from copy import deepcopy
14
14
  from typing import Any
15
- from typing import Optional
16
- from typing import Union
17
15
 
18
16
  import yaml
19
17
  from anemoi.utils.config import DotDict
@@ -25,7 +23,7 @@ from anemoi.datasets.dates.groups import Groups
25
23
  LOG = logging.getLogger(__name__)
26
24
 
27
25
 
28
- def _get_first_key_if_dict(x: Union[str, dict]) -> str:
26
+ def _get_first_key_if_dict(x: str | dict) -> str:
29
27
  """Returns the first key if the input is a dictionary, otherwise returns the input string.
30
28
 
31
29
  Parameters
@@ -97,7 +95,7 @@ def check_dict_value_and_set(dic: dict, key: str, value: Any) -> None:
97
95
  dic[key] = value
98
96
 
99
97
 
100
- def resolve_includes(config: Union[dict, list]) -> Union[dict, list]:
98
+ def resolve_includes(config: dict | list) -> dict | list:
101
99
  """Resolves '<<' includes in a configuration dictionary or list.
102
100
 
103
101
  Parameters
@@ -123,7 +121,7 @@ def resolve_includes(config: Union[dict, list]) -> Union[dict, list]:
123
121
  class Config(DotDict):
124
122
  """Configuration class that extends DotDict to handle configuration loading and processing."""
125
123
 
126
- def __init__(self, config: Optional[Union[str, dict]] = None, **kwargs):
124
+ def __init__(self, config: str | dict | None = None, **kwargs):
127
125
  """Initializes the Config object.
128
126
 
129
127
  Parameters
@@ -134,7 +132,6 @@ class Config(DotDict):
134
132
  Additional keyword arguments to update the configuration.
135
133
  """
136
134
  if isinstance(config, str):
137
- self.config_path = os.path.realpath(config)
138
135
  config = load_any_dict_format(config)
139
136
  else:
140
137
  config = deepcopy(config if config is not None else {})
@@ -282,6 +279,8 @@ class LoadersConfig(Config):
282
279
 
283
280
  self.output.order_by = normalize_order_by(self.output.order_by)
284
281
 
282
+ self.setdefault("dates", Config())
283
+
285
284
  self.dates["group_by"] = self.build.group_by
286
285
 
287
286
  ###########
@@ -1,4 +1,4 @@
1
- # (C) Copyright 2024 Anemoi contributors.
1
+ # (C) Copyright 2024-2025 Anemoi contributors.
2
2
  #
3
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
4
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
@@ -7,100 +7,79 @@
7
7
  # granted to it by virtue of its status as an intergovernmental organisation
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
- import logging
11
10
  from copy import deepcopy
11
+ from functools import cached_property
12
+ from typing import TYPE_CHECKING
12
13
  from typing import Any
13
- from typing import Union
14
14
 
15
- from anemoi.datasets.dates.groups import GroupOfDates
15
+ from anemoi.datasets.create.input.context.field import FieldContext
16
16
 
17
- from .trace import trace_select
18
-
19
- LOG = logging.getLogger(__name__)
20
-
21
-
22
- class Context:
23
- """Context for building input data."""
24
-
25
- pass
17
+ if TYPE_CHECKING:
18
+ from anemoi.datasets.create.input.action import Recipe
26
19
 
27
20
 
28
21
  class InputBuilder:
29
22
  """Builder class for creating input data from configuration and data sources."""
30
23
 
31
- def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) -> None:
24
+ def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> None:
32
25
  """Initialize the InputBuilder.
33
26
 
34
27
  Parameters
35
28
  ----------
36
29
  config : dict
37
30
  Configuration dictionary.
38
- data_sources : Union[dict, list]
31
+ data_sources : dict
39
32
  Data sources.
40
33
  **kwargs : Any
41
34
  Additional keyword arguments.
42
35
  """
43
36
  self.kwargs = kwargs
37
+ self.config = deepcopy(config)
38
+ self.data_sources = deepcopy(dict(data_sources=data_sources))
44
39
 
45
- config = deepcopy(config)
46
- if data_sources:
47
- config = dict(
48
- data_sources=dict(
49
- sources=data_sources,
50
- input=config,
51
- )
52
- )
53
- self.config = config
54
- self.action_path = ["input"]
55
-
56
- @trace_select
57
- def select(self, group_of_dates: GroupOfDates) -> Any:
58
- """Select data based on the group of dates.
59
-
60
- Parameters
61
- ----------
62
- group_of_dates : GroupOfDates
63
- Group of dates to select data for.
64
-
65
- Returns
66
- -------
67
- Any
68
- Selected data.
69
- """
70
- from .action import ActionContext
40
+ @cached_property
41
+ def action(self) -> "Recipe":
42
+ """Returns the action object based on the configuration."""
43
+ from .action import Recipe
71
44
  from .action import action_factory
72
45
 
73
- """This changes the context."""
74
- context = ActionContext(**self.kwargs)
75
- action = action_factory(self.config, context, self.action_path)
76
- return action.select(group_of_dates)
77
-
78
- def __repr__(self) -> str:
79
- """Return a string representation of the InputBuilder.
80
-
81
- Returns
82
- -------
83
- str
84
- String representation.
85
- """
86
- from .action import ActionContext
87
- from .action import action_factory
46
+ sources = action_factory(self.data_sources, "data_sources")
47
+ input = action_factory(self.config, "input")
88
48
 
89
- context = ActionContext(**self.kwargs)
90
- a = action_factory(self.config, context, self.action_path)
91
- return repr(a)
49
+ return Recipe(input, sources)
92
50
 
93
- def _trace_select(self, group_of_dates: GroupOfDates) -> str:
94
- """Trace the select operation.
51
+ def select(self, argument) -> Any:
52
+ """Select data based on the group of dates.
95
53
 
96
54
  Parameters
97
55
  ----------
98
- group_of_dates : GroupOfDates
56
+ argument : GroupOfDates
99
57
  Group of dates to select data for.
100
58
 
101
59
  Returns
102
60
  -------
103
- str
104
- Trace string.
61
+ Any
62
+ Selected data.
105
63
  """
106
- return f"InputBuilder({group_of_dates})"
64
+ context = FieldContext(argument, **self.kwargs)
65
+ return context.create_result(self.action(context, argument))
66
+
67
+
68
+ def build_input(config: dict, data_sources: dict | list, **kwargs: Any) -> InputBuilder:
69
+ """Build an InputBuilder instance.
70
+
71
+ Parameters
72
+ ----------
73
+ config : dict
74
+ Configuration dictionary.
75
+ data_sources : Union[dict, list]
76
+ Data sources.
77
+ **kwargs : Any
78
+ Additional keyword arguments.
79
+
80
+ Returns
81
+ -------
82
+ InputBuilder
83
+ An instance of InputBuilder.
84
+ """
85
+ return InputBuilder(config, data_sources, **kwargs)