anemoi-datasets 0.5.16__py3-none-any.whl → 0.5.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. anemoi/datasets/__init__.py +4 -1
  2. anemoi/datasets/__main__.py +12 -2
  3. anemoi/datasets/_version.py +9 -4
  4. anemoi/datasets/commands/cleanup.py +17 -2
  5. anemoi/datasets/commands/compare.py +18 -2
  6. anemoi/datasets/commands/copy.py +196 -14
  7. anemoi/datasets/commands/create.py +50 -7
  8. anemoi/datasets/commands/finalise-additions.py +17 -2
  9. anemoi/datasets/commands/finalise.py +17 -2
  10. anemoi/datasets/commands/init-additions.py +17 -2
  11. anemoi/datasets/commands/init.py +16 -2
  12. anemoi/datasets/commands/inspect.py +283 -62
  13. anemoi/datasets/commands/load-additions.py +16 -2
  14. anemoi/datasets/commands/load.py +16 -2
  15. anemoi/datasets/commands/patch.py +17 -2
  16. anemoi/datasets/commands/publish.py +17 -2
  17. anemoi/datasets/commands/scan.py +31 -3
  18. anemoi/datasets/compute/recentre.py +47 -11
  19. anemoi/datasets/create/__init__.py +612 -85
  20. anemoi/datasets/create/check.py +142 -20
  21. anemoi/datasets/create/chunks.py +64 -4
  22. anemoi/datasets/create/config.py +185 -21
  23. anemoi/datasets/create/filter.py +50 -0
  24. anemoi/datasets/create/filters/__init__.py +33 -0
  25. anemoi/datasets/create/filters/empty.py +37 -0
  26. anemoi/datasets/create/filters/legacy.py +93 -0
  27. anemoi/datasets/create/filters/noop.py +37 -0
  28. anemoi/datasets/create/filters/orog_to_z.py +58 -0
  29. anemoi/datasets/create/{functions/filters → filters}/pressure_level_relative_humidity_to_specific_humidity.py +33 -10
  30. anemoi/datasets/create/{functions/filters → filters}/pressure_level_specific_humidity_to_relative_humidity.py +32 -8
  31. anemoi/datasets/create/filters/rename.py +205 -0
  32. anemoi/datasets/create/{functions/filters → filters}/rotate_winds.py +43 -28
  33. anemoi/datasets/create/{functions/filters → filters}/single_level_dewpoint_to_relative_humidity.py +32 -9
  34. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_dewpoint.py +33 -9
  35. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_specific_humidity.py +55 -7
  36. anemoi/datasets/create/{functions/filters → filters}/single_level_specific_humidity_to_relative_humidity.py +98 -37
  37. anemoi/datasets/create/filters/speeddir_to_uv.py +95 -0
  38. anemoi/datasets/create/{functions/filters → filters}/sum.py +24 -27
  39. anemoi/datasets/create/filters/transform.py +53 -0
  40. anemoi/datasets/create/{functions/filters → filters}/unrotate_winds.py +27 -18
  41. anemoi/datasets/create/filters/uv_to_speeddir.py +94 -0
  42. anemoi/datasets/create/{functions/filters → filters}/wz_to_w.py +51 -33
  43. anemoi/datasets/create/input/__init__.py +76 -5
  44. anemoi/datasets/create/input/action.py +149 -13
  45. anemoi/datasets/create/input/concat.py +81 -10
  46. anemoi/datasets/create/input/context.py +39 -4
  47. anemoi/datasets/create/input/data_sources.py +72 -6
  48. anemoi/datasets/create/input/empty.py +21 -3
  49. anemoi/datasets/create/input/filter.py +60 -12
  50. anemoi/datasets/create/input/function.py +154 -37
  51. anemoi/datasets/create/input/join.py +86 -14
  52. anemoi/datasets/create/input/misc.py +67 -17
  53. anemoi/datasets/create/input/pipe.py +33 -6
  54. anemoi/datasets/create/input/repeated_dates.py +189 -41
  55. anemoi/datasets/create/input/result.py +202 -87
  56. anemoi/datasets/create/input/step.py +119 -22
  57. anemoi/datasets/create/input/template.py +100 -13
  58. anemoi/datasets/create/input/trace.py +62 -7
  59. anemoi/datasets/create/patch.py +52 -4
  60. anemoi/datasets/create/persistent.py +134 -17
  61. anemoi/datasets/create/size.py +15 -1
  62. anemoi/datasets/create/source.py +51 -0
  63. anemoi/datasets/create/sources/__init__.py +36 -0
  64. anemoi/datasets/create/{functions/sources → sources}/accumulations.py +296 -30
  65. anemoi/datasets/create/{functions/sources → sources}/constants.py +27 -2
  66. anemoi/datasets/create/{functions/sources → sources}/eccc_fstd.py +7 -3
  67. anemoi/datasets/create/sources/empty.py +37 -0
  68. anemoi/datasets/create/{functions/sources → sources}/forcings.py +25 -1
  69. anemoi/datasets/create/sources/grib.py +297 -0
  70. anemoi/datasets/create/{functions/sources → sources}/hindcasts.py +38 -4
  71. anemoi/datasets/create/sources/legacy.py +93 -0
  72. anemoi/datasets/create/{functions/sources → sources}/mars.py +168 -20
  73. anemoi/datasets/create/sources/netcdf.py +42 -0
  74. anemoi/datasets/create/sources/opendap.py +43 -0
  75. anemoi/datasets/create/{functions/sources/__init__.py → sources/patterns.py} +35 -4
  76. anemoi/datasets/create/sources/recentre.py +150 -0
  77. anemoi/datasets/create/{functions/sources → sources}/source.py +27 -5
  78. anemoi/datasets/create/{functions/sources → sources}/tendencies.py +64 -7
  79. anemoi/datasets/create/sources/xarray.py +92 -0
  80. anemoi/datasets/create/sources/xarray_kerchunk.py +36 -0
  81. anemoi/datasets/create/sources/xarray_support/README.md +1 -0
  82. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/__init__.py +109 -8
  83. anemoi/datasets/create/sources/xarray_support/coordinates.py +442 -0
  84. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/field.py +94 -16
  85. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/fieldlist.py +90 -25
  86. anemoi/datasets/create/sources/xarray_support/flavour.py +1036 -0
  87. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/grid.py +92 -31
  88. anemoi/datasets/create/sources/xarray_support/metadata.py +395 -0
  89. anemoi/datasets/create/sources/xarray_support/patch.py +91 -0
  90. anemoi/datasets/create/sources/xarray_support/time.py +391 -0
  91. anemoi/datasets/create/sources/xarray_support/variable.py +331 -0
  92. anemoi/datasets/create/sources/xarray_zarr.py +41 -0
  93. anemoi/datasets/create/{functions/sources → sources}/zenodo.py +34 -5
  94. anemoi/datasets/create/statistics/__init__.py +233 -44
  95. anemoi/datasets/create/statistics/summary.py +52 -6
  96. anemoi/datasets/create/testing.py +76 -0
  97. anemoi/datasets/create/{functions/filters/noop.py → typing.py} +6 -3
  98. anemoi/datasets/create/utils.py +97 -6
  99. anemoi/datasets/create/writer.py +26 -4
  100. anemoi/datasets/create/zarr.py +170 -23
  101. anemoi/datasets/data/__init__.py +51 -4
  102. anemoi/datasets/data/complement.py +191 -40
  103. anemoi/datasets/data/concat.py +141 -16
  104. anemoi/datasets/data/dataset.py +558 -62
  105. anemoi/datasets/data/debug.py +197 -26
  106. anemoi/datasets/data/ensemble.py +93 -8
  107. anemoi/datasets/data/fill_missing.py +165 -18
  108. anemoi/datasets/data/forwards.py +428 -56
  109. anemoi/datasets/data/grids.py +323 -97
  110. anemoi/datasets/data/indexing.py +112 -19
  111. anemoi/datasets/data/interpolate.py +92 -12
  112. anemoi/datasets/data/join.py +158 -19
  113. anemoi/datasets/data/masked.py +129 -15
  114. anemoi/datasets/data/merge.py +137 -23
  115. anemoi/datasets/data/misc.py +172 -16
  116. anemoi/datasets/data/missing.py +233 -29
  117. anemoi/datasets/data/rescale.py +111 -10
  118. anemoi/datasets/data/select.py +168 -26
  119. anemoi/datasets/data/statistics.py +67 -6
  120. anemoi/datasets/data/stores.py +149 -64
  121. anemoi/datasets/data/subset.py +159 -25
  122. anemoi/datasets/data/unchecked.py +168 -57
  123. anemoi/datasets/data/xy.py +168 -25
  124. anemoi/datasets/dates/__init__.py +191 -16
  125. anemoi/datasets/dates/groups.py +189 -47
  126. anemoi/datasets/grids.py +270 -31
  127. anemoi/datasets/testing.py +28 -1
  128. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.18.dist-info}/METADATA +9 -6
  129. anemoi_datasets-0.5.18.dist-info/RECORD +137 -0
  130. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.18.dist-info}/WHEEL +1 -1
  131. anemoi/datasets/create/functions/__init__.py +0 -66
  132. anemoi/datasets/create/functions/filters/__init__.py +0 -9
  133. anemoi/datasets/create/functions/filters/empty.py +0 -17
  134. anemoi/datasets/create/functions/filters/orog_to_z.py +0 -58
  135. anemoi/datasets/create/functions/filters/rename.py +0 -79
  136. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +0 -78
  137. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +0 -56
  138. anemoi/datasets/create/functions/sources/empty.py +0 -15
  139. anemoi/datasets/create/functions/sources/grib.py +0 -150
  140. anemoi/datasets/create/functions/sources/netcdf.py +0 -15
  141. anemoi/datasets/create/functions/sources/opendap.py +0 -15
  142. anemoi/datasets/create/functions/sources/recentre.py +0 -60
  143. anemoi/datasets/create/functions/sources/xarray/coordinates.py +0 -255
  144. anemoi/datasets/create/functions/sources/xarray/flavour.py +0 -472
  145. anemoi/datasets/create/functions/sources/xarray/metadata.py +0 -148
  146. anemoi/datasets/create/functions/sources/xarray/patch.py +0 -44
  147. anemoi/datasets/create/functions/sources/xarray/time.py +0 -177
  148. anemoi/datasets/create/functions/sources/xarray/variable.py +0 -188
  149. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +0 -42
  150. anemoi/datasets/create/functions/sources/xarray_zarr.py +0 -15
  151. anemoi/datasets/utils/fields.py +0 -47
  152. anemoi_datasets-0.5.16.dist-info/RECORD +0 -129
  153. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.18.dist-info}/entry_points.txt +0 -0
  154. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.18.dist-info/licenses}/LICENSE +0 -0
  155. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.18.dist-info}/top_level.txt +0 -0
@@ -8,25 +8,48 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
10
 
11
+ import datetime
11
12
  import logging
12
13
  import re
13
14
  import warnings
15
+ from typing import Any
16
+ from typing import Callable
17
+ from typing import Optional
18
+ from typing import Union
14
19
 
15
20
  import numpy as np
16
21
  from anemoi.utils.dates import frequency_to_string
22
+ from numpy.typing import NDArray
17
23
 
18
24
  LOG = logging.getLogger(__name__)
19
25
 
20
26
 
21
27
  class DatasetName:
28
+ """Class to validate and parse dataset names according to naming conventions."""
29
+
22
30
  def __init__(
23
31
  self,
24
- name,
25
- resolution=None,
26
- start_date=None,
27
- end_date=None,
28
- frequency=None,
32
+ name: str,
33
+ resolution: Optional[str] = None,
34
+ start_date: Optional[datetime.date] = None,
35
+ end_date: Optional[datetime.date] = None,
36
+ frequency: Optional[datetime.timedelta] = None,
29
37
  ):
38
+ """Initialize a DatasetName instance.
39
+
40
+ Parameters
41
+ ----------
42
+ name : str
43
+ The name of the dataset.
44
+ resolution : Optional[str], optional
45
+ The resolution of the dataset.
46
+ start_date : Optional[datetime.date], optional
47
+ The start date of the dataset.
48
+ end_date : Optional[datetime.date], optional
49
+ The end date of the dataset.
50
+ frequency : Optional[datetime.timedelta], optional
51
+ The frequency of the dataset.
52
+ """
30
53
  self.name = name
31
54
  self.parsed = self._parse(name)
32
55
  print("---------------")
@@ -45,19 +68,39 @@ class DatasetName:
45
68
  self.messages.append(f"{self} is parsed as :" + "/".join(f"{k}={v}" for k, v in self.parsed.items()))
46
69
 
47
70
  @property
48
- def error_message(self):
71
+ def error_message(self) -> str:
72
+ """Generate an error message based on the collected messages."""
49
73
  out = " And ".join(self.messages)
50
74
  if out:
51
- out = out[0].upper() + out[1:]
75
+ out[0].upper() + out[1:]
52
76
  return out
53
77
 
54
- def raise_if_not_valid(self, print=print):
78
+ def raise_if_not_valid(self, print: Callable = print) -> None:
79
+ """Raise a ValueError if the dataset name is not valid.
80
+
81
+ Parameters
82
+ ----------
83
+ print : Callable
84
+ The function to use for printing messages.
85
+ """
55
86
  if self.messages:
56
87
  for m in self.messages:
57
88
  print(m)
58
89
  raise ValueError(self.error_message)
59
90
 
60
- def _parse(self, name):
91
+ def _parse(self, name: str) -> dict:
92
+ """Parse the dataset name into its components.
93
+
94
+ Parameters
95
+ ----------
96
+ name : str
97
+ The name of the dataset.
98
+
99
+ Returns
100
+ -------
101
+ dict
102
+ The parsed components of the dataset name.
103
+ """
61
104
  pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h|\d+m)-v(\d+)-?([a-zA-Z0-9-]+)?$"
62
105
  match = re.match(pattern, name)
63
106
 
@@ -81,10 +124,12 @@ class DatasetName:
81
124
 
82
125
  return parsed
83
126
 
84
- def __str__(self):
127
+ def __str__(self) -> str:
128
+ """Return the string representation of the dataset name."""
85
129
  return self.name
86
130
 
87
- def check_parsed(self):
131
+ def check_parsed(self) -> None:
132
+ """Check if the dataset name was parsed correctly."""
88
133
  if not self.parsed:
89
134
  self.messages.append(
90
135
  f"the dataset name {self} does not follow naming convention. "
@@ -92,7 +137,14 @@ class DatasetName:
92
137
  "https://anemoi-registry.readthedocs.io/en/latest/naming-conventions.html"
93
138
  )
94
139
 
95
- def check_resolution(self, resolution):
140
+ def check_resolution(self, resolution: Optional[str]) -> None:
141
+ """Check if the resolution matches the expected format.
142
+
143
+ Parameters
144
+ ----------
145
+ resolution : str or None
146
+ The expected resolution.
147
+ """
96
148
  if self.parsed.get("resolution") and self.parsed["resolution"][0] not in "0123456789on":
97
149
  self.messages.append(
98
150
  f"the resolution {self.parsed['resolution'] } should start "
@@ -105,42 +157,97 @@ class DatasetName:
105
157
  self._check_missing("resolution", resolution_str)
106
158
  self._check_mismatch("resolution", resolution_str)
107
159
 
108
- def check_frequency(self, frequency):
160
+ def check_frequency(self, frequency: Optional[datetime.timedelta]) -> None:
161
+ """Check if the frequency matches the expected format.
162
+
163
+ Parameters
164
+ ----------
165
+ frequency : datetime.timedelta or None
166
+ The expected frequency.
167
+ """
109
168
  if frequency is None:
110
169
  return
111
170
  frequency_str = frequency_to_string(frequency)
112
171
  self._check_missing("frequency", frequency_str)
113
172
  self._check_mismatch("frequency", frequency_str)
114
173
 
115
- def check_start_date(self, start_date):
174
+ def check_start_date(self, start_date: Optional[datetime.date]) -> None:
175
+ """Check if the start date matches the expected format.
176
+
177
+ Parameters
178
+ ----------
179
+ start_date : datetime.date or None
180
+ The expected start date.
181
+ """
116
182
  if start_date is None:
117
183
  return
118
184
  start_date_str = str(start_date.year)
119
185
  self._check_missing("start_date", start_date_str)
120
186
  self._check_mismatch("start_date", start_date_str)
121
187
 
122
- def check_end_date(self, end_date):
188
+ def check_end_date(self, end_date: Optional[datetime.date]) -> None:
189
+ """Check if the end date matches the expected format.
190
+
191
+ Parameters
192
+ ----------
193
+ end_date : datetime.date or None
194
+ The expected end date.
195
+ """
123
196
  if end_date is None:
124
197
  return
125
198
  end_date_str = str(end_date.year)
126
199
  self._check_missing("end_date", end_date_str)
127
200
  self._check_mismatch("end_date", end_date_str)
128
201
 
129
- def _check_missing(self, key, value):
202
+ def _check_missing(self, key: str, value: str) -> None:
203
+ """Check if a component is missing from the dataset name.
204
+
205
+ Parameters
206
+ ----------
207
+ key : str
208
+ The component key.
209
+ value : str
210
+ The expected value.
211
+ """
130
212
  if value not in self.name:
131
213
  self.messages.append(f"the {key} is {value}, but is missing in {self.name}.")
132
214
 
133
- def _check_mismatch(self, key, value):
215
+ def _check_mismatch(self, key: str, value: str) -> None:
216
+ """Check if a component value mismatches the expected value.
217
+
218
+ Parameters
219
+ ----------
220
+ key : str
221
+ The component key.
222
+ value : str
223
+ The expected value.
224
+ """
134
225
  if self.parsed.get(key) and self.parsed[key] != value:
135
226
  self.messages.append(f"the {key} is {value}, but is {self.parsed[key]} in {self.name}.")
136
227
 
137
228
 
138
229
  class StatisticsValueError(ValueError):
139
- pass
230
+ """Custom error for statistics value issues."""
140
231
 
232
+ pass
141
233
 
142
- def check_data_values(arr, *, name: str, log=[], allow_nans=False):
143
234
 
235
+ def check_data_values(
236
+ arr: NDArray[Any], *, name: str, log: list = [], allow_nans: Union[bool, list, set, tuple, dict] = False
237
+ ) -> None:
238
+ """Check the values in the data array for validity.
239
+
240
+ Parameters
241
+ ----------
242
+ arr : NDArray[Any]
243
+ The data array to check.
244
+ name : str
245
+ The name of the data array.
246
+ log : list, optional
247
+ A list to log messages.
248
+ allow_nans : bool or list or set or tuple or dict, optional
249
+ Whether to allow NaNs in the data array.
250
+ """
144
251
  shape = arr.shape
145
252
 
146
253
  if (isinstance(allow_nans, (set, list, tuple, dict)) and name in allow_nans) or allow_nans:
@@ -182,7 +289,22 @@ def check_data_values(arr, *, name: str, log=[], allow_nans=False):
182
289
  )
183
290
 
184
291
 
185
- def check_stats(minimum, maximum, mean, msg, **kwargs):
292
+ def check_stats(minimum: float, maximum: float, mean: float, msg: str, **kwargs: Any) -> None:
293
+ """Check if the mean value is within the min/max interval.
294
+
295
+ Parameters
296
+ ----------
297
+ minimum : float
298
+ The minimum value.
299
+ maximum : float
300
+ The maximum value.
301
+ mean : float
302
+ The mean value.
303
+ msg : str
304
+ The message to include in the error.
305
+ **kwargs : Any
306
+ Additional keyword arguments.
307
+ """
186
308
  tolerance = (abs(minimum) + abs(maximum)) * 0.01
187
309
  if (mean - minimum < -tolerance) or (mean - minimum < -tolerance):
188
310
  raise StatisticsValueError(
@@ -9,6 +9,7 @@
9
9
 
10
10
  import logging
11
11
  import warnings
12
+ from typing import Union
12
13
 
13
14
  LOG = logging.getLogger(__name__)
14
15
 
@@ -16,7 +17,35 @@ ALL = object()
16
17
 
17
18
 
18
19
  class ChunkFilter:
19
- def __init__(self, *, parts, total):
20
+ """A filter to determine which chunks to process based on the specified parts.
21
+
22
+ Attributes
23
+ ----------
24
+ total : int
25
+ The total number of chunks.
26
+ allowed : object or list
27
+ The chunks that are allowed to be processed.
28
+ """
29
+
30
+ def __init__(self, *, parts: Union[str, list], total: int):
31
+ """Initializes the ChunkFilter with the given parts and total number of chunks.
32
+
33
+ Parameters
34
+ ----------
35
+ parts : str or list
36
+ The parts to process, specified as 'i/n' or a list of such strings.
37
+ total : int
38
+ The total number of chunks.
39
+
40
+ Raises
41
+ ------
42
+ ValueError
43
+ If the parts format is invalid.
44
+ AssertionError
45
+ If the chunk number is invalid.
46
+ Warning
47
+ If the number of chunks is larger than the total number of chunks.
48
+ """
20
49
  self.total = total
21
50
 
22
51
  if isinstance(parts, list):
@@ -62,7 +91,24 @@ class ChunkFilter:
62
91
 
63
92
  self.allowed = parts
64
93
 
65
- def __call__(self, i):
94
+ def __call__(self, i: int) -> bool:
95
+ """Checks if the given chunk number is allowed to be processed.
96
+
97
+ Parameters
98
+ ----------
99
+ i : int
100
+ The chunk number to check.
101
+
102
+ Returns
103
+ -------
104
+ bool
105
+ True if the chunk is allowed, False otherwise.
106
+
107
+ Raises
108
+ ------
109
+ AssertionError
110
+ If the chunk number is invalid.
111
+ """
66
112
  if i < 0 or i >= self.total:
67
113
  raise AssertionError(f"Invalid chunk number {i}. Must be between 0 and {self.total - 1}.")
68
114
 
@@ -70,10 +116,24 @@ class ChunkFilter:
70
116
  return True
71
117
  return i in self.allowed
72
118
 
73
- def __iter__(self):
119
+ def __iter__(self) -> iter:
120
+ """Iterates over the allowed chunks.
121
+
122
+ Yields
123
+ ------
124
+ int
125
+ The next allowed chunk number.
126
+ """
74
127
  for i in range(self.total):
75
128
  if self(i):
76
129
  yield i
77
130
 
78
- def __len__(self):
131
+ def __len__(self) -> int:
132
+ """Returns the number of allowed chunks.
133
+
134
+ Returns
135
+ -------
136
+ int
137
+ The number of allowed chunks.
138
+ """
79
139
  return len([_ for _ in self])
@@ -11,6 +11,9 @@ import datetime
11
11
  import logging
12
12
  import os
13
13
  from copy import deepcopy
14
+ from typing import Any
15
+ from typing import Optional
16
+ from typing import Union
14
17
 
15
18
  import yaml
16
19
  from anemoi.utils.config import DotDict
@@ -22,13 +25,41 @@ from anemoi.datasets.dates.groups import Groups
22
25
  LOG = logging.getLogger(__name__)
23
26
 
24
27
 
25
- def _get_first_key_if_dict(x):
28
+ def _get_first_key_if_dict(x: Union[str, dict]) -> str:
29
+ """Returns the first key if the input is a dictionary, otherwise returns the input string.
30
+
31
+ Parameters
32
+ ----------
33
+ x : str or dict
34
+ Input string or dictionary.
35
+
36
+ Returns
37
+ -------
38
+ str
39
+ The first key if input is a dictionary, otherwise the input string.
40
+ """
26
41
  if isinstance(x, str):
27
42
  return x
28
43
  return list(x.keys())[0]
29
44
 
30
45
 
31
- def ensure_element_in_list(lst, elt, index):
46
+ def ensure_element_in_list(lst: list, elt: str, index: int) -> list:
47
+ """Ensures that a specified element is present at a given index in a list.
48
+
49
+ Parameters
50
+ ----------
51
+ lst : list
52
+ The list to check.
53
+ elt : str
54
+ The element to ensure is in the list.
55
+ index : int
56
+ The index at which the element should be present.
57
+
58
+ Returns
59
+ -------
60
+ list
61
+ The modified list with the element at the specified index.
62
+ """
32
63
  if elt in lst:
33
64
  assert lst[index] == elt
34
65
  return lst
@@ -41,7 +72,23 @@ def ensure_element_in_list(lst, elt, index):
41
72
  return lst[:index] + [elt] + lst[index:]
42
73
 
43
74
 
44
- def check_dict_value_and_set(dic, key, value):
75
+ def check_dict_value_and_set(dic: dict, key: str, value: Any) -> None:
76
+ """Checks if a dictionary contains a specific key-value pair and sets it if not present.
77
+
78
+ Parameters
79
+ ----------
80
+ dic : dict
81
+ The dictionary to check.
82
+ key : str
83
+ The key to check in the dictionary.
84
+ value : Any
85
+ The value to set if the key is not present.
86
+
87
+ Raises
88
+ ------
89
+ ValueError
90
+ If the key is present but with a different value.
91
+ """
45
92
  if key in dic:
46
93
  if dic[key] == value:
47
94
  return
@@ -50,7 +97,19 @@ def check_dict_value_and_set(dic, key, value):
50
97
  dic[key] = value
51
98
 
52
99
 
53
- def resolve_includes(config):
100
+ def resolve_includes(config: Union[dict, list]) -> Union[dict, list]:
101
+ """Resolves '<<' includes in a configuration dictionary or list.
102
+
103
+ Parameters
104
+ ----------
105
+ config : dict or list
106
+ The configuration to resolve includes for.
107
+
108
+ Returns
109
+ -------
110
+ dict or list
111
+ The configuration with includes resolved.
112
+ """
54
113
  if isinstance(config, list):
55
114
  return [resolve_includes(c) for c in config]
56
115
  if isinstance(config, dict):
@@ -62,7 +121,18 @@ def resolve_includes(config):
62
121
 
63
122
 
64
123
  class Config(DotDict):
65
- def __init__(self, config=None, **kwargs):
124
+ """Configuration class that extends DotDict to handle configuration loading and processing."""
125
+
126
+ def __init__(self, config: Optional[Union[str, dict]] = None, **kwargs):
127
+ """Initializes the Config object.
128
+
129
+ Parameters
130
+ ----------
131
+ config : str or dict, optional
132
+ Path to the configuration file or a dictionary. Defaults to None.
133
+ **kwargs
134
+ Additional keyword arguments to update the configuration.
135
+ """
66
136
  if isinstance(config, str):
67
137
  self.config_path = os.path.realpath(config)
68
138
  config = load_any_dict_format(config)
@@ -74,7 +144,18 @@ class Config(DotDict):
74
144
 
75
145
 
76
146
  class OutputSpecs:
77
- def __init__(self, config, parent):
147
+ """Class to handle output specifications for datasets."""
148
+
149
+ def __init__(self, config: Config, parent: Any):
150
+ """Initializes the OutputSpecs object.
151
+
152
+ Parameters
153
+ ----------
154
+ config : Config
155
+ The configuration object.
156
+ parent : Any
157
+ The parent object.
158
+ """
78
159
  self.config = config
79
160
  if "order_by" in config:
80
161
  assert isinstance(config.order_by, dict), config.order_by
@@ -82,15 +163,28 @@ class OutputSpecs:
82
163
  self.parent = parent
83
164
 
84
165
  @property
85
- def dtype(self):
166
+ def dtype(self) -> str:
167
+ """Returns the data type for the output."""
86
168
  return self.config.dtype
87
169
 
88
170
  @property
89
- def order_by_as_list(self):
90
- # this is used when an ordered dict is not supported (e.g. zarr attributes)
171
+ def order_by_as_list(self) -> list[dict]:
172
+ """Returns the order_by configuration as a list of dictionaries."""
91
173
  return [{k: v} for k, v in self.config.order_by.items()]
92
174
 
93
- def get_chunking(self, coords):
175
+ def get_chunking(self, coords: dict) -> tuple:
176
+ """Returns the chunking configuration based on coordinates.
177
+
178
+ Parameters
179
+ ----------
180
+ coords : dict
181
+ The coordinates dictionary.
182
+
183
+ Returns
184
+ -------
185
+ tuple
186
+ The chunking configuration.
187
+ """
94
188
  user = deepcopy(self.config.chunking)
95
189
  chunks = []
96
190
  for k, v in coords.items():
@@ -105,25 +199,41 @@ class OutputSpecs:
105
199
  return tuple(chunks)
106
200
 
107
201
  @property
108
- def order_by(self):
202
+ def order_by(self) -> dict:
203
+ """Returns the order_by configuration."""
109
204
  return self.config.order_by
110
205
 
111
206
  @property
112
- def remapping(self):
207
+ def remapping(self) -> dict:
208
+ """Returns the remapping configuration."""
113
209
  return self.config.remapping
114
210
 
115
211
  @property
116
- def flatten_grid(self):
212
+ def flatten_grid(self) -> bool:
213
+ """Returns whether the grid should be flattened."""
117
214
  return self.config.flatten_grid
118
215
 
119
216
  @property
120
- def statistics(self):
217
+ def statistics(self) -> str:
218
+ """Returns the statistics configuration."""
121
219
  return self.config.statistics
122
220
 
123
221
 
124
222
  class LoadersConfig(Config):
125
- def __init__(self, config, *args, **kwargs):
126
-
223
+ """Configuration class for dataset loaders."""
224
+
225
+ def __init__(self, config: dict, *args, **kwargs):
226
+ """Initializes the LoadersConfig object.
227
+
228
+ Parameters
229
+ ----------
230
+ config : dict
231
+ The configuration dictionary.
232
+ *args
233
+ Additional positional arguments.
234
+ **kwargs
235
+ Additional keyword arguments.
236
+ """
127
237
  super().__init__(config, *args, **kwargs)
128
238
 
129
239
  # TODO: should use a json schema to validate the config
@@ -178,11 +288,30 @@ class LoadersConfig(Config):
178
288
 
179
289
  self.reading_chunks = self.get("reading_chunks")
180
290
 
181
- def get_serialisable_dict(self):
291
+ def get_serialisable_dict(self) -> dict:
292
+ """Returns a serializable dictionary representation of the configuration.
293
+
294
+ Returns
295
+ -------
296
+ dict
297
+ The serializable dictionary.
298
+ """
182
299
  return _prepare_serialisation(self)
183
300
 
184
301
 
185
- def _prepare_serialisation(o):
302
+ def _prepare_serialisation(o: Any) -> Any:
303
+ """Prepares an object for serialization.
304
+
305
+ Parameters
306
+ ----------
307
+ o : Any
308
+ The object to prepare.
309
+
310
+ Returns
311
+ -------
312
+ Any
313
+ The prepared object.
314
+ """
186
315
  if isinstance(o, dict):
187
316
  dic = {}
188
317
  for k, v in o.items():
@@ -212,7 +341,14 @@ def _prepare_serialisation(o):
212
341
  return str(o)
213
342
 
214
343
 
215
- def set_to_test_mode(cfg):
344
+ def set_to_test_mode(cfg: dict) -> None:
345
+ """Modifies the configuration to run in test mode.
346
+
347
+ Parameters
348
+ ----------
349
+ cfg : dict
350
+ The configuration dictionary.
351
+ """
216
352
  NUMBER_OF_DATES = 4
217
353
 
218
354
  LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.")
@@ -251,7 +387,21 @@ def set_to_test_mode(cfg):
251
387
  set_element_to_test(cfg)
252
388
 
253
389
 
254
- def loader_config(config, is_test=False):
390
+ def loader_config(config: dict, is_test: bool = False) -> LoadersConfig:
391
+ """Loads and validates the configuration for dataset loaders.
392
+
393
+ Parameters
394
+ ----------
395
+ config : dict
396
+ The configuration dictionary.
397
+ is_test : bool, optional
398
+ Whether to run in test mode. Defaults to False.
399
+
400
+ Returns
401
+ -------
402
+ LoadersConfig
403
+ The validated configuration object.
404
+ """
255
405
  config = Config(config)
256
406
  if is_test:
257
407
  set_to_test_mode(config)
@@ -273,5 +423,19 @@ def loader_config(config, is_test=False):
273
423
  return copy
274
424
 
275
425
 
276
- def build_output(*args, **kwargs):
426
+ def build_output(*args, **kwargs) -> OutputSpecs:
427
+ """Builds the output specifications.
428
+
429
+ Parameters
430
+ ----------
431
+ *args
432
+ Additional positional arguments.
433
+ **kwargs
434
+ Additional keyword arguments.
435
+
436
+ Returns
437
+ -------
438
+ OutputSpecs
439
+ The output specifications object.
440
+ """
277
441
  return OutputSpecs(*args, **kwargs)